Add Add->Clamp->Div->Mul to HSwish fusion (#4027)
* Add Add->Clamp->Div->Mul to HSwish fusion * use opset4 NS instead of op::v0, don't copy constants RT info * use opset4 in tests
This commit is contained in:
parent
c1b0b03750
commit
ecb6d8604e
@ -17,7 +17,8 @@ class TRANSFORMATIONS_API HSwishFusion;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithClamp;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithClampMul;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithClampDiv;
|
||||
|
||||
|
||||
} // namespace pass
|
||||
@ -57,10 +58,20 @@ public:
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass {
|
||||
class ngraph::pass::HSwishFusionWithClampMul: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSwishFusionWithClamp();
|
||||
HSwishFusionWithClampMul();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) / 6) with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusionWithClampDiv: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSwishFusionWithClampDiv();
|
||||
};
|
||||
|
||||
/**
|
||||
@ -74,6 +85,7 @@ public:
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithClampMul>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithClampDiv>();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
@ -174,10 +174,10 @@ ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClamp, "HSwishFusionWithClamp", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClampMul, "HSwishFusionWithClampMul", 0);
|
||||
|
||||
ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
|
||||
MATCHER_SCOPE(HSwishFusionWithClamp);
|
||||
ngraph::pass::HSwishFusionWithClampMul::HSwishFusionWithClampMul() {
|
||||
MATCHER_SCOPE(HSwishFusionWithClampMul);
|
||||
// Replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
@ -219,3 +219,46 @@ ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClampDiv, "HSwishFusionWithClampDiv", 0);
|
||||
|
||||
ngraph::pass::HSwishFusionWithClampDiv::HSwishFusionWithClampDiv() {
|
||||
MATCHER_SCOPE(HSwishFusionWithClampDiv);
|
||||
// Replaces a sub-graph x * (Clamp(x + 3, 0, 6) / 6) with a HSwish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.0f, 6.0f);
|
||||
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(clamp, div_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
|
||||
|
||||
bool valid_constant_values = op::util::has_constant_value(add_const_value, 3.0)
|
||||
&& op::util::has_constant_value(div_const_value, 6.0);
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
|
||||
|
||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(clamp).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr()
|
||||
},
|
||||
hswish);
|
||||
ngraph::replace_node(m.get_match_root(), hswish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -151,13 +151,13 @@ TEST(TransformationTests, HSwishFusionWithoutRelu) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithClamp) {
|
||||
TEST(TransformationTests, HSwishFusionWithClampMul) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
|
||||
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.0f, 6.0f);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0 / 6.0});
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
@ -166,7 +166,38 @@ TEST(TransformationTests, HSwishFusionWithClamp) {
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithClamp>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithClampMul>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithClampDiv) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.0f, 6.0f);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(clamp, div_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithClampDiv>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -310,7 +341,7 @@ TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) {
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
|
||||
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.11f, 6.02f);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
@ -328,7 +359,7 @@ TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) {
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
|
||||
auto clamp = std::make_shared<ngraph::opset4::Clamp>(add, 0.11f, 6.02f);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(input, mul_first);
|
||||
|
Loading…
Reference in New Issue
Block a user