Disallow LeakyReluFusion when alpha is greater than one (#19446)

Tickets: CVS-118898, CVS-82454
This commit is contained in:
Mateusz Tabaka 2023-08-31 12:41:46 +02:00 committed by GitHub
parent 463ae19207
commit 120a81ff5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 8 deletions

View File

@ -18,7 +18,7 @@
ov::pass::LeakyReluFusion::LeakyReluFusion() {
MATCHER_SCOPE(LeakyReluFusion);
auto data_pattern = pass::pattern::any_input();
auto alpha_pattern = pass::pattern::any_input(pattern::has_static_shape());
auto alpha_pattern = pass::pattern::wrap_type<op::v0::Constant>();
auto multiply_pattern =
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({data_pattern, alpha_pattern}, pattern::consumers_count(1));
auto max_pattern = ov::pass::pattern::wrap_type<ov::op::v1::Maximum>({data_pattern, multiply_pattern});
@ -30,6 +30,17 @@ ov::pass::LeakyReluFusion::LeakyReluFusion() {
if (shape_size(original_alpha_pattern.get_shape()) != 1)
return false;
auto constant = ov::as_type_ptr<op::v0::Constant>(original_alpha_pattern.get_node_shared_ptr());
if (!constant)
return false;
float value;
if (!op::util::get_single_value(constant, value))
return false;
if (value > 1.0f)
return false;
auto leaky_relu = register_new_node<ov::op::v0::PRelu>(pattern_map.at(data_pattern), original_alpha_pattern);
auto maximum = pattern_map.at(max_pattern);
leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name());

View File

@ -40,6 +40,45 @@ TEST_F(TransformationTestsF, LeakyReluFusionConstant) {
}
}
TEST_F(TransformationTestsF, LeakyReluFusionConstantGreaterThanOne) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {1.1});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
model = std::make_shared<Model>(NodeVector{max}, ParameterVector{data});
manager.register_pass<ov::pass::LeakyReluFusion>();
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, LeakyReluFusionConstantAlphaOnFirstInput) {
{
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto multiply = std::make_shared<opset8::Multiply>(alpha, data);
auto max = std::make_shared<opset8::Maximum>(multiply, data);
model = std::make_shared<Model>(NodeVector{max}, ParameterVector{data});
manager.register_pass<ov::pass::LeakyReluFusion>();
}
{
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
model_ref = std::make_shared<Model>(NodeVector{leaky_relu}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, LeakyReluFusionScalar) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
@ -69,11 +108,4 @@ TEST_F(TransformationTestsF, LeakyReluFusionParameter) {
manager.register_pass<ov::pass::LeakyReluFusion>();
}
{
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
model_ref = std::make_shared<Model>(NodeVector{leaky_relu}, ParameterVector{data, alpha});
}
}