Add MVN-6 related transformations (#3710)
* Add MVN decomposition transformation
* Add MVN-1 to MVN-6 transformation
* Apply review feedback
* Apply review feedback
* Fix build
* Fix if statement and add 5D tests
* Apply review feedback
* Apply review feedback
* Apply feedback
* Revert "Apply feedback"
This reverts commit 039fefbff9
.
* Apply review feedback
* Apply review feedback
* Fix build issue
* Apply review feedback
* Apply review feedback
* Apply feedback
This commit is contained in:
parent
d462626826
commit
af5eccc6ae
@ -0,0 +1,27 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API ConvertMVN1ToMVN6;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief ConvertMVN1ToMVN6 covert v0:MVN into v6::MVN.
|
||||||
|
*/
|
||||||
|
class ngraph::pass::ConvertMVN1ToMVN6 : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
ConvertMVN1ToMVN6();
|
||||||
|
};
|
@ -0,0 +1,27 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API MVN6Decomposition;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief MVN6Decomposition transformation into sub-graph x - ReduceMean(x, axes) if normalize_variance is false and
|
||||||
|
* into sub-graph (x - ReduceMean(x, axes)) / Sqrt(ReduceSum((x - ReduceMean(x, axes)) ^ 2)) if normalize_variance is true.
|
||||||
|
*/
|
||||||
|
class ngraph::pass::MVN6Decomposition : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
MVN6Decomposition();
|
||||||
|
};
|
@ -0,0 +1,52 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp"
|
||||||
|
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset2.hpp>
|
||||||
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMVN1ToMVN6, "ConvertMVN1ToMVN6", 0);
|
||||||
|
|
||||||
|
ngraph::pass::ConvertMVN1ToMVN6::ConvertMVN1ToMVN6() {
|
||||||
|
auto mvn = pattern::wrap_type<ngraph::opset2::MVN>();
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||||
|
auto mvn_node = std::dynamic_pointer_cast<ngraph::opset2::MVN>(m.get_match_root());
|
||||||
|
if (!mvn_node) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto input = mvn_node->input_value(0);
|
||||||
|
auto input_rank = input.get_partial_shape().rank();
|
||||||
|
if (!input_rank.is_static()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int64_t start_axis = 1 + (!mvn_node->get_across_channels());
|
||||||
|
if (input_rank.get_length() <= start_axis) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<int64_t> axes_v(input_rank.get_length() - start_axis);
|
||||||
|
std::iota(axes_v.begin(), axes_v.end(), start_axis);
|
||||||
|
auto axes = opset6::Constant::create(ngraph::element::i64, { axes_v.size() }, axes_v);
|
||||||
|
auto mvn6_node = std::make_shared<ngraph::opset6::MVN>(input,
|
||||||
|
axes,
|
||||||
|
mvn_node->get_normalize_variance(),
|
||||||
|
mvn_node->get_eps(),
|
||||||
|
ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||||
|
|
||||||
|
mvn6_node->set_friendly_name(mvn_node->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info(mvn_node, mvn6_node);
|
||||||
|
ngraph::replace_node(mvn_node, mvn6_node);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<pattern::Matcher>(mvn, "ConvertMVN1ToMVN6");
|
||||||
|
register_matcher(m, callback);
|
||||||
|
}
|
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/op_conversions/mvn6_decomposition.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVN6Decomposition, "MVN6Decomposition", 0);
|
||||||
|
|
||||||
|
ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
|
||||||
|
// Decomposes MVN(x, axes) op if normalize_variance is false into sub-graph
|
||||||
|
// x - ReduceMean(x, axes), if normalize_variance is true into sub-graph
|
||||||
|
// (x - ReduceMean(x, axes)) / Sqrt(ReduceSum((x - ReduceMean(x, axes)) ^ 2))
|
||||||
|
auto mvn = ngraph::pattern::wrap_type<opset6::MVN>();
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
|
auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto mvn_node = std::dynamic_pointer_cast<ngraph::opset6::MVN>(pattern_to_output.at(mvn).get_node_shared_ptr());
|
||||||
|
|
||||||
|
if (mvn_node == nullptr || transformation_callback(mvn_node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto data = mvn_node->input_value(0);
|
||||||
|
const auto axes = mvn_node->input_value(1);
|
||||||
|
|
||||||
|
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(data, axes, true);
|
||||||
|
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(data, mean);
|
||||||
|
|
||||||
|
if (!mvn_node->get_normalize_variance()) {
|
||||||
|
mean_normalization->set_friendly_name(mvn_node->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization });
|
||||||
|
ngraph::replace_node(mvn_node, mean_normalization);
|
||||||
|
} else {
|
||||||
|
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
|
||||||
|
auto sum = std::make_shared<ngraph::opset6::ReduceSum>(mul, axes, true);
|
||||||
|
|
||||||
|
auto eps = mvn_node->get_eps();
|
||||||
|
auto eps_node = ngraph::opset6::Constant::create(data.get_element_type(), ngraph::Shape{ 1 }, { eps });
|
||||||
|
auto eps_mode = mvn_node->get_eps_mode();
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::opset6::Add> eps_add;
|
||||||
|
std::shared_ptr<ngraph::opset6::Sqrt> sqrt;
|
||||||
|
std::shared_ptr<ngraph::opset6::Divide> div;
|
||||||
|
|
||||||
|
if (eps_mode == op::MVNEpsMode::INSIDE_SQRT) {
|
||||||
|
eps_add = std::make_shared<ngraph::opset6::Add>(sum, eps_node);
|
||||||
|
sqrt = std::make_shared<ngraph::opset6::Sqrt>(eps_add);
|
||||||
|
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
|
||||||
|
} else if (eps_mode == op::MVNEpsMode::OUTSIDE_SQRT) {
|
||||||
|
sqrt = std::make_shared<ngraph::opset6::Sqrt>(sum);
|
||||||
|
eps_add = std::make_shared<ngraph::opset6::Add>(sqrt, eps_node);
|
||||||
|
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
div->set_friendly_name(mvn_node->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization, mul, sum, eps_node, eps_add, sqrt, div });
|
||||||
|
ngraph::replace_node(mvn_node, div);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(mvn, "MVN6Decomposition");
|
||||||
|
register_matcher(m, callback);
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset2.hpp>
|
||||||
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/op_conversions/convert_mvn1_to_mvn6.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMVN1ToMVN6) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset2::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto mvn = std::make_shared<ngraph::op::v0::MVN>(data, false, true, 1e-5);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mvn = std::make_shared<ngraph::op::v6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, false, false, false, false);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMVN1ToMVN6_across_channels) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset2::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto mvn = std::make_shared<ngraph::op::v0::MVN>(data, true, true, 1e-5);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 1, 2, 3 });
|
||||||
|
auto mvn = std::make_shared<ngraph::op::v6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, false, false, false, false);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertMVN1ToMVN6_5D) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset2::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4, 5 });
|
||||||
|
auto mvn = std::make_shared<ngraph::op::v0::MVN>(data, false, true, 1e-5);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4, 5 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 2, 3, 4 });
|
||||||
|
auto mvn = std::make_shared<ngraph::op::v6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, false, false, false, false);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
@ -0,0 +1,124 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/op_conversions/mvn6_decomposition.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST(TransformationTests, MVN6Decomposition_No_Variance) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mvn = std::make_shared<ngraph::opset6::MVN>(data, axes_const, false, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::MVN6Decomposition>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
|
||||||
|
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mean_normalization }, ngraph::ParameterVector{ input0 });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, false, false, false, false);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, MVN6Decomposition_Inside_Sqrt) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mvn = std::make_shared<ngraph::opset6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::MVN6Decomposition>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
|
||||||
|
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
|
||||||
|
|
||||||
|
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
|
||||||
|
auto sum = std::make_shared<ngraph::opset6::ReduceSum>(mul, axes_const, true);
|
||||||
|
|
||||||
|
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
|
||||||
|
|
||||||
|
auto eps_add = std::make_shared<ngraph::opset6::Add>(sum, eps_node);
|
||||||
|
auto sqrt = std::make_shared<ngraph::opset6::Sqrt>(eps_add);
|
||||||
|
auto div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input0 });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, false, false, false, false);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, MVN6Decomposition_Outside_Sqrt) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mvn = std::make_shared<ngraph::opset6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::MVN6Decomposition>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
|
||||||
|
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
|
||||||
|
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
|
||||||
|
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
|
||||||
|
|
||||||
|
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
|
||||||
|
auto sum = std::make_shared<ngraph::opset6::ReduceSum>(mul, axes_const, true);
|
||||||
|
|
||||||
|
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
|
||||||
|
|
||||||
|
auto sqrt = std::make_shared<ngraph::opset6::Sqrt>(sum);
|
||||||
|
auto eps_add = std::make_shared<ngraph::opset6::Add>(sqrt, eps_node);
|
||||||
|
auto div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input0 });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, false, false, false, false);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user