From a5250fd0fcee4f129d459d0b26ae75829f96bbfe Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 29 Sep 2021 00:22:53 +0800 Subject: [PATCH] [Transformation]hswish_fusion with clamp mul (#7414) * [Transformation]hswish_fusion with clamp * add const for constant variable * [Transformation]fix review comments * [Transformation]fix opset version --- .../common_optimizations/hswish_fusion.hpp | 13 +++ .../common_optimizations/hswish_fusion.cpp | 97 +++++++++++++------ .../transformations/hswish_fusion_test.cpp | 66 +++++++++++++ 3 files changed, 148 insertions(+), 28 deletions(-) diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/hswish_fusion.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/hswish_fusion.hpp index 94d76493be1..5adef99494a 100644 --- a/inference-engine/src/transformations/include/transformations/common_optimizations/hswish_fusion.hpp +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/hswish_fusion.hpp @@ -17,6 +17,7 @@ class TRANSFORMATIONS_API HSwishFusion; class TRANSFORMATIONS_API HSwishFusionWithReluDiv; class TRANSFORMATIONS_API HSwishFusionWithReluMul; class TRANSFORMATIONS_API HSwishFusionWithHSigmoid; +class TRANSFORMATIONS_API HSwishFusionWithClamp; } // namespace pass } // namespace ngraph @@ -52,6 +53,17 @@ public: HSwishFusionWithHSigmoid(); }; + +/** + * @ingroup ie_transformation_common_api + * @brief HSwishFusion transformation replaces a sub-graph (Clamp(x + 3, 0, 6) * x) with a HSwish * 6. + */ +class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + HSwishFusionWithClamp(); +}; + /** * @ingroup ie_transformation_common_api * @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op. @@ -63,5 +75,6 @@ public: add_matcher(); add_matcher(); add_matcher(); + add_matcher(); } }; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/hswish_fusion.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/hswish_fusion.cpp index 038dbd2ae97..a0134a8c5c0 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/hswish_fusion.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/hswish_fusion.cpp @@ -8,7 +8,7 @@ #include -#include +#include #include #include @@ -20,22 +20,22 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() { MATCHER_SCOPE(HSwishFusionWithReluDiv); // Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op. auto input = ngraph::pattern::any_input(); - auto add_constant = ngraph::pattern::wrap_type(); - auto add = std::make_shared(input, add_constant); - auto relu = std::make_shared(add); - auto min_constant = ngraph::pattern::wrap_type(); - auto min = std::make_shared(relu, min_constant); - auto mul = std::make_shared(input, min); - auto div_constant = ngraph::pattern::wrap_type(); - auto div = std::make_shared(mul, div_constant); + auto add_constant = ngraph::pattern::wrap_type(); + auto add = std::make_shared(input, add_constant); + auto relu = std::make_shared(add); + auto min_constant = ngraph::pattern::wrap_type(); + auto min = std::make_shared(relu, min_constant); + auto mul = std::make_shared(input, min); + auto div_constant = ngraph::pattern::wrap_type(); + auto div = std::make_shared(mul, div_constant); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { auto &pattern_to_output = m.get_pattern_value_map(); auto x_output = pattern_to_output.at(input); - auto add_const_value = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); - auto min_const_value = std::dynamic_pointer_cast(pattern_to_output.at(min_constant).get_node_shared_ptr()); - auto div_const_value = std::dynamic_pointer_cast(pattern_to_output.at(div_constant).get_node_shared_ptr()); + auto add_const_value = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); + auto min_const_value = std::dynamic_pointer_cast(pattern_to_output.at(min_constant).get_node_shared_ptr()); + auto div_const_value = std::dynamic_pointer_cast(pattern_to_output.at(div_constant).get_node_shared_ptr()); bool valid_constant_values = op::util::has_constant_value(add_const_value, 3.0) && op::util::has_constant_value(min_const_value, 6.0) @@ -45,7 +45,7 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() { return false; } - auto hswish = std::make_shared(x_output); + auto hswish = std::make_shared(x_output); hswish->set_friendly_name(m.get_match_root()->get_friendly_name()); ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(), @@ -72,22 +72,22 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() { MATCHER_SCOPE(HSwishFusionWithReluMul); // Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op. auto input = ngraph::pattern::any_input(); - auto add_constant = ngraph::pattern::wrap_type(); - auto add = std::make_shared(input, add_constant); - auto relu = std::make_shared(add); - auto min_constant = ngraph::pattern::wrap_type(); - auto min = std::make_shared(relu, min_constant); - auto mul_first = std::make_shared(input, min); - auto mul_constant = ngraph::pattern::wrap_type(); - auto mul_second = std::make_shared(mul_first, mul_constant); + auto add_constant = ngraph::pattern::wrap_type(); + auto add = std::make_shared(input, add_constant); + auto relu = std::make_shared(add); + auto min_constant = ngraph::pattern::wrap_type(); + auto min = std::make_shared(relu, min_constant); + auto mul_first = std::make_shared(input, min); + auto mul_constant = ngraph::pattern::wrap_type(); + auto mul_second = std::make_shared(mul_first, mul_constant); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { auto &pattern_to_output = m.get_pattern_value_map(); auto x_output = pattern_to_output.at(input); - auto add_const_value = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); - auto min_const_value = std::dynamic_pointer_cast(pattern_to_output.at(min_constant).get_node_shared_ptr()); - auto mul_const_value = std::dynamic_pointer_cast(pattern_to_output.at(mul_constant).get_node_shared_ptr()); + auto add_const_value = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); + auto min_const_value = std::dynamic_pointer_cast(pattern_to_output.at(min_constant).get_node_shared_ptr()); + auto mul_const_value = std::dynamic_pointer_cast(pattern_to_output.at(mul_constant).get_node_shared_ptr()); bool valid_constant_values = op::util::has_constant_value(add_const_value, 3.0f) && op::util::has_constant_value(min_const_value, 6.0f) @@ -97,7 +97,7 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() { return false; } - auto hswish = std::make_shared(x_output); + auto hswish = std::make_shared(x_output); hswish->set_friendly_name(m.get_match_root()->get_friendly_name()); ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(), @@ -124,15 +124,15 @@ ngraph::pass::HSwishFusionWithHSigmoid::HSwishFusionWithHSigmoid() { MATCHER_SCOPE(HSwishFusionWithHSigmoid); // Replaces a sub-graph x * HSigmoid(x) with a HSwish op. auto input = pattern::any_input(); - auto hsigmoid_pattern = pattern::wrap_type({input}, pattern::consumers_count(1)); - auto mul_pattern = pattern::wrap_type({input, hsigmoid_pattern}); + auto hsigmoid_pattern = pattern::wrap_type({input}, pattern::consumers_count(1)); + auto mul_pattern = pattern::wrap_type({input, hsigmoid_pattern}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { const auto& pattern_to_output = m.get_pattern_value_map(); auto hsigmoid = pattern_to_output.at(hsigmoid_pattern).get_node_shared_ptr(); auto mul = pattern_to_output.at(mul_pattern).get_node_shared_ptr(); - auto hswish = std::make_shared(pattern_to_output.at(input)); + auto hswish = std::make_shared(pattern_to_output.at(input)); hswish->set_friendly_name(mul->get_friendly_name()); ngraph::copy_runtime_info({hsigmoid, mul}, hswish); ngraph::replace_node(mul, hswish); @@ -142,3 +142,44 @@ ngraph::pass::HSwishFusionWithHSigmoid::HSwishFusionWithHSigmoid() { auto m = std::make_shared(mul_pattern, matcher_name); register_matcher(m, callback); } + +NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClamp, "HSwishFusionWithClamp", 0); + +ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() { + MATCHER_SCOPE(HSwishFusionWithClampMul); + // Replaces a sub-graph (Clamp(x + 3, 0, 6) * x) with a HSwish * 6. + const auto input = ngraph::pattern::any_input(); + const auto add_constant = ngraph::pattern::wrap_type(); + const auto add = ngraph::pattern::wrap_type({input, add_constant}); + const auto clamp = ngraph::pattern::wrap_type({add}); + const auto mul = ngraph::pattern::wrap_type({clamp, input}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { + const auto &pattern_to_output = m.get_pattern_value_map(); + const auto x_output = pattern_to_output.at(input); + const auto add_const_value = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); + if (!op::util::has_constant_value(add_const_value, 3.0)) { + return false; + } + + const auto clamp_node = std::dynamic_pointer_cast(pattern_to_output.at(clamp).get_node_shared_ptr()); + if (!clamp_node || clamp_node->get_min() != 0 || clamp_node->get_max() != 6) + return false; + + auto hswish = std::make_shared(x_output); + auto new_mul_const = std::make_shared(add_const_value->get_element_type(), Shape{}, std::vector{6.0}); + auto new_mul = std::make_shared(hswish, new_mul_const); + + new_mul->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(clamp).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr() + }, + {hswish, new_mul_const, new_mul}); + ngraph::replace_node(m.get_match_root(), new_mul); + return true; + }; + + auto m = std::make_shared(mul, matcher_name); + register_matcher(m, callback); +} \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp index 7f3954656f3..d6ec593f48a 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp @@ -408,3 +408,69 @@ TEST(TransformationTests, HSwishFusionWithHSigmoidMul) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; } + + +TEST(TransformationTests, HSwishFusionWithClamp) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0}); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.0f, 6.0f); + auto mul = std::make_shared(input, clamp); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + auto gr = manager.register_pass(); + gr->add_matcher(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto hswish = std::make_shared(input); + auto mul_const = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0}); + auto mul = std::make_shared(hswish, mul_const); + + f_ref = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, HSwishFusionWithClampWithWrongConstant) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11}); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.11f, 6.32f); + auto mul = std::make_shared(input, clamp); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + auto gr = manager.register_pass(); + gr->add_matcher(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11}); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.11f, 6.32f); + auto mul = std::make_shared(input, clamp); + + f_ref = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} \ No newline at end of file