Added MVN fusion for case with constants inside (#4961)

This commit is contained in:
Alexandra Sidorova 2021-05-14 07:52:05 +03:00 committed by GitHub
parent 61e9d020d4
commit 621e36ee79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 189 additions and 5 deletions

View File

@ -16,7 +16,9 @@
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API MVNFusion;
class TRANSFORMATIONS_API MVNFusion;
class TRANSFORMATIONS_API MVNFusionWithoutConstants;
class TRANSFORMATIONS_API MVNFusionWithConstantsInside;
} // namespace pass
} // namespace ngraph
@ -26,8 +28,32 @@ namespace pass {
* @brief MVNFusion transformation replaces group of
* operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op.
*/
class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass {
class ngraph::pass::MVNFusionWithoutConstants : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusion();
MVNFusionWithoutConstants();
};
/**
* @ingroup ie_transformation_common_api
* @brief MVNFusion transformation replaces group of
* operations: gamma * (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) - beta to MVN op.
*/
class ngraph::pass::MVNFusionWithConstantsInside : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusionWithConstantsInside();
};
/**
* @ingroup ie_transformation_common_api
* @brief MVNFusion transformation replaces various sub-graphs with a MVN op.
*/
class ngraph::pass::MVNFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusion() {
add_matcher<ngraph::pass::MVNFusionWithoutConstants>();
add_matcher<ngraph::pass::MVNFusionWithConstantsInside>();
}
};

View File

@ -27,8 +27,10 @@ std::function<bool(ngraph::Output<ngraph::Node>)> value_is_equal_to(const std::v
};
}
ngraph::pass::MVNFusion::MVNFusion() {
MATCHER_SCOPE(MVNFusion);
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusionWithoutConstants, "MVNFusionWithoutConstants", 0);
ngraph::pass::MVNFusionWithoutConstants::MVNFusionWithoutConstants() {
MATCHER_SCOPE(MVNFusionWithoutConstants);
// Detect MVN decomposition pattern:
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
auto x = pattern::any_input();
@ -188,3 +190,113 @@ ngraph::pass::MVNFusion::MVNFusion() {
auto m = std::make_shared<ngraph::pattern::Matcher>(powerMulOrDiv, matcher_name);
register_matcher(m, matcher_pass_callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusionWithConstantsInside, "MVNFusionWithConstantsInside", 0);
ngraph::pass::MVNFusionWithConstantsInside::MVNFusionWithConstantsInside() {
MATCHER_SCOPE(MVNFusionWithConstantsInside);
// Detect MVN decomposition pattern:
// (x - ReduceMean(x, axes)) * gamma / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) - beta
auto x = pattern::any_input();
// (x - ReduceMean(x, axes))^2
// `------mean1-------'
auto mean1_axes = pattern::wrap_type<opset6::Constant>();
auto mean1 = pattern::wrap_type<opset6::ReduceMean>({ x, mean1_axes });
// (x - ReduceMean(x, axes))^2
// `-squared_difference------'
auto squared_difference = pattern::wrap_type<opset6::SquaredDifference>({ x, mean1 });
// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---mean2--------------------------------'
auto mean2_axes = pattern::wrap_type<opset6::Constant>();
auto mean2 = pattern::wrap_type<opset6::ReduceMean>({ squared_difference, mean2_axes });
// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `------------------------------------------add--'
auto eps = pattern::wrap_type<opset6::Constant>();
auto add_eps = pattern::wrap_type<opset6::Add>({ mean2, eps });
// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `-power-------------------------------------------------'
auto const_0_5 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({-0.5}));
auto power = pattern::wrap_type<opset6::Power>({ add_eps, const_0_5 });
// gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---mul1----------------------------------------------------'
auto gamma = pattern::wrap_type<opset6::Constant>();
auto mul1 = pattern::wrap_type<opset6::Multiply>({ power, gamma });
// x * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---mul2--------------------------------------------------------'
auto mul2 = pattern::wrap_type<opset6::Multiply>({ x, mul1 });
// ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) - beta
// `-------------------mul3----------------------------------------------------------'
auto mul3 = pattern::wrap_type<opset6::Multiply>({ mul1, mean1 });
// beta - ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---sub-----------------------------------------------------------------------------------'
auto beta = pattern::wrap_type<opset6::Constant>();
auto sub = pattern::wrap_type<opset6::Subtract>({ beta, mul3 });
// Final Add
// x * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) +
// beta - ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) =
// gamma * (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) - beta
auto add = pattern::wrap_type<opset6::Add>({ mul2, sub });
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(x);
auto const_0_5_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(const_0_5).get_node_shared_ptr());
if (!const_0_5_node) {
return false;
}
auto const_gamma_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(gamma).get_node_shared_ptr());
auto const_beta_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(beta).get_node_shared_ptr());
auto const_eps_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(eps).get_node_shared_ptr());
float eps_value;
bool valid_constant_values = op::util::has_constant_value<float>(const_0_5_node, -0.5) && op::util::get_single_value(const_eps_node, eps_value);
if (!valid_constant_values) {
return false;
}
auto axes_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean1_axes).get_node_shared_ptr());
auto axes_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean2_axes).get_node_shared_ptr());
if (!axes_1_node || !axes_2_node) {
return false;
}
auto axes_1_value = axes_1_node->cast_vector<int64_t>();
auto axes_2_value = axes_2_node->cast_vector<int64_t>();
if (axes_1_value != axes_2_value) {
return false;
}
auto mvn = std::make_shared<ngraph::opset6::MVN>(x_output, axes_1_node, true, eps_value, op::MVNEpsMode::INSIDE_SQRT);
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(mvn, const_gamma_node);
auto add_beta = std::make_shared<ngraph::opset6::Add>(mul_gamma, const_beta_node);
ngraph::copy_runtime_info({ pattern_to_output.at(mean1).get_node_shared_ptr(),
pattern_to_output.at(squared_difference).get_node_shared_ptr(),
pattern_to_output.at(add_eps).get_node_shared_ptr(),
pattern_to_output.at(power).get_node_shared_ptr(),
pattern_to_output.at(mul1).get_node_shared_ptr(),
pattern_to_output.at(mul2).get_node_shared_ptr(),
pattern_to_output.at(mul3).get_node_shared_ptr(),
pattern_to_output.at(sub).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr() },
{ mvn, const_gamma_node, mul_gamma, const_beta_node, add_beta });
add_beta->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::replace_node(m.get_match_root(), add_beta);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -419,3 +419,49 @@ TEST(TransformationTests, MVNFusionTestAltDivInsideSqrt) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestWithParametersInside) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes, true);
auto squared_difference = std::make_shared<ngraph::opset6::SquaredDifference>(input, mean1);
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(squared_difference, mean2_axes, true);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean2, eps);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
auto gamma = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1 });
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(power_sqrt, gamma);
auto mul1 = std::make_shared<ngraph::opset6::Multiply>(input, mul_gamma);
auto mul2 = std::make_shared<ngraph::opset6::Multiply>(mul_gamma, mean1);
auto beta = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto sub = std::make_shared<ngraph::opset6::Subtract>(beta, mul2);
auto add = std::make_shared<ngraph::opset6::Add>(mul1, sub);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
auto gamma = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1 });
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(mvn, gamma);
auto beta = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto add = std::make_shared<ngraph::opset6::Add>(mul_gamma, beta);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}