Disallow LeakyReluFusion when alpha is greater than one (#19446)
Tickets: CVS-118898, CVS-82454
This commit is contained in:
parent
463ae19207
commit
120a81ff5e
@ -18,7 +18,7 @@
|
||||
ov::pass::LeakyReluFusion::LeakyReluFusion() {
|
||||
MATCHER_SCOPE(LeakyReluFusion);
|
||||
auto data_pattern = pass::pattern::any_input();
|
||||
auto alpha_pattern = pass::pattern::any_input(pattern::has_static_shape());
|
||||
auto alpha_pattern = pass::pattern::wrap_type<op::v0::Constant>();
|
||||
auto multiply_pattern =
|
||||
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({data_pattern, alpha_pattern}, pattern::consumers_count(1));
|
||||
auto max_pattern = ov::pass::pattern::wrap_type<ov::op::v1::Maximum>({data_pattern, multiply_pattern});
|
||||
@ -30,6 +30,17 @@ ov::pass::LeakyReluFusion::LeakyReluFusion() {
|
||||
if (shape_size(original_alpha_pattern.get_shape()) != 1)
|
||||
return false;
|
||||
|
||||
auto constant = ov::as_type_ptr<op::v0::Constant>(original_alpha_pattern.get_node_shared_ptr());
|
||||
if (!constant)
|
||||
return false;
|
||||
|
||||
float value;
|
||||
if (!op::util::get_single_value(constant, value))
|
||||
return false;
|
||||
|
||||
if (value > 1.0f)
|
||||
return false;
|
||||
|
||||
auto leaky_relu = register_new_node<ov::op::v0::PRelu>(pattern_map.at(data_pattern), original_alpha_pattern);
|
||||
auto maximum = pattern_map.at(max_pattern);
|
||||
leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name());
|
||||
|
@ -40,6 +40,45 @@ TEST_F(TransformationTestsF, LeakyReluFusionConstant) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, LeakyReluFusionConstantGreaterThanOne) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
|
||||
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {1.1});
|
||||
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
|
||||
auto max = std::make_shared<opset8::Maximum>(data, multiply);
|
||||
model = std::make_shared<Model>(NodeVector{max}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::LeakyReluFusion>();
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, LeakyReluFusionConstantAlphaOnFirstInput) {
|
||||
{
|
||||
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
|
||||
auto multiply = std::make_shared<opset8::Multiply>(alpha, data);
|
||||
auto max = std::make_shared<opset8::Maximum>(multiply, data);
|
||||
model = std::make_shared<Model>(NodeVector{max}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::LeakyReluFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
|
||||
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
|
||||
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
|
||||
model_ref = std::make_shared<Model>(NodeVector{leaky_relu}, ParameterVector{data});
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, LeakyReluFusionScalar) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
|
||||
@ -69,11 +108,4 @@ TEST_F(TransformationTestsF, LeakyReluFusionParameter) {
|
||||
|
||||
manager.register_pass<ov::pass::LeakyReluFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 2});
|
||||
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
|
||||
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
|
||||
model_ref = std::make_shared<Model>(NodeVector{leaky_relu}, ParameterVector{data, alpha});
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user