From f7e898893d9be15c2a68f5d945fa3c6294c70993 Mon Sep 17 00:00:00 2001 From: Tomasz Jankowski Date: Wed, 29 Mar 2023 13:32:57 +0200 Subject: [PATCH] Add PRelu fusion (#16617) --- .../common_optimizations/prelu_fusion.hpp | 25 +++++++++++-- .../common_optimizations/prelu_fusion.cpp | 35 +++++++++++++++++++ .../transformations/tests/prelu_fusion.cpp | 25 +++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/prelu_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/prelu_fusion.hpp index 7e4939a2c5b..e9b090381ec 100644 --- a/src/common/transformations/include/transformations/common_optimizations/prelu_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/prelu_fusion.hpp @@ -18,6 +18,7 @@ class TRANSFORMATIONS_API PReluFusionNegativeSub; class TRANSFORMATIONS_API PReluFusionMultiplyAdd; class TRANSFORMATIONS_API PReluFusionMultiplySub; class TRANSFORMATIONS_API PReluFusionAbsSubMulMulAdd; +class TRANSFORMATIONS_API PReluFusionNegReluMulAdd; } // namespace pass } // namespace ov @@ -103,11 +104,11 @@ public: /** * @ingroup ie_transformation_common_api * @brief PReluFusionAbsSubMulMulAdd transformation replaces a sub-graph - * Op + * Op * / | \ * Relu | Abs * | \ | - * | Sub + * | Subtract * | | * | Multiply * | | @@ -121,6 +122,25 @@ public: PReluFusionAbsSubMulMulAdd(); }; +/** + * @ingroup ie_transformation_common_api + * @brief PReluFusionNegReluMulAdd transformation replaces a sub-graph + * Op + * / \ + * Relu Negative + * | | + * | Relu + * | | + * | Multiply + * \ / + * Add + */ +class ov::pass::PReluFusionNegReluMulAdd : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("PReluFusionNegReluMulAdd", "0"); + PReluFusionNegReluMulAdd(); +}; + /** * @ingroup ie_transformation_common_api * @brief PReluFusion transformation replaces various sub-graphs with a PRelu op. @@ -134,5 +154,6 @@ public: add_matcher(); add_matcher(); add_matcher(); + add_matcher(); } }; diff --git a/src/common/transformations/src/transformations/common_optimizations/prelu_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/prelu_fusion.cpp index 98474c2ad9a..4a2ed9729ec 100644 --- a/src/common/transformations/src/transformations/common_optimizations/prelu_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/prelu_fusion.cpp @@ -201,3 +201,38 @@ ov::pass::PReluFusionAbsSubMulMulAdd::PReluFusionAbsSubMulMulAdd() { auto m = make_shared(add, matcher_name); register_matcher(m, callback); } + +ov::pass::PReluFusionNegReluMulAdd::PReluFusionNegReluMulAdd() { + MATCHER_SCOPE(PReluFusionNegReluMulAdd); + + using namespace std; + using namespace ov; + using namespace ov::opset10; + + const auto input = pass::pattern::any_input(); + const auto relu_pos = pattern::wrap_type({input}); + const auto neg1 = pattern::wrap_type({input}); + const auto relu_neg = pattern::wrap_type({neg1}); + const auto mul_constant = pattern::wrap_type(); + const auto mul = pattern::wrap_type({relu_neg, mul_constant}); + const auto add = pattern::wrap_type({relu_pos, mul}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + const auto input_output = pattern_to_output.at(input); + const auto add_node = pattern_to_output.at(add).get_node_shared_ptr(); + const auto slope = op::util::make_try_fold(pattern_to_output.at(mul_constant)); + const auto prelu = make_shared(input_output, slope); + prelu->set_friendly_name(m.get_match_root()->get_friendly_name()); + NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(), + pattern_to_output.at(neg1).get_node_shared_ptr(), + pattern_to_output.at(relu_neg).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr()}; + copy_runtime_info(copy_from, prelu); + replace_node(add_node, prelu); + return true; + }; + auto matcher = make_shared(add, matcher_name); + register_matcher(matcher, callback); +} diff --git a/src/common/transformations/tests/prelu_fusion.cpp b/src/common/transformations/tests/prelu_fusion.cpp index 3c2df005f3d..ba878750cdd 100644 --- a/src/common/transformations/tests/prelu_fusion.cpp +++ b/src/common/transformations/tests/prelu_fusion.cpp @@ -156,4 +156,29 @@ TEST_F(TransformationTestsF, PReluFusionAbsSubMulMulAdd) { const auto prelu = make_shared(data, prelu_const); function_ref = make_shared(NodeVector{prelu}, ParameterVector{data}); } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} + +TEST_F(TransformationTestsF, PReluFusionNegReluMulAdd) { + using namespace std; + using namespace ov::opset10; + { + const auto data = make_shared(element::f32, Shape{2, 12}); + const auto relu_pos = make_shared(data); + const auto neg = make_shared(data); + const auto relu_neg = make_shared(neg); + const auto mul_const = Constant::create(element::f32, Shape{1}, {0.235}); + const auto mul = make_shared(relu_neg, mul_const); + const auto add = make_shared(relu_pos, mul); + function = make_shared(NodeVector{add}, ParameterVector{data}); + + manager.register_pass(); + } + { + const auto data = make_shared(element::f32, Shape{2, 12}); + const auto prelu_const = Constant::create(element::f32, Shape{1}, {-0.235}); + const auto prelu = make_shared(data, prelu_const); + function_ref = make_shared(NodeVector{prelu}, ParameterVector{data}); + } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); }