Add Add->Clamp->Div->Mul to HSwish fusion (#4027)

* Add Add->Clamp->Div->Mul to HSwish fusion

* use opset4 NS instead of op::v0, don't copy constants RT info

* use opset4 in tests
This commit is contained in:
Mateusz Tabaka 2021-02-02 09:12:21 +01:00 committed by GitHub
parent c1b0b03750
commit ecb6d8604e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 13 deletions

View File

@ -17,7 +17,8 @@ class TRANSFORMATIONS_API HSwishFusion;
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
class TRANSFORMATIONS_API HSwishFusionWithClamp;
class TRANSFORMATIONS_API HSwishFusionWithClampMul;
class TRANSFORMATIONS_API HSwishFusionWithClampDiv;
} // namespace pass
@ -57,10 +58,20 @@ public:
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
*/
class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass {
class ngraph::pass::HSwishFusionWithClampMul: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSwishFusionWithClamp();
HSwishFusionWithClampMul();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) / 6) with a HSwish op.
*/
class ngraph::pass::HSwishFusionWithClampDiv: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSwishFusionWithClampDiv();
};
/**
@ -74,6 +85,7 @@ public:
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
add_matcher<ngraph::pass::HSwishFusionWithClampMul>();
add_matcher<ngraph::pass::HSwishFusionWithClampDiv>();
}
};
};

View File

@ -174,10 +174,10 @@ ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClamp, "HSwishFusionWithClamp", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClampMul, "HSwishFusionWithClampMul", 0);
ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
MATCHER_SCOPE(HSwishFusionWithClamp);
ngraph::pass::HSwishFusionWithClampMul::HSwishFusionWithClampMul() {
MATCHER_SCOPE(HSwishFusionWithClampMul);
// Replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
@ -219,3 +219,46 @@ ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, matcher_name);
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClampDiv, "HSwishFusionWithClampDiv", 0);
ngraph::pass::HSwishFusionWithClampDiv::HSwishFusionWithClampDiv() {
MATCHER_SCOPE(HSwishFusionWithClampDiv);
// Replaces a sub-graph x * (Clamp(x + 3, 0, 6) / 6) with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.0f, 6.0f);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto div = std::make_shared<ngraph::opset4::Divide>(clamp, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value(add_const_value, 3.0)
&& op::util::has_constant_value(div_const_value, 6.0);
if (!valid_constant_values) {
return false;
}
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(clamp).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr()
},
hswish);
ngraph::replace_node(m.get_match_root(), hswish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}

View File

@ -151,13 +151,13 @@ TEST(TransformationTests, HSwishFusionWithoutRelu) {
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithClamp) {
TEST(TransformationTests, HSwishFusionWithClampMul) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.0f, 6.0f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0 / 6.0});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
@ -166,7 +166,38 @@ TEST(TransformationTests, HSwishFusionWithClamp) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithClamp>();
manager.register_pass<ngraph::pass::HSwishFusionWithClampMul>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithClampDiv) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.0f, 6.0f);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(clamp, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithClampDiv>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -310,7 +341,7 @@ TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) {
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.11f, 6.02f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
@ -328,7 +359,7 @@ TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) {
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.11f, 6.02f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);