Add MVN fusion (#4647)

* Add MVN fusion

* Fix map for Or pattern

* Use count instead of find in MVN fusion

* Apply review feedback

* Apply review feedback

* Fuse patterns

* Fix Win build

* Apply feedback
This commit is contained in:
Maxim Vafin 2021-03-23 19:27:53 +03:00 committed by GitHub
parent 3b3d9a0989
commit b2f3243387
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 665 additions and 7 deletions

View File

@ -0,0 +1,33 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API MVNFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief MVNFusion transformation replaces group of
* operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op.
*/
class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusion();
};

View File

@ -30,6 +30,7 @@
#include "transformations/common_optimizations/pad_fusion.hpp"
#include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp"
#include "transformations/common_optimizations/softmax_fusion.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
@ -98,6 +99,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
common_fusions->add_matcher<ngraph::pass::PadFusion>();
common_fusions->add_matcher<ngraph::pass::SoftmaxFusion>();
common_fusions->add_matcher<ngraph::pass::MVNFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();

View File

@ -0,0 +1,190 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
#include "transformations/utils/utils.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusion, "MVNFusion", 0);
template <class T>
std::function<bool(ngraph::Output<ngraph::Node>)> value_is_equal_to(const std::vector<T>& ref_values) {
return [ref_values](ngraph::Output<ngraph::Node> output) -> bool {
auto node = output.get_node_shared_ptr();
if (auto const_node = std::dynamic_pointer_cast<ngraph::op::Constant>(node)) {
return const_node->template cast_vector<T>() == ref_values;
}
return false;
};
}
ngraph::pass::MVNFusion::MVNFusion() {
MATCHER_SCOPE(MVNFusion);
// Detect MVN decomposition pattern:
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
auto x = pattern::any_input();
// (x - ReduceMean(x, axes))
// `------mean1-------'
auto mean1_axes = pattern::wrap_type<opset6::Constant>();
auto mean1 = pattern::wrap_type<opset6::ReduceMean>({ x, mean1_axes });
// (x - ReduceMean(x, axes))
// `-sub1------------------'
auto sub1 = pattern::wrap_type<opset6::Subtract>({ x, mean1 });
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `---mean2----------'
auto mean2_axes = pattern::wrap_type<opset6::Constant>();
auto mean2 = pattern::wrap_type<opset6::ReduceMean>({ x, mean2_axes });
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `-sub2------------------'
auto sub2 = pattern::wrap_type<opset6::Subtract>({ x, mean2 });
const auto reuseSub1OrNot = std::make_shared<pattern::op::Or>(OutputVector{ sub1, sub2 });
auto cast = pattern::wrap_type<opset6::Convert>({ reuseSub1OrNot });
const auto hasConvertOrNot = std::make_shared<pattern::op::Or>(OutputVector{ cast, reuseSub1OrNot });
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `---------------------power--'
auto const_2 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({ 2.0 }));
auto power = pattern::wrap_type<opset6::Power>({ hasConvertOrNot, const_2 });
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `---mean3--------------------------------'
auto mean3_axes = pattern::wrap_type<opset6::Constant>();
auto mean3 = pattern::wrap_type<opset6::ReduceMean>({ power, mean3_axes });
auto const_0_5 = pattern::wrap_type<ngraph::opset6::Constant>(value_is_equal_to<float>({0.5}));
auto eps = pattern::wrap_type<opset6::Constant>();
// ------------------- OUTSIDE_SQRT ----------------------
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `--Power--------------------------------------'
auto power_sqrt_os = pattern::wrap_type<opset6::Power>({ mean3, const_0_5 });
auto sqrt_os = pattern::wrap_type<opset6::Sqrt>({ mean3 });
const auto powerOrSqrt_os = std::make_shared<pattern::op::Or>(OutputVector{ power_sqrt_os, sqrt_os });
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps
// `----------------------------------------------Add---'
auto add_eps_os = pattern::wrap_type<opset6::Add>({ powerOrSqrt_os, eps });
// ------------------- INSIDE_SQRT ----------------------
// (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps))
// `-----------------------------------------------Add---'
auto add_eps_is = pattern::wrap_type<opset6::Add>({ mean3, eps });
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `--Power--------------------------------------'
auto power_sqrt_is = pattern::wrap_type<opset6::Power>({ add_eps_is, const_0_5 });
auto sqrt_is = pattern::wrap_type<opset6::Sqrt>({ add_eps_is });
const auto powerOrSqrt_is = std::make_shared<pattern::op::Or>(OutputVector{ power_sqrt_is, sqrt_is });
auto outsideOrInside = std::make_shared<pattern::op::Or>(OutputVector{ add_eps_os, powerOrSqrt_is });
// Final Divide
auto const_neg_1 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({ -1 }));
auto power_div = pattern::wrap_type<opset6::Power>({ outsideOrInside, const_neg_1 });
auto div = pattern::wrap_type<opset6::Multiply>({ sub1, power_div });
auto div_alt = pattern::wrap_type<opset6::Divide>({ sub1, outsideOrInside });
const auto powerMulOrDiv = std::make_shared<pattern::op::Or>(OutputVector{ div, div_alt });
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(x);
auto const_eps_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(eps).get_node_shared_ptr());
float eps_value;
if (!op::util::get_single_value(const_eps_node, eps_value)) {
return false;
}
auto axes_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean1_axes).get_node_shared_ptr());
auto axes_3_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean3_axes).get_node_shared_ptr());
if (!axes_1_node || !axes_3_node) {
return false;
}
auto axes_1_value = axes_1_node->cast_vector<int64_t>();
auto axes_3_value = axes_3_node->cast_vector<int64_t>();
if (axes_1_value != axes_3_value) {
return false;
}
if (pattern_to_output.count(mean2_axes)) {
auto axes_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean2_axes).get_node_shared_ptr());
if (!axes_2_node) {
return false;
}
auto axes_2_value = axes_2_node->cast_vector<int64_t>();
if (axes_1_value != axes_2_value) {
return false;
}
}
ngraph::NodeVector nodes_to_copy_info({ pattern_to_output.at(mean1).get_node_shared_ptr(),
pattern_to_output.at(sub1).get_node_shared_ptr(),
pattern_to_output.at(power).get_node_shared_ptr(),
pattern_to_output.at(mean3).get_node_shared_ptr() });
op::MVNEpsMode mode;
if (pattern_to_output.count(add_eps_os)) {
mode = op::MVNEpsMode::OUTSIDE_SQRT;
nodes_to_copy_info.push_back(pattern_to_output.at(add_eps_os).get_node_shared_ptr());
if (pattern_to_output.count(power_sqrt_os)) {
nodes_to_copy_info.push_back(pattern_to_output.at(power_sqrt_os).get_node_shared_ptr());
} else if (pattern_to_output.count(sqrt_os)) {
nodes_to_copy_info.push_back(pattern_to_output.at(sqrt_os).get_node_shared_ptr());
}
} else if (pattern_to_output.count(powerOrSqrt_is)) {
mode = op::MVNEpsMode::INSIDE_SQRT;
nodes_to_copy_info.push_back(pattern_to_output.at(add_eps_is).get_node_shared_ptr());
if (pattern_to_output.count(power_sqrt_is)) {
nodes_to_copy_info.push_back(pattern_to_output.at(power_sqrt_is).get_node_shared_ptr());
} else if (pattern_to_output.count(sqrt_is)) {
nodes_to_copy_info.push_back(pattern_to_output.at(sqrt_is).get_node_shared_ptr());
}
} else {
return false;
}
auto mvn = std::make_shared<ngraph::opset6::MVN>(exp_input, axes_1_node, true, eps_value, mode);
if (pattern_to_output.count(mean2) && pattern_to_output.count(sub2)) {
nodes_to_copy_info.push_back(pattern_to_output.at(mean2).get_node_shared_ptr());
nodes_to_copy_info.push_back(pattern_to_output.at(sub2).get_node_shared_ptr());
}
if (pattern_to_output.count(cast)) {
nodes_to_copy_info.push_back(pattern_to_output.at(cast).get_node_shared_ptr());
}
if (pattern_to_output.count(div_alt)) {
nodes_to_copy_info.push_back(pattern_to_output.at(div_alt).get_node_shared_ptr());
} else if (pattern_to_output.count(power_div) && pattern_to_output.count(div)) {
nodes_to_copy_info.push_back(pattern_to_output.at(power_div).get_node_shared_ptr());
nodes_to_copy_info.push_back(pattern_to_output.at(div).get_node_shared_ptr());
}
mvn->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(nodes_to_copy_info, mvn);
ngraph::replace_node(m.get_match_root(), mvn);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(powerMulOrDiv, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -31,6 +31,7 @@ ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
const auto data = mvn_node->input_value(0);
const auto axes = mvn_node->input_value(1);
// (x - ReduceMean(x, axes))
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(data, axes, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(data, mean);
@ -39,8 +40,11 @@ ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization });
ngraph::replace_node(mvn_node, mean_normalization);
} else {
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(mul, axes, true);
// (x - ReduceMean(x, axes)) ^ 2
auto sqr_const = ngraph::opset6::Constant::create(data.get_element_type(), ngraph::Shape{ 1 }, { 2 });
auto sqr = std::make_shared<ngraph::opset6::Power>(mean_normalization, sqr_const);
// ReduceMean((x - ReduceMean(x, axes)) ^ 2)
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(sqr, axes, true);
auto eps = mvn_node->get_eps();
auto eps_node = ngraph::opset6::Constant::create(data.get_element_type(), ngraph::Shape{ 1 }, { eps });
@ -51,19 +55,23 @@ ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
std::shared_ptr<ngraph::opset6::Divide> div;
if (eps_mode == op::MVNEpsMode::INSIDE_SQRT) {
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
eps_add = std::make_shared<ngraph::opset6::Add>(mean2, eps_node);
sqrt = std::make_shared<ngraph::opset6::Sqrt>(eps_add);
// (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
} else if (eps_mode == op::MVNEpsMode::OUTSIDE_SQRT) {
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps
sqrt = std::make_shared<ngraph::opset6::Sqrt>(mean2);
eps_add = std::make_shared<ngraph::opset6::Add>(sqrt, eps_node);
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, eps_add);
} else {
return false;
}
div->set_friendly_name(mvn_node->get_friendly_name());
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization, mul, mean2, eps_node, eps_add, sqrt, div });
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization, sqr, mean2, eps_node, eps_add, sqrt, div });
ngraph::replace_node(mvn_node, div);
}
return true;

View File

@ -69,8 +69,9 @@ TEST(TransformationTests, MVN6Decomposition_Inside_Sqrt) {
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(mul, axes_const, true);
auto sqr_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 2 });
auto sqr = std::make_shared<ngraph::opset6::Power>(mean_normalization, sqr_const);
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(sqr, axes_const, true);
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
@ -107,8 +108,9 @@ TEST(TransformationTests, MVN6Decomposition_Outside_Sqrt) {
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(mul, axes_const, true);
auto sqr_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 2 });
auto sqr = std::make_shared<ngraph::opset6::Power>(mean_normalization, sqr_const);
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(sqr, axes_const, true);
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });

View File

@ -0,0 +1,421 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/mvn_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, MVNFusionTestOutside) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean2_axes);
auto sub2 = std::make_shared<ngraph::opset6::Subtract>(input, mean2);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub2, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestReuseSub) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestWithConvert) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto cast = std::make_shared<ngraph::opset6::Convert>(sub1, ngraph::element::f32);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(cast, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestSqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto power_sqrt = std::make_shared<ngraph::opset6::Sqrt>(mean3);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestAltDiv) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
auto div = std::make_shared<ngraph::opset6::Divide>(sub1, add_eps);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestInsideSqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean2_axes);
auto sub2 = std::make_shared<ngraph::opset6::Subtract>(input, mean2);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub2, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestReuseSubInsideSqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestWithConvertInsideSqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto cast = std::make_shared<ngraph::opset6::Convert>(sub1, ngraph::element::f32);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(cast, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestSqrtInsideSqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
auto power_sqrt = std::make_shared<ngraph::opset6::Sqrt>(add_eps);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVNFusionTestAltDivInsideSqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
auto div = std::make_shared<ngraph::opset6::Divide>(sub1, power_sqrt);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -36,6 +36,8 @@ bool pattern::op::Or::match_value(Matcher* matcher,
auto saved = matcher->start_match();
if (matcher->match_value(input_value, graph_value))
{
auto& pattern_map = matcher->get_pattern_value_map();
pattern_map[input_value.get_node_shared_ptr()] = graph_value;
return saved.finish(true);
}
}