[Transformation]hswish_fusion with clamp mul (#7414)

* [Transformation]hswish_fusion with clamp

* add const for constant variable

* [Transformation]fix review comments

* [Transformation]fix opset version
This commit is contained in:
Zhang Yi 2021-09-29 00:22:53 +08:00 committed by GitHub
parent 65dcffe913
commit a5250fd0fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 148 additions and 28 deletions

View File

@ -17,6 +17,7 @@ class TRANSFORMATIONS_API HSwishFusion;
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
class TRANSFORMATIONS_API HSwishFusionWithHSigmoid;
class TRANSFORMATIONS_API HSwishFusionWithClamp;
} // namespace pass
} // namespace ngraph
@ -52,6 +53,17 @@ public:
HSwishFusionWithHSigmoid();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph (Clamp(x + 3, 0, 6) * x) with a HSwish * 6.
*/
class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSwishFusionWithClamp();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
@ -63,5 +75,6 @@ public:
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
add_matcher<ngraph::pass::HSwishFusionWithHSigmoid>();
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
}
};

View File

@ -8,7 +8,7 @@
#include <memory>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
@ -20,22 +20,22 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
MATCHER_SCOPE(HSwishFusionWithReluDiv);
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset7::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto min = std::make_shared<ngraph::opset7::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, min);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto div = std::make_shared<ngraph::opset7::Divide>(mul, div_constant);
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto add = std::make_shared<ngraph::opset8::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset8::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto min = std::make_shared<ngraph::opset8::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset8::Multiply>(input, min);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto div = std::make_shared<ngraph::opset8::Divide>(mul, div_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0)
&& op::util::has_constant_value<float>(min_const_value, 6.0)
@ -45,7 +45,7 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
return false;
}
auto hswish = std::make_shared<ngraph::opset7::HSwish>(x_output);
auto hswish = std::make_shared<ngraph::opset8::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
@ -72,22 +72,22 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
MATCHER_SCOPE(HSwishFusionWithReluMul);
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset7::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto min = std::make_shared<ngraph::opset7::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset7::Multiply>(input, min);
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto mul_second = std::make_shared<ngraph::opset7::Multiply>(mul_first, mul_constant);
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto add = std::make_shared<ngraph::opset8::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset8::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto min = std::make_shared<ngraph::opset8::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset8::Multiply>(input, min);
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto mul_second = std::make_shared<ngraph::opset8::Multiply>(mul_first, mul_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
@ -97,7 +97,7 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
return false;
}
auto hswish = std::make_shared<ngraph::opset7::HSwish>(x_output);
auto hswish = std::make_shared<ngraph::opset8::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
@ -124,15 +124,15 @@ ngraph::pass::HSwishFusionWithHSigmoid::HSwishFusionWithHSigmoid() {
MATCHER_SCOPE(HSwishFusionWithHSigmoid);
// Replaces a sub-graph x * HSigmoid(x) with a HSwish op.
auto input = pattern::any_input();
auto hsigmoid_pattern = pattern::wrap_type<ngraph::opset7::HSigmoid>({input}, pattern::consumers_count(1));
auto mul_pattern = pattern::wrap_type<ngraph::opset7::Multiply>({input, hsigmoid_pattern});
auto hsigmoid_pattern = pattern::wrap_type<ngraph::opset8::HSigmoid>({input}, pattern::consumers_count(1));
auto mul_pattern = pattern::wrap_type<ngraph::opset8::Multiply>({input, hsigmoid_pattern});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto hsigmoid = pattern_to_output.at(hsigmoid_pattern).get_node_shared_ptr();
auto mul = pattern_to_output.at(mul_pattern).get_node_shared_ptr();
auto hswish = std::make_shared<ngraph::opset7::HSwish>(pattern_to_output.at(input));
auto hswish = std::make_shared<ngraph::opset8::HSwish>(pattern_to_output.at(input));
hswish->set_friendly_name(mul->get_friendly_name());
ngraph::copy_runtime_info({hsigmoid, mul}, hswish);
ngraph::replace_node(mul, hswish);
@ -142,3 +142,44 @@ ngraph::pass::HSwishFusionWithHSigmoid::HSwishFusionWithHSigmoid() {
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_pattern, matcher_name);
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClamp, "HSwishFusionWithClamp", 0);
ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
MATCHER_SCOPE(HSwishFusionWithClampMul);
// Replaces a sub-graph (Clamp(x + 3, 0, 6) * x) with a HSwish * 6.
const auto input = ngraph::pattern::any_input();
const auto add_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
const auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({input, add_constant});
const auto clamp = ngraph::pattern::wrap_type<ngraph::opset8::Clamp>({add});
const auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({clamp, input});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto &pattern_to_output = m.get_pattern_value_map();
const auto x_output = pattern_to_output.at(input);
const auto add_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
if (!op::util::has_constant_value(add_const_value, 3.0)) {
return false;
}
const auto clamp_node = std::dynamic_pointer_cast<ngraph::opset8::Clamp>(pattern_to_output.at(clamp).get_node_shared_ptr());
if (!clamp_node || clamp_node->get_min() != 0 || clamp_node->get_max() != 6)
return false;
auto hswish = std::make_shared<ngraph::opset8::HSwish>(x_output);
auto new_mul_const = std::make_shared<ngraph::opset8::Constant>(add_const_value->get_element_type(), Shape{}, std::vector<float>{6.0});
auto new_mul = std::make_shared<ngraph::opset8::Multiply>(hswish, new_mul_const);
new_mul->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(clamp).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr()
},
{hswish, new_mul_const, new_mul});
ngraph::replace_node(m.get_match_root(), new_mul);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}

View File

@ -408,3 +408,69 @@ TEST(TransformationTests, HSwishFusionWithHSigmoidMul) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithClamp) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::opset7::Clamp>(add, 0.0f, 6.0f);
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, clamp);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
auto gr = manager.register_pass<ngraph::pass::GraphRewrite>();
gr->add_matcher<ngraph::pass::HSwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset7::HSwish>(input);
auto mul_const = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto mul = std::make_shared<ngraph::opset7::Multiply>(hswish, mul_const);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithClampWithWrongConstant) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::opset7::Clamp>(add, 0.11f, 6.32f);
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, clamp);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
auto gr = manager.register_pass<ngraph::pass::GraphRewrite>();
gr->add_matcher<ngraph::pass::HSwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::opset7::Clamp>(add, 0.11f, 6.32f);
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, clamp);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}