Add MVN fusion (#4647)
* Add MVN fusion * Fix map for Or pattern * Use count instead of find in MVN fusion * Apply review feedback * Apply review feedback * Fuse patterns * Fix Win build * Apply feedback
This commit is contained in:
parent
3b3d9a0989
commit
b2f3243387
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API MVNFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief MVNFusion transformation replaces group of
|
||||
* operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op.
|
||||
*/
|
||||
class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MVNFusion();
|
||||
};
|
@ -30,6 +30,7 @@
|
||||
#include "transformations/common_optimizations/pad_fusion.hpp"
|
||||
#include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp"
|
||||
#include "transformations/common_optimizations/softmax_fusion.hpp"
|
||||
#include "transformations/common_optimizations/mvn_fusion.hpp"
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||
#include "transformations/op_conversions/convert_divide.hpp"
|
||||
@ -98,6 +99,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::PadFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftmaxFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::MVNFusion>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
|
||||
|
@ -0,0 +1,190 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/mvn_fusion.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusion, "MVNFusion", 0);
|
||||
|
||||
template <class T>
|
||||
std::function<bool(ngraph::Output<ngraph::Node>)> value_is_equal_to(const std::vector<T>& ref_values) {
|
||||
return [ref_values](ngraph::Output<ngraph::Node> output) -> bool {
|
||||
auto node = output.get_node_shared_ptr();
|
||||
if (auto const_node = std::dynamic_pointer_cast<ngraph::op::Constant>(node)) {
|
||||
return const_node->template cast_vector<T>() == ref_values;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
}
|
||||
|
||||
ngraph::pass::MVNFusion::MVNFusion() {
|
||||
MATCHER_SCOPE(MVNFusion);
|
||||
// Detect MVN decomposition pattern:
|
||||
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
|
||||
auto x = pattern::any_input();
|
||||
|
||||
// (x - ReduceMean(x, axes))
|
||||
// `------mean1-------'
|
||||
auto mean1_axes = pattern::wrap_type<opset6::Constant>();
|
||||
auto mean1 = pattern::wrap_type<opset6::ReduceMean>({ x, mean1_axes });
|
||||
|
||||
// (x - ReduceMean(x, axes))
|
||||
// `-sub1------------------'
|
||||
auto sub1 = pattern::wrap_type<opset6::Subtract>({ x, mean1 });
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
|
||||
// `---mean2----------'
|
||||
auto mean2_axes = pattern::wrap_type<opset6::Constant>();
|
||||
auto mean2 = pattern::wrap_type<opset6::ReduceMean>({ x, mean2_axes });
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
|
||||
// `-sub2------------------'
|
||||
auto sub2 = pattern::wrap_type<opset6::Subtract>({ x, mean2 });
|
||||
|
||||
const auto reuseSub1OrNot = std::make_shared<pattern::op::Or>(OutputVector{ sub1, sub2 });
|
||||
|
||||
auto cast = pattern::wrap_type<opset6::Convert>({ reuseSub1OrNot });
|
||||
const auto hasConvertOrNot = std::make_shared<pattern::op::Or>(OutputVector{ cast, reuseSub1OrNot });
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
|
||||
// `---------------------power--'
|
||||
auto const_2 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({ 2.0 }));
|
||||
auto power = pattern::wrap_type<opset6::Power>({ hasConvertOrNot, const_2 });
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
|
||||
// `---mean3--------------------------------'
|
||||
auto mean3_axes = pattern::wrap_type<opset6::Constant>();
|
||||
auto mean3 = pattern::wrap_type<opset6::ReduceMean>({ power, mean3_axes });
|
||||
|
||||
auto const_0_5 = pattern::wrap_type<ngraph::opset6::Constant>(value_is_equal_to<float>({0.5}));
|
||||
auto eps = pattern::wrap_type<opset6::Constant>();
|
||||
// ------------------- OUTSIDE_SQRT ----------------------
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
|
||||
// `--Power--------------------------------------'
|
||||
auto power_sqrt_os = pattern::wrap_type<opset6::Power>({ mean3, const_0_5 });
|
||||
auto sqrt_os = pattern::wrap_type<opset6::Sqrt>({ mean3 });
|
||||
const auto powerOrSqrt_os = std::make_shared<pattern::op::Or>(OutputVector{ power_sqrt_os, sqrt_os });
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps
|
||||
// `----------------------------------------------Add---'
|
||||
auto add_eps_os = pattern::wrap_type<opset6::Add>({ powerOrSqrt_os, eps });
|
||||
|
||||
// ------------------- INSIDE_SQRT ----------------------
|
||||
|
||||
// (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps))
|
||||
// `-----------------------------------------------Add---'
|
||||
auto add_eps_is = pattern::wrap_type<opset6::Add>({ mean3, eps });
|
||||
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
|
||||
// `--Power--------------------------------------'
|
||||
auto power_sqrt_is = pattern::wrap_type<opset6::Power>({ add_eps_is, const_0_5 });
|
||||
auto sqrt_is = pattern::wrap_type<opset6::Sqrt>({ add_eps_is });
|
||||
const auto powerOrSqrt_is = std::make_shared<pattern::op::Or>(OutputVector{ power_sqrt_is, sqrt_is });
|
||||
|
||||
auto outsideOrInside = std::make_shared<pattern::op::Or>(OutputVector{ add_eps_os, powerOrSqrt_is });
|
||||
|
||||
// Final Divide
|
||||
auto const_neg_1 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({ -1 }));
|
||||
auto power_div = pattern::wrap_type<opset6::Power>({ outsideOrInside, const_neg_1 });
|
||||
auto div = pattern::wrap_type<opset6::Multiply>({ sub1, power_div });
|
||||
|
||||
auto div_alt = pattern::wrap_type<opset6::Divide>({ sub1, outsideOrInside });
|
||||
const auto powerMulOrDiv = std::make_shared<pattern::op::Or>(OutputVector{ div, div_alt });
|
||||
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(x);
|
||||
|
||||
auto const_eps_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(eps).get_node_shared_ptr());
|
||||
float eps_value;
|
||||
if (!op::util::get_single_value(const_eps_node, eps_value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axes_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean1_axes).get_node_shared_ptr());
|
||||
auto axes_3_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean3_axes).get_node_shared_ptr());
|
||||
|
||||
if (!axes_1_node || !axes_3_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axes_1_value = axes_1_node->cast_vector<int64_t>();
|
||||
auto axes_3_value = axes_3_node->cast_vector<int64_t>();
|
||||
|
||||
if (axes_1_value != axes_3_value) {
|
||||
return false;
|
||||
}
|
||||
if (pattern_to_output.count(mean2_axes)) {
|
||||
auto axes_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean2_axes).get_node_shared_ptr());
|
||||
if (!axes_2_node) {
|
||||
return false;
|
||||
}
|
||||
auto axes_2_value = axes_2_node->cast_vector<int64_t>();
|
||||
if (axes_1_value != axes_2_value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ngraph::NodeVector nodes_to_copy_info({ pattern_to_output.at(mean1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(sub1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(power).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mean3).get_node_shared_ptr() });
|
||||
|
||||
op::MVNEpsMode mode;
|
||||
if (pattern_to_output.count(add_eps_os)) {
|
||||
mode = op::MVNEpsMode::OUTSIDE_SQRT;
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(add_eps_os).get_node_shared_ptr());
|
||||
if (pattern_to_output.count(power_sqrt_os)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(power_sqrt_os).get_node_shared_ptr());
|
||||
} else if (pattern_to_output.count(sqrt_os)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(sqrt_os).get_node_shared_ptr());
|
||||
}
|
||||
} else if (pattern_to_output.count(powerOrSqrt_is)) {
|
||||
mode = op::MVNEpsMode::INSIDE_SQRT;
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(add_eps_is).get_node_shared_ptr());
|
||||
if (pattern_to_output.count(power_sqrt_is)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(power_sqrt_is).get_node_shared_ptr());
|
||||
} else if (pattern_to_output.count(sqrt_is)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(sqrt_is).get_node_shared_ptr());
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(exp_input, axes_1_node, true, eps_value, mode);
|
||||
|
||||
if (pattern_to_output.count(mean2) && pattern_to_output.count(sub2)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(mean2).get_node_shared_ptr());
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(sub2).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
if (pattern_to_output.count(cast)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(cast).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
if (pattern_to_output.count(div_alt)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(div_alt).get_node_shared_ptr());
|
||||
} else if (pattern_to_output.count(power_div) && pattern_to_output.count(div)) {
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(power_div).get_node_shared_ptr());
|
||||
nodes_to_copy_info.push_back(pattern_to_output.at(div).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
mvn->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(nodes_to_copy_info, mvn);
|
||||
ngraph::replace_node(m.get_match_root(), mvn);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(powerMulOrDiv, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -31,6 +31,7 @@ ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
|
||||
const auto data = mvn_node->input_value(0);
|
||||
const auto axes = mvn_node->input_value(1);
|
||||
|
||||
// (x - ReduceMean(x, axes))
|
||||
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(data, axes, true);
|
||||
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(data, mean);
|
||||
|
||||
@ -39,8 +40,11 @@ ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
|
||||
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization });
|
||||
ngraph::replace_node(mvn_node, mean_normalization);
|
||||
} else {
|
||||
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(mul, axes, true);
|
||||
// (x - ReduceMean(x, axes)) ^ 2
|
||||
auto sqr_const = ngraph::opset6::Constant::create(data.get_element_type(), ngraph::Shape{ 1 }, { 2 });
|
||||
auto sqr = std::make_shared<ngraph::opset6::Power>(mean_normalization, sqr_const);
|
||||
// ReduceMean((x - ReduceMean(x, axes)) ^ 2)
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(sqr, axes, true);
|
||||
|
||||
auto eps = mvn_node->get_eps();
|
||||
auto eps_node = ngraph::opset6::Constant::create(data.get_element_type(), ngraph::Shape{ 1 }, { eps });
|
||||
@ -51,19 +55,23 @@ ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
|
||||
std::shared_ptr<ngraph::opset6::Divide> div;
|
||||
|
||||
if (eps_mode == op::MVNEpsMode::INSIDE_SQRT) {
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
eps_add = std::make_shared<ngraph::opset6::Add>(mean2, eps_node);
|
||||
sqrt = std::make_shared<ngraph::opset6::Sqrt>(eps_add);
|
||||
// (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
|
||||
} else if (eps_mode == op::MVNEpsMode::OUTSIDE_SQRT) {
|
||||
// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps
|
||||
sqrt = std::make_shared<ngraph::opset6::Sqrt>(mean2);
|
||||
eps_add = std::make_shared<ngraph::opset6::Add>(sqrt, eps_node);
|
||||
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
|
||||
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, eps_add);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
div->set_friendly_name(mvn_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization, mul, mean2, eps_node, eps_add, sqrt, div });
|
||||
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization, sqr, mean2, eps_node, eps_add, sqrt, div });
|
||||
ngraph::replace_node(mvn_node, div);
|
||||
}
|
||||
return true;
|
||||
|
@ -69,8 +69,9 @@ TEST(TransformationTests, MVN6Decomposition_Inside_Sqrt) {
|
||||
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
|
||||
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
|
||||
|
||||
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(mul, axes_const, true);
|
||||
auto sqr_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 2 });
|
||||
auto sqr = std::make_shared<ngraph::opset6::Power>(mean_normalization, sqr_const);
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(sqr, axes_const, true);
|
||||
|
||||
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
|
||||
|
||||
@ -107,8 +108,9 @@ TEST(TransformationTests, MVN6Decomposition_Outside_Sqrt) {
|
||||
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
|
||||
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
|
||||
|
||||
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(mul, axes_const, true);
|
||||
auto sqr_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 2 });
|
||||
auto sqr = std::make_shared<ngraph::opset6::Power>(mean_normalization, sqr_const);
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(sqr, axes_const, true);
|
||||
|
||||
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
|
||||
|
||||
|
@ -0,0 +1,421 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/mvn_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestOutside) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean2_axes);
|
||||
auto sub2 = std::make_shared<ngraph::opset6::Subtract>(input, mean2);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub2, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestReuseSub) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestWithConvert) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto cast = std::make_shared<ngraph::opset6::Convert>(sub1, ngraph::element::f32);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(cast, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestSqrt) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Sqrt>(mean3);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestAltDiv) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
|
||||
auto div = std::make_shared<ngraph::opset6::Divide>(sub1, add_eps);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestInsideSqrt) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean2_axes);
|
||||
auto sub2 = std::make_shared<ngraph::opset6::Subtract>(input, mean2);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub2, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestReuseSubInsideSqrt) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestWithConvertInsideSqrt) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto cast = std::make_shared<ngraph::opset6::Convert>(sub1, ngraph::element::f32);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(cast, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestSqrtInsideSqrt) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Sqrt>(add_eps);
|
||||
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto power_div = std::make_shared<ngraph::opset6::Power>(power_sqrt, const_neg_1);
|
||||
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestAltDivInsideSqrt) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
|
||||
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
|
||||
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
|
||||
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub1, const_2);
|
||||
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean3, eps);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
|
||||
auto div = std::make_shared<ngraph::opset6::Divide>(sub1, power_sqrt);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -36,6 +36,8 @@ bool pattern::op::Or::match_value(Matcher* matcher,
|
||||
auto saved = matcher->start_match();
|
||||
if (matcher->match_value(input_value, graph_value))
|
||||
{
|
||||
auto& pattern_map = matcher->get_pattern_value_map();
|
||||
pattern_map[input_value.get_node_shared_ptr()] = graph_value;
|
||||
return saved.finish(true);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user