Add MVN-6 related transformations (#3710)

* Add MVN decomposition transformation

* Add MVN-1 to MVN-6 transformation

* Apply review feedback

* Apply review feedback

* Fix build

* Fix if statement and add 5D tests

* Apply review feedback

* Apply review feedback

* Apply feedback

* Revert "Apply feedback"

This reverts commit 039fefbff9.

* Apply review feedback

* Apply review feedback

* Fix build issue

* Apply review feedback

* Apply review feedback

* Apply feedback
This commit is contained in:
Maxim Vafin 2021-01-18 10:34:01 +03:00 committed by GitHub
parent d462626826
commit af5eccc6ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 403 additions and 0 deletions

View File

@ -0,0 +1,27 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertMVN1ToMVN6;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertMVN1ToMVN6 covert v0:MVN into v6::MVN.
*/
class ngraph::pass::ConvertMVN1ToMVN6 : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertMVN1ToMVN6();
};

View File

@ -0,0 +1,27 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API MVN6Decomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief MVN6Decomposition transformation into sub-graph x - ReduceMean(x, axes) if normalize_variance is false and
* into sub-graph (x - ReduceMean(x, axes)) / Sqrt(ReduceSum((x - ReduceMean(x, axes)) ^ 2)) if normalize_variance is true.
*/
class ngraph::pass::MVN6Decomposition : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVN6Decomposition();
};

View File

@ -0,0 +1,52 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp"
#include <ngraph/rt_info.hpp>
#include <numeric>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMVN1ToMVN6, "ConvertMVN1ToMVN6", 0);
ngraph::pass::ConvertMVN1ToMVN6::ConvertMVN1ToMVN6() {
auto mvn = pattern::wrap_type<ngraph::opset2::MVN>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto mvn_node = std::dynamic_pointer_cast<ngraph::opset2::MVN>(m.get_match_root());
if (!mvn_node) {
return false;
}
const auto input = mvn_node->input_value(0);
auto input_rank = input.get_partial_shape().rank();
if (!input_rank.is_static()) {
return false;
}
int64_t start_axis = 1 + (!mvn_node->get_across_channels());
if (input_rank.get_length() <= start_axis) {
return false;
}
std::vector<int64_t> axes_v(input_rank.get_length() - start_axis);
std::iota(axes_v.begin(), axes_v.end(), start_axis);
auto axes = opset6::Constant::create(ngraph::element::i64, { axes_v.size() }, axes_v);
auto mvn6_node = std::make_shared<ngraph::opset6::MVN>(input,
axes,
mvn_node->get_normalize_variance(),
mvn_node->get_eps(),
ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
mvn6_node->set_friendly_name(mvn_node->get_friendly_name());
ngraph::copy_runtime_info(mvn_node, mvn6_node);
ngraph::replace_node(mvn_node, mvn6_node);
return true;
};
auto m = std::make_shared<pattern::Matcher>(mvn, "ConvertMVN1ToMVN6");
register_matcher(m, callback);
}

View File

@ -0,0 +1,72 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/mvn6_decomposition.hpp"
#include <memory>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVN6Decomposition, "MVN6Decomposition", 0);
ngraph::pass::MVN6Decomposition::MVN6Decomposition() {
// Decomposes MVN(x, axes) op if normalize_variance is false into sub-graph
// x - ReduceMean(x, axes), if normalize_variance is true into sub-graph
// (x - ReduceMean(x, axes)) / Sqrt(ReduceSum((x - ReduceMean(x, axes)) ^ 2))
auto mvn = ngraph::pattern::wrap_type<opset6::MVN>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto mvn_node = std::dynamic_pointer_cast<ngraph::opset6::MVN>(pattern_to_output.at(mvn).get_node_shared_ptr());
if (mvn_node == nullptr || transformation_callback(mvn_node)) {
return false;
}
const auto data = mvn_node->input_value(0);
const auto axes = mvn_node->input_value(1);
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(data, axes, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(data, mean);
if (!mvn_node->get_normalize_variance()) {
mean_normalization->set_friendly_name(mvn_node->get_friendly_name());
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization });
ngraph::replace_node(mvn_node, mean_normalization);
} else {
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
auto sum = std::make_shared<ngraph::opset6::ReduceSum>(mul, axes, true);
auto eps = mvn_node->get_eps();
auto eps_node = ngraph::opset6::Constant::create(data.get_element_type(), ngraph::Shape{ 1 }, { eps });
auto eps_mode = mvn_node->get_eps_mode();
std::shared_ptr<ngraph::opset6::Add> eps_add;
std::shared_ptr<ngraph::opset6::Sqrt> sqrt;
std::shared_ptr<ngraph::opset6::Divide> div;
if (eps_mode == op::MVNEpsMode::INSIDE_SQRT) {
eps_add = std::make_shared<ngraph::opset6::Add>(sum, eps_node);
sqrt = std::make_shared<ngraph::opset6::Sqrt>(eps_add);
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
} else if (eps_mode == op::MVNEpsMode::OUTSIDE_SQRT) {
sqrt = std::make_shared<ngraph::opset6::Sqrt>(sum);
eps_add = std::make_shared<ngraph::opset6::Add>(sqrt, eps_node);
div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
} else {
return false;
}
div->set_friendly_name(mvn_node->get_friendly_name());
ngraph::copy_runtime_info(mvn_node, { mean, mean_normalization, mul, sum, eps_node, eps_add, sqrt, div });
ngraph::replace_node(mvn_node, div);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mvn, "MVN6Decomposition");
register_matcher(m, callback);
}

View File

@ -0,0 +1,101 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_mvn1_to_mvn6.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertMVN1ToMVN6) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset2::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto mvn = std::make_shared<ngraph::op::v0::MVN>(data, false, true, 1e-5);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mvn = std::make_shared<ngraph::op::v6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMVN1ToMVN6_across_channels) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset2::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto mvn = std::make_shared<ngraph::op::v0::MVN>(data, true, true, 1e-5);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::op::v6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertMVN1ToMVN6_5D) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset2::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4, 5 });
auto mvn = std::make_shared<ngraph::op::v0::MVN>(data, false, true, 1e-5);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4, 5 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 2, 3, 4 });
auto mvn = std::make_shared<ngraph::op::v6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -0,0 +1,124 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/mvn6_decomposition.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, MVN6Decomposition_No_Variance) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(data, axes_const, false, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVN6Decomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mean_normalization }, ngraph::ParameterVector{ input0 });
}
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVN6Decomposition_Inside_Sqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::INSIDE_SQRT);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVN6Decomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
auto sum = std::make_shared<ngraph::opset6::ReduceSum>(mul, axes_const, true);
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
auto eps_add = std::make_shared<ngraph::opset6::Add>(sum, eps_node);
auto sqrt = std::make_shared<ngraph::opset6::Sqrt>(eps_add);
auto div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input0 });
}
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MVN6Decomposition_Outside_Sqrt) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(data, axes_const, true, 1e-5, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ data });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVN6Decomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 3, 4 });
auto axes_const = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, { 2, 3 });
auto mean = std::make_shared<ngraph::opset6::ReduceMean>(input0, axes_const, true);
auto mean_normalization = std::make_shared<ngraph::opset6::Subtract>(input0, mean);
auto mul = std::make_shared<ngraph::opset6::Multiply>(mean_normalization, mean_normalization);
auto sum = std::make_shared<ngraph::opset6::ReduceSum>(mul, axes_const, true);
auto eps_node = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1e-5 });
auto sqrt = std::make_shared<ngraph::opset6::Sqrt>(sum);
auto eps_add = std::make_shared<ngraph::opset6::Add>(sqrt, eps_node);
auto div = std::make_shared<ngraph::opset6::Divide>(mean_normalization, sqrt);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input0 });
}
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}