[Transformation]hswish_fusion with clamp mul (#7414)
* [Transformation]hswish_fusion with clamp * add const for constant variable * [Transformation]fix review comments * [Transformation]fix opset version
This commit is contained in:
parent
65dcffe913
commit
a5250fd0fc
@ -17,6 +17,7 @@ class TRANSFORMATIONS_API HSwishFusion;
|
|||||||
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
|
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
|
||||||
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
|
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
|
||||||
class TRANSFORMATIONS_API HSwishFusionWithHSigmoid;
|
class TRANSFORMATIONS_API HSwishFusionWithHSigmoid;
|
||||||
|
class TRANSFORMATIONS_API HSwishFusionWithClamp;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
@ -52,6 +53,17 @@ public:
|
|||||||
HSwishFusionWithHSigmoid();
|
HSwishFusionWithHSigmoid();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief HSwishFusion transformation replaces a sub-graph (Clamp(x + 3, 0, 6) * x) with a HSwish * 6.
|
||||||
|
*/
|
||||||
|
class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
HSwishFusionWithClamp();
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
|
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
|
||||||
@ -63,5 +75,6 @@ public:
|
|||||||
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||||
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
||||||
add_matcher<ngraph::pass::HSwishFusionWithHSigmoid>();
|
add_matcher<ngraph::pass::HSwishFusionWithHSigmoid>();
|
||||||
|
add_matcher<ngraph::pass::HSwishFusionWithClamp>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include <ngraph/opsets/opset7.hpp>
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
@ -20,22 +20,22 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
|
|||||||
MATCHER_SCOPE(HSwishFusionWithReluDiv);
|
MATCHER_SCOPE(HSwishFusionWithReluDiv);
|
||||||
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
|
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
|
||||||
auto input = ngraph::pattern::any_input();
|
auto input = ngraph::pattern::any_input();
|
||||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
|
auto add = std::make_shared<ngraph::opset8::Add>(input, add_constant);
|
||||||
auto relu = std::make_shared<ngraph::opset7::Relu>(add);
|
auto relu = std::make_shared<ngraph::opset8::Relu>(add);
|
||||||
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto min = std::make_shared<ngraph::opset7::Minimum>(relu, min_constant);
|
auto min = std::make_shared<ngraph::opset8::Minimum>(relu, min_constant);
|
||||||
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, min);
|
auto mul = std::make_shared<ngraph::opset8::Multiply>(input, min);
|
||||||
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto div = std::make_shared<ngraph::opset7::Divide>(mul, div_constant);
|
auto div = std::make_shared<ngraph::opset8::Divide>(mul, div_constant);
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||||
auto &pattern_to_output = m.get_pattern_value_map();
|
auto &pattern_to_output = m.get_pattern_value_map();
|
||||||
auto x_output = pattern_to_output.at(input);
|
auto x_output = pattern_to_output.at(input);
|
||||||
|
|
||||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||||
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
||||||
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
|
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
|
||||||
|
|
||||||
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0)
|
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0)
|
||||||
&& op::util::has_constant_value<float>(min_const_value, 6.0)
|
&& op::util::has_constant_value<float>(min_const_value, 6.0)
|
||||||
@ -45,7 +45,7 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto hswish = std::make_shared<ngraph::opset7::HSwish>(x_output);
|
auto hswish = std::make_shared<ngraph::opset8::HSwish>(x_output);
|
||||||
|
|
||||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||||
@ -72,22 +72,22 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
|
|||||||
MATCHER_SCOPE(HSwishFusionWithReluMul);
|
MATCHER_SCOPE(HSwishFusionWithReluMul);
|
||||||
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
|
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
|
||||||
auto input = ngraph::pattern::any_input();
|
auto input = ngraph::pattern::any_input();
|
||||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
|
auto add = std::make_shared<ngraph::opset8::Add>(input, add_constant);
|
||||||
auto relu = std::make_shared<ngraph::opset7::Relu>(add);
|
auto relu = std::make_shared<ngraph::opset8::Relu>(add);
|
||||||
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto min = std::make_shared<ngraph::opset7::Minimum>(relu, min_constant);
|
auto min = std::make_shared<ngraph::opset8::Minimum>(relu, min_constant);
|
||||||
auto mul_first = std::make_shared<ngraph::opset7::Multiply>(input, min);
|
auto mul_first = std::make_shared<ngraph::opset8::Multiply>(input, min);
|
||||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
auto mul_second = std::make_shared<ngraph::opset7::Multiply>(mul_first, mul_constant);
|
auto mul_second = std::make_shared<ngraph::opset8::Multiply>(mul_first, mul_constant);
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||||
auto &pattern_to_output = m.get_pattern_value_map();
|
auto &pattern_to_output = m.get_pattern_value_map();
|
||||||
auto x_output = pattern_to_output.at(input);
|
auto x_output = pattern_to_output.at(input);
|
||||||
|
|
||||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||||
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
||||||
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
|
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
|
||||||
|
|
||||||
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
|
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
|
||||||
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
|
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
|
||||||
@ -97,7 +97,7 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto hswish = std::make_shared<ngraph::opset7::HSwish>(x_output);
|
auto hswish = std::make_shared<ngraph::opset8::HSwish>(x_output);
|
||||||
|
|
||||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||||
@ -124,15 +124,15 @@ ngraph::pass::HSwishFusionWithHSigmoid::HSwishFusionWithHSigmoid() {
|
|||||||
MATCHER_SCOPE(HSwishFusionWithHSigmoid);
|
MATCHER_SCOPE(HSwishFusionWithHSigmoid);
|
||||||
// Replaces a sub-graph x * HSigmoid(x) with a HSwish op.
|
// Replaces a sub-graph x * HSigmoid(x) with a HSwish op.
|
||||||
auto input = pattern::any_input();
|
auto input = pattern::any_input();
|
||||||
auto hsigmoid_pattern = pattern::wrap_type<ngraph::opset7::HSigmoid>({input}, pattern::consumers_count(1));
|
auto hsigmoid_pattern = pattern::wrap_type<ngraph::opset8::HSigmoid>({input}, pattern::consumers_count(1));
|
||||||
auto mul_pattern = pattern::wrap_type<ngraph::opset7::Multiply>({input, hsigmoid_pattern});
|
auto mul_pattern = pattern::wrap_type<ngraph::opset8::Multiply>({input, hsigmoid_pattern});
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
auto hsigmoid = pattern_to_output.at(hsigmoid_pattern).get_node_shared_ptr();
|
auto hsigmoid = pattern_to_output.at(hsigmoid_pattern).get_node_shared_ptr();
|
||||||
auto mul = pattern_to_output.at(mul_pattern).get_node_shared_ptr();
|
auto mul = pattern_to_output.at(mul_pattern).get_node_shared_ptr();
|
||||||
|
|
||||||
auto hswish = std::make_shared<ngraph::opset7::HSwish>(pattern_to_output.at(input));
|
auto hswish = std::make_shared<ngraph::opset8::HSwish>(pattern_to_output.at(input));
|
||||||
hswish->set_friendly_name(mul->get_friendly_name());
|
hswish->set_friendly_name(mul->get_friendly_name());
|
||||||
ngraph::copy_runtime_info({hsigmoid, mul}, hswish);
|
ngraph::copy_runtime_info({hsigmoid, mul}, hswish);
|
||||||
ngraph::replace_node(mul, hswish);
|
ngraph::replace_node(mul, hswish);
|
||||||
@ -142,3 +142,44 @@ ngraph::pass::HSwishFusionWithHSigmoid::HSwishFusionWithHSigmoid() {
|
|||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_pattern, matcher_name);
|
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_pattern, matcher_name);
|
||||||
register_matcher(m, callback);
|
register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSwishFusionWithClamp, "HSwishFusionWithClamp", 0);
|
||||||
|
|
||||||
|
ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
|
||||||
|
MATCHER_SCOPE(HSwishFusionWithClampMul);
|
||||||
|
// Replaces a sub-graph (Clamp(x + 3, 0, 6) * x) with a HSwish * 6.
|
||||||
|
const auto input = ngraph::pattern::any_input();
|
||||||
|
const auto add_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||||
|
const auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({input, add_constant});
|
||||||
|
const auto clamp = ngraph::pattern::wrap_type<ngraph::opset8::Clamp>({add});
|
||||||
|
const auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({clamp, input});
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||||
|
const auto &pattern_to_output = m.get_pattern_value_map();
|
||||||
|
const auto x_output = pattern_to_output.at(input);
|
||||||
|
const auto add_const_value = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||||
|
if (!op::util::has_constant_value(add_const_value, 3.0)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto clamp_node = std::dynamic_pointer_cast<ngraph::opset8::Clamp>(pattern_to_output.at(clamp).get_node_shared_ptr());
|
||||||
|
if (!clamp_node || clamp_node->get_min() != 0 || clamp_node->get_max() != 6)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto hswish = std::make_shared<ngraph::opset8::HSwish>(x_output);
|
||||||
|
auto new_mul_const = std::make_shared<ngraph::opset8::Constant>(add_const_value->get_element_type(), Shape{}, std::vector<float>{6.0});
|
||||||
|
auto new_mul = std::make_shared<ngraph::opset8::Multiply>(hswish, new_mul_const);
|
||||||
|
|
||||||
|
new_mul->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(clamp).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(mul).get_node_shared_ptr()
|
||||||
|
},
|
||||||
|
{hswish, new_mul_const, new_mul});
|
||||||
|
ngraph::replace_node(m.get_match_root(), new_mul);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
|
||||||
|
register_matcher(m, callback);
|
||||||
|
}
|
@ -408,3 +408,69 @@ TEST(TransformationTests, HSwishFusionWithHSigmoidMul) {
|
|||||||
auto res = compare_functions(f, f_ref);
|
auto res = compare_functions(f, f_ref);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(TransformationTests, HSwishFusionWithClamp) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||||
|
auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||||
|
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
|
||||||
|
auto clamp = std::make_shared<ngraph::opset7::Clamp>(add, 0.0f, 6.0f);
|
||||||
|
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, clamp);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
auto gr = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||||
|
gr->add_matcher<ngraph::pass::HSwishFusion>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||||
|
auto hswish = std::make_shared<ngraph::opset7::HSwish>(input);
|
||||||
|
auto mul_const = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||||
|
auto mul = std::make_shared<ngraph::opset7::Multiply>(hswish, mul_const);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, HSwishFusionWithClampWithWrongConstant) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||||
|
auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||||
|
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
|
||||||
|
auto clamp = std::make_shared<ngraph::opset7::Clamp>(add, 0.11f, 6.32f);
|
||||||
|
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, clamp);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
auto gr = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||||
|
gr->add_matcher<ngraph::pass::HSwishFusion>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||||
|
auto add_constant = ngraph::opset7::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||||
|
auto add = std::make_shared<ngraph::opset7::Add>(input, add_constant);
|
||||||
|
auto clamp = std::make_shared<ngraph::opset7::Clamp>(add, 0.11f, 6.32f);
|
||||||
|
auto mul = std::make_shared<ngraph::opset7::Multiply>(input, clamp);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user