[GNA] Add MVN decomposition (#8142)

This commit is contained in:
Szymon Irzabek 2021-11-22 12:39:25 +01:00 committed by GitHub
parent a4854fa61d
commit f6a4dcb5ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 703 additions and 0 deletions

View File

@ -70,6 +70,8 @@
#include "transformations/op_conversions/lstm_cell_decomposition.hpp" #include "transformations/op_conversions/lstm_cell_decomposition.hpp"
#include "transformations/remove_single_input_concat.hpp" #include "transformations/remove_single_input_concat.hpp"
#include "transformations/broadcast_const.hpp" #include "transformations/broadcast_const.hpp"
#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp"
#include "transformations/decompose_mvn.hpp"
#include "transformations/substitute_softsign.hpp" #include "transformations/substitute_softsign.hpp"
#include <ngraph/opsets/opset7.hpp> #include <ngraph/opsets/opset7.hpp>
@ -687,6 +689,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::InitNodeInfo>();
fake_quantized = ngraph::op::util::has_op_with_type<ngraph::opset7::FakeQuantize>(graph); fake_quantized = ngraph::op::util::has_op_with_type<ngraph::opset7::FakeQuantize>(graph);
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
manager.register_pass<DecomposeMVN>();
manager.register_pass<ngraph::pass::CommonOptimizations>(); manager.register_pass<ngraph::pass::CommonOptimizations>();
manager.register_pass<ngraph::pass::LSTMCellDecomposition>(); manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
manager.register_pass<ConvertDWSCToScaleShifts>(); manager.register_pass<ConvertDWSCToScaleShifts>();

View File

@ -0,0 +1,265 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/decompose_mvn.hpp"
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations/utils/utils.hpp>
#include "backend/gna_limitations.hpp"
using namespace GNAPluginNS;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(DecomposeMVN, "DecomposeMVN", 0);
struct MVNData {
size_t N;
size_t C;
size_t H;
size_t W;
size_t num_parts;
float eps;
op::MVNEpsMode eps_mode;
bool normalize_variance;
element::Type element_type;
std::string name;
};
template <class T>
static bool ValidateAxes(const std::shared_ptr<opset8::Constant> axes_const, const size_t& mvn_shape_size) {
T axes_value;
size_t axes_vector_size;
std::vector<T> axes_const_vector = axes_const->cast_vector<T>();
IE_ASSERT(!axes_const_vector.empty());
axes_value = axes_const_vector[0];
axes_vector_size = axes_const_vector.size();
if (axes_vector_size != mvn_shape_size - 2) {
return false;
}
// Verify supported first axes value
if (axes_value != 2 && axes_value != 2 - mvn_shape_size)
return false;
return true;
}
static bool GetVerifiedMVNData(const std::shared_ptr<opset8::MVN> mvn, MVNData& mvn_data) {
const auto mvn_shape = mvn->get_output_shape(0);
auto mvn_shape_size = mvn_shape.size();
// Validate axes parameter
auto axes_const = std::dynamic_pointer_cast<opset8::Constant>(mvn->input_value(1).get_node_shared_ptr());
IE_ASSERT(axes_const);
auto element_type = axes_const->get_element_type();
if (!(element_type == element::Type_t::i64 ? ValidateAxes<int64_t>(axes_const, mvn_shape_size) :
ValidateAxes<int32_t>(axes_const, mvn_shape_size)))
return false;
if (mvn_shape_size == 4) {
mvn_data.N = mvn_shape[0];
mvn_data.C = mvn_shape[1];
mvn_data.H = mvn_shape[2];
mvn_data.W = mvn_shape[3];
} else if (mvn_shape_size == 3) {
mvn_data.N = 1;
mvn_data.C = mvn_shape[0];
mvn_data.H = mvn_shape[1];
mvn_data.W = mvn_shape[2];
}
// Check if average must be split
mvn_data.num_parts = 1;
while (mvn_data.W / mvn_data.num_parts > GNALimitations::convFilterMaxSize) {
mvn_data.num_parts *= 2;
}
// Abort if W is not divisible by power of 2
if ((mvn_data.W / mvn_data.num_parts) * mvn_data.num_parts != mvn_data.W) {
return false;
}
mvn_data.eps = mvn->get_eps();
mvn_data.eps_mode = mvn->get_eps_mode();
mvn_data.normalize_variance = mvn->get_normalize_variance();
mvn_data.element_type = mvn->get_element_type();
mvn_data.name = mvn->get_friendly_name();
return true;
}
static std::shared_ptr<Node> NormalizeVariance(const std::shared_ptr<opset8::MVN> mvn, const MVNData& mvn_data,
const std::shared_ptr<opset8::Add>& subtract_mean, const std::shared_ptr<opset8::Constant>& avg_broadcast_const) {
// Prepare consts
auto combined_C_H = mvn_data.C * mvn_data.H;
std::vector<float> avg_weights(8 * mvn_data.W / mvn_data.num_parts, 1.0f / mvn_data.W);
auto avg_weights_const = opset8::Constant::create(mvn_data.element_type, Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, avg_weights);
std::vector<float> eps_tensor(combined_C_H * mvn_data.W, mvn_data.eps);
auto eps_tensor_const = opset8::Constant::create(mvn_data.element_type, Shape{1, combined_C_H * mvn_data.W}, eps_tensor);
std::vector<float> minus_half(combined_C_H * mvn_data.W, -0.5f);
auto minus_half_const = opset8::Constant::create(mvn_data.element_type, Shape{1, combined_C_H * mvn_data.W}, minus_half);
// Calculate square of the difference between input and its mean
auto squared_diff = std::make_shared<opset8::Multiply>(subtract_mean, subtract_mean);
squared_diff->set_friendly_name(mvn_data.name + "_SqrDiff");
// Calculate sum of the squares
auto squared_diff_reshape = std::make_shared<opset8::Reshape>(squared_diff,
opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false);
auto transposed_input_3 = std::make_shared<opset8::Transpose>(squared_diff_reshape, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_3 = std::make_shared<opset8::Convolution>(transposed_input_3, avg_weights_const,
Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID);
transposed_avg_conv_3->set_friendly_name(mvn_data.name + "_Avg3");
auto avg_conv_3 = std::make_shared<opset8::Transpose>(transposed_avg_conv_3, opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1}));
auto reshape_avg_conv_3 = std::make_shared<opset8::Reshape>(avg_conv_3,
opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false);
auto transposed_input_4 = std::make_shared<opset8::Transpose>(reshape_avg_conv_3, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_4 = std::make_shared<opset8::Convolution>(transposed_input_4,
avg_broadcast_const, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID);
transposed_avg_conv_4->set_friendly_name(mvn_data.name + "_Avg4");
auto avg_conv_4 = std::make_shared<opset8::Transpose>(transposed_avg_conv_4,
opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1}));
auto reshape_avg_conv_4 = std::make_shared<opset8::Reshape>(avg_conv_4,
opset8::Constant::create(element::i32, Shape{2}, Shape{1ull, combined_C_H * mvn_data.W}), false);
std::shared_ptr<Node> inv_stdev;
// Create normalization part of the graph
// We ignore inside/outside epsilon position here and always use inside, to get better accuracy
// even though the built-in MVN1 to MVN6 transformation enforces outside setting
// Add epsilon inside the square root
auto add_epsilon = std::make_shared<opset8::Add>(eps_tensor_const, reshape_avg_conv_4);
// Calculate square root and inversion
auto log_var_eps = std::make_shared<opset8::Log>(add_epsilon);
log_var_eps->set_friendly_name(mvn_data.name + "_LogVarEps");
auto log_inv_stdev = std::make_shared<opset8::Multiply>(log_var_eps, minus_half_const);
log_inv_stdev->set_friendly_name(mvn_data.name + "_LogInvStdev");
inv_stdev = std::make_shared<opset8::Exp>(log_inv_stdev);
inv_stdev->set_friendly_name(mvn_data.name + "_InvStdev");
copy_runtime_info(mvn, {add_epsilon, log_var_eps, log_inv_stdev, inv_stdev});
auto normalized_output = std::make_shared<opset8::Multiply>(subtract_mean, inv_stdev);
normalized_output->set_friendly_name(mvn_data.name + "_Output");
copy_runtime_info(mvn, {squared_diff, squared_diff_reshape, transposed_input_3, transposed_avg_conv_3, avg_conv_3, reshape_avg_conv_3,
transposed_input_4, transposed_avg_conv_4, avg_conv_4, reshape_avg_conv_4});
return normalized_output;
}
static void Decompose(const std::shared_ptr<opset8::MVN> mvn, const MVNData& mvn_data) {
// Prepare data
auto combined_C_H = mvn_data.C * mvn_data.H;
std::vector<float> neg_avg_weights(8 * mvn_data.W / mvn_data.num_parts, -1.0f / mvn_data.W);
auto neg_avg_weights_const = opset8::Constant::create(mvn_data.element_type, Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, neg_avg_weights);
std::vector<float> avg_broadcast(8 * mvn_data.W * mvn_data.num_parts, 0.0f);
for (size_t i = 0; i < mvn_data.W * mvn_data.num_parts; i++) {
avg_broadcast[i * 8] = 1.0f;
}
auto avg_broadcast_const = opset8::Constant::create(mvn_data.element_type, Shape{mvn_data.W, 8 * mvn_data.num_parts, 1, 1}, avg_broadcast);
// Create average calculation part of the graph
// We assume C = 1 case (combined channels)
const auto input = mvn->input_value(0);
auto reshape = std::make_shared<opset8::Reshape>(input,
opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, 1ull, combined_C_H, mvn_data.W}), false);
auto input_4d = std::make_shared<opset8::Reshape>(reshape,
opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false);
auto input_2d = std::make_shared<opset8::Reshape>(reshape,
opset8::Constant::create(element::i32, Shape{2}, Shape{1ull, combined_C_H * mvn_data.W}), false);
auto transposed_input_1 = std::make_shared<opset8::Transpose>(input_4d, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_1 = std::make_shared<opset8::Convolution>(transposed_input_1, neg_avg_weights_const,
Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID);
transposed_avg_conv_1->set_friendly_name(mvn_data.name + "_Avg1");
auto avg_conv_1 = std::make_shared<opset8::Transpose>(transposed_avg_conv_1, opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1}));
auto reshape_avg_conv_1 = std::make_shared<opset8::Reshape>(avg_conv_1,
opset8::Constant::create(element::i32, Shape{4}, Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false);
auto transposed_input_2 = std::make_shared<opset8::Transpose>(reshape_avg_conv_1, opset8::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_2 = std::make_shared<opset8::Convolution>(transposed_input_2,
avg_broadcast_const, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::VALID);
transposed_avg_conv_2->set_friendly_name(mvn_data.name + "_Avg2");
auto avg_conv_2 = std::make_shared<opset8::Transpose>(transposed_avg_conv_2,
opset8::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1}));
auto avg_conv_2_2d = std::make_shared<opset8::Reshape>(avg_conv_2,
opset8::Constant::create(element::i32, Shape{2}, Shape{1ull, combined_C_H * mvn_data.W}), false);
auto subtract_mean = std::make_shared<opset8::Add>(input_2d, avg_conv_2_2d);
subtract_mean->set_friendly_name(mvn_data.name + "_SubMean");
std::shared_ptr<Node> mvn_output, pre_output = subtract_mean;
// Normalize variance if required
if (mvn_data.normalize_variance) {
pre_output = NormalizeVariance(mvn, mvn_data, subtract_mean, avg_broadcast_const);
}
// Reshape (combined channels) back to get the final output
if (mvn->get_output_shape(0).size() == 3) {
mvn_output = std::make_shared<opset8::Reshape>(pre_output,
opset8::Constant::create(element::i32, Shape{3}, {mvn_data.C, mvn_data.H, mvn_data.W}), false);
} else {
mvn_output = std::make_shared<opset8::Reshape>(pre_output,
opset8::Constant::create(element::i32, Shape{4}, {mvn_data.N, mvn_data.C, mvn_data.H, mvn_data.W}), false);
}
copy_runtime_info(mvn, {reshape, input_4d, input_2d, transposed_input_1, transposed_avg_conv_1, avg_conv_1, reshape_avg_conv_1,
transposed_input_2, transposed_avg_conv_2, avg_conv_2, avg_conv_2_2d, subtract_mean, mvn_output});
// We need retain the MVN layer name, so its output can be used as a network result
replace_node(mvn, mvn_output);
mvn_output->set_friendly_name(mvn_data.name);
}
static bool Convert(std::shared_ptr<Node> mvn_node) {
const auto mvn = std::dynamic_pointer_cast<opset8::MVN>(mvn_node);
MVNData mvn_data;
if (!GetVerifiedMVNData(mvn, mvn_data))
return false;
Decompose(mvn, mvn_data);
return true;
}
static std::function<bool(Output<Node>)> verify_rank_batch() {
return [=](Output<Node> output) -> bool {
// Only rank 3 and 4 and batch 1 are supported for now
auto rank = output.get_partial_shape().rank();
if (rank != 3 && rank != 4)
return false;
auto batch = (rank == 3 ? 1 : output.get_partial_shape()[0]);
if (batch != 1)
return false;
return true;
};
}
DecomposeMVN::DecomposeMVN() {
MATCHER_SCOPE(DecomposeMVN);
auto axes = pattern::wrap_type<opset8::Constant>();
auto mvn = pattern::wrap_type<opset8::MVN>({pattern::any_input(), axes}, verify_rank_batch());
matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
return Convert(pattern_map.at(mvn).get_node_shared_ptr());
};
auto m = std::make_shared<pattern::Matcher>(mvn, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,24 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace GNAPluginNS {
/**
* @brief Decompose MVN operation
* See official OpenVINO documentation for the MVN formula
* implemented partially by this decomposition:
* https://docs.openvino.ai/latest/openvino_docs_ops_normalization_MVN_6.html
*
*/
class DecomposeMVN : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
DecomposeMVN();
};
} // namespace GNAPluginNS

View File

@ -0,0 +1,157 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_common.hpp"
#include <string>
#include <sstream>
#include <fstream>
#include <memory>
#include <queue>
#include <map>
#include "transformations/init_node_info.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
using namespace ngraph;
using namespace opset8;
namespace LayerTestsDefinitions {
typedef std::tuple<
bool, // Normalize variance
float, // Epsilon
op::MVNEpsMode, // Epsilon mode
bool, // Across channels
bool // MVN version, true = v6, false = v1
> mvnSpecificParams;
typedef std::tuple<
mvnSpecificParams, // MVN parameters
InferenceEngine::Precision, // Network Precision
std::string, // Target Device
std::map<std::string, std::string>, // Configuration
InferenceEngine::SizeVector // Input shapes
> decomposeMVNParams;
class DecomposeMVNTest : public testing::WithParamInterface<decomposeMVNParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<decomposeMVNParams> obj) {
mvnSpecificParams mvnParams;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::map<std::string, std::string> configuration;
InferenceEngine::SizeVector inputShape;
std::tie(mvnParams, netPrecision, targetDevice, configuration, inputShape) = obj.param;
float eps;
op::MVNEpsMode epsMode;
bool normalizeVariance, acrossChannels, mvnVersion6;
std::tie(normalizeVariance, eps, epsMode, acrossChannels, mvnVersion6) = mvnParams;
std::ostringstream result;
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "NV=" << normalizeVariance << "_";
result << "eps=" << eps << "_";
result << "mode=" << static_cast<uint32_t>(epsMode) << "_";
result << "AC=" << acrossChannels << "_";
result << "version=" << mvnVersion6 << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
for (auto const& configItem : configuration) {
result << "_configItem=" << configItem.first << "_" << configItem.second;
}
return result.str();
}
protected:
void SetUp() override {
threshold = 0.2f;
mvnSpecificParams mvnParams;
InferenceEngine::Precision netPrecision;
InferenceEngine::SizeVector inputShape;
std::tie(mvnParams, netPrecision, targetDevice, configuration, inputShape) = this->GetParam();
float eps;
op::MVNEpsMode epsMode;
bool normalizeVariance, acrossChannels, mvnVersion6;
std::tie(normalizeVariance, eps, epsMode, acrossChannels, mvnVersion6) = mvnParams;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto input = builder::makeParams(ngPrc, {inputShape});
InferenceEngine::SizeVector axes(inputShape.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
std::shared_ptr<ngraph::Node> mvn;
if (mvnVersion6) {
const auto axesConst = std::make_shared<op::v0::Constant>(element::i64, Shape{axes.size()}, axes);
mvn = std::make_shared<opset8::MVN>(input[0], axesConst, normalizeVariance, eps, epsMode);
} else {
mvn = std::make_shared<opset2::MVN>(input[0], acrossChannels, normalizeVariance);
}
auto result = std::make_shared<Result>(mvn);
function = std::make_shared<Function>(ResultVector{result}, ParameterVector{input});
}
};
TEST_P(DecomposeMVNTest, CompareWithRefs) {
Run();
}
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
const std::vector<std::map<std::string, std::string>> configs = {
{
{"GNA_DEVICE_MODE", "GNA_SW_FP32"},
{"GNA_SCALE_FACTOR_0", "1"}
}
};
const std::vector<std::vector<size_t>> inputs = {{1, 1, 5, 300}, {1, 6, 256}};
const std::vector<bool> normalizeVariance = {true};
const std::vector<float> eps = {1.0e-09f};
const std::vector<op::MVNEpsMode> epsMode = {op::MVNEpsMode::INSIDE_SQRT};
const std::vector<bool> accrossChannels = {false};
const auto mvnParams_v6 = ::testing::Combine(
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(eps),
::testing::ValuesIn(epsMode),
::testing::Values(false),
::testing::Values(true)
);
const auto mvnParams_v1 = ::testing::Combine(
::testing::ValuesIn(normalizeVariance),
::testing::ValuesIn(eps),
::testing::ValuesIn(epsMode),
::testing::ValuesIn(accrossChannels),
::testing::Values(false)
);
INSTANTIATE_TEST_SUITE_P(smoke_DecomposeMVN_v6, DecomposeMVNTest,
::testing::Combine(
mvnParams_v6,
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(inputs)),
DecomposeMVNTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_DecomposeMVN_v1, DecomposeMVNTest,
::testing::Combine(
mvnParams_v1,
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(inputs)),
DecomposeMVNTest::getTestCaseName);
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,253 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <tuple>
#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp"
#include "transformations/decompose_mvn.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/opsets/opset2.hpp>
#include "backend/gna_limitations.hpp"
namespace decomposeMVN {
typedef std::tuple<
ngraph::Shape, // Input shape
bool, // Normalize variance
float, // Epsilon
ngraph::op::MVNEpsMode, // Epsilon mode
InferenceEngine::SizeVector, // Axes tensor
bool, // Across channels
bool // MVN version, true = v6, false = v1
> decomposeMVNParams;
struct MVNParams {
size_t N;
size_t C;
size_t H;
size_t W;
size_t num_parts;
float eps;
ngraph::op::MVNEpsMode eps_mode;
bool normalize_variance;
};
static std::shared_ptr<ngraph::Node> NormalizeVariance(const MVNParams& mvn_data, const std::shared_ptr<ngraph::opset8::Add>& subtract_mean,
const std::shared_ptr<ngraph::opset8::Constant>& avg_broadcast_const) {
// Prepare consts
auto combined_C_H = mvn_data.C * mvn_data.H;
std::vector<float> avg_weights(8 * mvn_data.W / mvn_data.num_parts, 1.0f / mvn_data.W);
auto avg_weights_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, avg_weights);
std::vector<float> eps_tensor(combined_C_H * mvn_data.W, mvn_data.eps);
auto eps_tensor_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1, combined_C_H * mvn_data.W}, eps_tensor);
std::vector<float> minus_half(combined_C_H * mvn_data.W, -0.5f);
auto minus_half_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{1, combined_C_H * mvn_data.W}, minus_half);
// Calculate square of the difference between input and its mean
auto squared_diff = std::make_shared<ngraph::opset8::Multiply>(subtract_mean, subtract_mean);
squared_diff->set_friendly_name("MvnSqrDiff");
// Calculate sum of the squares
auto squared_diff_reshape = std::make_shared<ngraph::opset8::Reshape>(squared_diff,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4},
ngraph::Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false);
auto transposed_input_3 = std::make_shared<ngraph::opset8::Transpose>(squared_diff_reshape,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_3 = std::make_shared<ngraph::opset8::Convolution>(transposed_input_3, avg_weights_const,
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
transposed_avg_conv_3->set_friendly_name("MvnAvg3");
auto avg_conv_3 = std::make_shared<ngraph::opset8::Transpose>(transposed_avg_conv_3,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1}));
auto reshape_avg_conv_3 = std::make_shared<ngraph::opset8::Reshape>(avg_conv_3,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4},
ngraph::Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false);
auto transposed_input_4 = std::make_shared<ngraph::opset8::Transpose>(reshape_avg_conv_3,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_4 = std::make_shared<ngraph::opset8::Convolution>(transposed_input_4,
avg_broadcast_const, ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
transposed_avg_conv_4->set_friendly_name("MvnAvg4");
auto avg_conv_4 = std::make_shared<ngraph::opset8::Transpose>(transposed_avg_conv_4,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1}));
auto reshape_avg_conv_4 = std::make_shared<ngraph::opset8::Reshape>(avg_conv_4,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{2}, ngraph::Shape{1ull, combined_C_H * mvn_data.W}), false);
std::shared_ptr<ngraph::Node> inv_stdev;
// Create normalization part of the graph
// We ignore inside/outside epsilon position here and always use inside, to get better accuracy
// even though the built-in MVN1 to MVN6 transformation enforces outside setting
// Add epsilon inside the square root
auto add_epsilon = std::make_shared<ngraph::opset8::Add>(eps_tensor_const, reshape_avg_conv_4);
// Calculate square root and inversion
auto log_var_eps = std::make_shared<ngraph::opset8::Log>(add_epsilon);
log_var_eps->set_friendly_name("MvnLogVarEps");
auto log_inv_stdev = std::make_shared<ngraph::opset8::Multiply>(log_var_eps, minus_half_const);
log_inv_stdev->set_friendly_name("MvnLogInvStdev");
inv_stdev = std::make_shared<ngraph::opset8::Exp>(log_inv_stdev);
inv_stdev->set_friendly_name("MvnInvStdev");
auto normalized_output = std::make_shared<ngraph::opset8::Multiply>(subtract_mean, inv_stdev);
normalized_output->set_friendly_name("MvnOutput");
return normalized_output;
}
static std::shared_ptr<ngraph::opset8::Result> Decompose(const std::shared_ptr<ngraph::Node> input_node, const MVNParams& mvn_data) {
// Prepare data
auto combined_C_H = mvn_data.C * mvn_data.H;
std::vector<float> neg_avg_weights(8 * mvn_data.W / mvn_data.num_parts, -1.0f / mvn_data.W);
auto neg_avg_weights_const = ngraph::opset8::Constant::create(ngraph::element::i32,
ngraph::Shape{8, mvn_data.W / mvn_data.num_parts, 1, 1}, neg_avg_weights);
std::vector<float> avg_broadcast(8 * mvn_data.W * mvn_data.num_parts, 0.0f);
for (size_t i = 0; i < mvn_data.W * mvn_data.num_parts; i++) {
avg_broadcast[i * 8] = 1.0f;
}
auto avg_broadcast_const = ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{mvn_data.W, 8 * mvn_data.num_parts, 1, 1}, avg_broadcast);
// Create average calculation part of the graph
// We assume C = 1 case (combined channels)
auto reshape = std::make_shared<ngraph::opset8::Reshape>(input_node,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4},
ngraph::Shape{mvn_data.N, 1ull, combined_C_H, mvn_data.W}), false);
auto input_4d = std::make_shared<ngraph::opset8::Reshape>(reshape,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4},
ngraph::Shape{mvn_data.N, combined_C_H * mvn_data.num_parts, 1ull, mvn_data.W / mvn_data.num_parts}), false);
auto input_2d = std::make_shared<ngraph::opset8::Reshape>(reshape,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{2},
ngraph::Shape{1ull, combined_C_H * mvn_data.W}), false);
auto transposed_input_1 = std::make_shared<ngraph::opset8::Transpose>(input_4d,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_1 = std::make_shared<ngraph::opset8::Convolution>(transposed_input_1, neg_avg_weights_const,
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
transposed_avg_conv_1->set_friendly_name("MvnAvg1");
auto avg_conv_1 = std::make_shared<ngraph::opset8::Transpose>(transposed_avg_conv_1,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1}));
auto reshape_avg_conv_1 = std::make_shared<ngraph::opset8::Reshape>(avg_conv_1,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4},
ngraph::Shape{mvn_data.N, 1ull, combined_C_H, 8 * mvn_data.num_parts}), false);
auto transposed_input_2 = std::make_shared<ngraph::opset8::Transpose>(reshape_avg_conv_1,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 3, 1, 2}));
auto transposed_avg_conv_2 = std::make_shared<ngraph::opset8::Convolution>(transposed_input_2,
avg_broadcast_const, ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
transposed_avg_conv_2->set_friendly_name("MvnAvg2");
auto avg_conv_2 = std::make_shared<ngraph::opset8::Transpose>(transposed_avg_conv_2,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {0, 2, 3, 1}));
auto avg_conv_2_2d = std::make_shared<ngraph::opset8::Reshape>(avg_conv_2,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{2}, ngraph::Shape{1ull, combined_C_H * mvn_data.W}), false);
auto subtract_mean = std::make_shared<ngraph::opset8::Add>(input_2d, avg_conv_2_2d);
subtract_mean->set_friendly_name("MvnSubMean");
std::shared_ptr<ngraph::Node> mvn_output, pre_output = subtract_mean;
// Normalize variance if required
if (mvn_data.normalize_variance) {
pre_output = NormalizeVariance(mvn_data, subtract_mean, avg_broadcast_const);
}
// Reshape (combined channels) back to get the final output
if (input_node->get_output_shape(0).size() == 3) {
mvn_output = std::make_shared<ngraph::opset8::Reshape>(pre_output,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{3}, {mvn_data.C, mvn_data.H, mvn_data.W}), false);
} else {
mvn_output = std::make_shared<ngraph::opset8::Reshape>(pre_output,
ngraph::opset8::Constant::create(ngraph::element::i32, ngraph::Shape{4}, {mvn_data.N, mvn_data.C, mvn_data.H, mvn_data.W}), false);
}
return std::make_shared<ngraph::opset8::Result>(mvn_output);
}
std::shared_ptr<ngraph::Function> getReferenceFunction(const ngraph::Shape& input_shape, const bool& normalize_variance,
const float& eps, const ngraph::op::MVNEpsMode& eps_mode, const InferenceEngine::SizeVector& axes) {
MVNParams mvn_data;
auto mvn_shape_size = input_shape.size();
if (mvn_shape_size == 4) {
mvn_data.N = input_shape[0];
mvn_data.C = input_shape[1];
mvn_data.H = input_shape[2];
mvn_data.W = input_shape[3];
} else if (mvn_shape_size == 3) {
mvn_data.N = 1;
mvn_data.C = input_shape[0];
mvn_data.H = input_shape[1];
mvn_data.W = input_shape[2];
}
mvn_data.eps = eps;
mvn_data.eps_mode = eps_mode;
mvn_data.normalize_variance = normalize_variance;
mvn_data.num_parts = 1;
while (mvn_data.W / mvn_data.num_parts > GNAPluginNS::GNALimitations::convFilterMaxSize) {
mvn_data.num_parts *= 2;
}
// Create decomposed reference function
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i32, input_shape);
std::shared_ptr<ngraph::opset8::Result> result = Decompose(input_params, mvn_data);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}
std::shared_ptr<ngraph::Function> getInitialFunction(const ngraph::Shape& input_shape, const bool& normalize_variance,
const float& eps, const ngraph::op::MVNEpsMode& eps_mode, const InferenceEngine::SizeVector& axes,
const bool& across_channels, const bool& mvn_version_6) {
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i32, input_shape);
std::shared_ptr<ngraph::Node> mvn;
if (mvn_version_6) {
const auto axesConst = std::make_shared<ngraph::opset8::Constant>(ngraph::element::i32, ngraph::Shape{axes.size()}, axes);
mvn = std::make_shared<ngraph::opset8::MVN>(input_params, axesConst, normalize_variance, eps, eps_mode);
} else {
mvn = std::make_shared<ngraph::opset2::MVN>(input_params, across_channels, normalize_variance, eps);
}
auto result = std::make_shared<ngraph::opset8::Result>(mvn);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}
} // namespace decomposeMVN
// ---------------------------------------------------------------------------------------------------------------------
namespace {
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertMVN1ToMVN6>();
manager.register_pass<GNAPluginNS::DecomposeMVN>();
manager.run_passes(function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
} // namespace
TEST(TransformationTests, DecomposeMVNTest) {
for (auto mvn_version_6 : {true, false}) {
for (auto normalize_variance : {true, false}) {
execute_test(decomposeMVN::getInitialFunction(ngraph::Shape{1, 1, 5, 300}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT,
InferenceEngine::SizeVector{2, 1}, false, mvn_version_6),
decomposeMVN::getReferenceFunction(ngraph::Shape{1, 1, 5, 300}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT,
InferenceEngine::SizeVector{2, 1}));
execute_test(decomposeMVN::getInitialFunction(ngraph::Shape{1, 6, 256}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT,
InferenceEngine::SizeVector{2}, false, mvn_version_6),
decomposeMVN::getReferenceFunction(ngraph::Shape{1, 6, 256}, normalize_variance, 1.0e-09f, ngraph::op::MVNEpsMode::INSIDE_SQRT,
InferenceEngine::SizeVector{2}));
}
}
}