Use MVN in GroupNorm/InstanceNorm in ONNX importer (#2711)

* Use MVN in GroupNorm/InstanceNorm in ONNX importer

* Remove mosaic_8 model from xfail list
This commit is contained in:
Mateusz Tabaka 2020-10-21 12:48:53 +02:00 committed by GitHub
parent 458425ac9e
commit d901bbfce3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 58 additions and 44 deletions

View File

@ -60,7 +60,6 @@ void op::MVN::validate_and_infer_types()
if (m_reduction_axes.empty() && input_value(0).get_partial_shape().rank().is_static())
{
AxisSet reduction_axes;
reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < input_value(0).get_partial_shape().rank().get_length(); ++i)
{
@ -90,11 +89,10 @@ OutputVector op::MVN::decompose_op() const
{
// calculate variance
auto variance = builder::opset1::variance(data, m_reduction_axes);
variance = make_shared<op::Sqrt>(variance);
// add epsilon
auto eps_node = op::Constant::create(
data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
variance = variance + eps_node;
variance = std::make_shared<op::Sqrt>(variance + eps_node);
variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes);
return OutputVector{mean_normalization / variance};

View File

@ -93,21 +93,7 @@ namespace ngraph
const auto reduction_axes =
common::get_monotonic_range_along_node_rank(data, 2);
const std::shared_ptr<ngraph::Node> eps_node =
std::make_shared<default_opset::Constant>(
data.get_element_type(), Shape{}, epsilon);
auto mean =
std::make_shared<default_opset::ReduceMean>(data, reduction_axes, true);
auto diff = std::make_shared<default_opset::Subtract>(data, mean);
auto variance = std::make_shared<default_opset::ReduceMean>(
std::make_shared<default_opset::Power>(
diff,
default_opset::Constant::create(data.get_element_type(), Shape{}, {2})),
reduction_axes,
true);
const auto sqrt = std::make_shared<default_opset::Sqrt>(
std::make_shared<default_opset::Add>(variance, eps_node));
auto mvn = std::make_shared<default_opset::MVN>(data, false, true, epsilon);
std::shared_ptr<ngraph::Node> data_shape_node;
if (data_pshape.is_static())
@ -132,10 +118,9 @@ namespace ngraph
data_shape_node,
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
// scale * (data - mean) / sqrt + bias
// scale * mvn + bias
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::Divide>(scale, sqrt);
result = std::make_shared<default_opset::Multiply>(diff, result);
std::make_shared<default_opset::Multiply>(mvn, scale);
result = std::make_shared<default_opset::Add>(result, bias);
return {result};

View File

@ -106,33 +106,19 @@ namespace ngraph
}
auto data_reshaped = std::make_shared<default_opset::Reshape>(
data, detail::create_group_norm_shape(data, num_groups), true);
const auto reduction_axes =
common::get_monotonic_range_along_node_rank(data_reshaped, 2);
auto mean = std::make_shared<default_opset::ReduceMean>(
data_reshaped, reduction_axes, true);
auto diff = std::make_shared<default_opset::Subtract>(data_reshaped, mean);
auto variance = std::make_shared<default_opset::ReduceMean>(
std::make_shared<default_opset::Power>(
diff, default_opset::Constant::create(element::f32, Shape{}, {2})),
reduction_axes,
true);
const std::shared_ptr<ngraph::Node> eps_node =
std::make_shared<default_opset::Constant>(element::f32, Shape{}, eps);
const auto sqrt = std::make_shared<default_opset::Sqrt>(
std::make_shared<default_opset::Add>(variance, eps_node));
auto mvn =
std::make_shared<default_opset::MVN>(data_reshaped, false, true, eps);
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::Reshape>(mvn, data_shape_node, true);
const auto& rank = data.get_partial_shape().rank();
NGRAPH_CHECK(rank.is_static());
auto data_rank_size = rank.get_length();
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::Divide>(diff, sqrt);
result =
std::make_shared<default_opset::Reshape>(result, data_shape_node, true);
result = std::make_shared<default_opset::Multiply>(
reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size),
result);
result,
reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size));
result = std::make_shared<default_opset::Add>(
result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));

View File

@ -155,7 +155,6 @@ if len(zoo_models) > 0:
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_vgg_model_vgg19_bn_7_vgg19_bn_vgg19_bn_cpu"),
(xfail_issue_36533, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov2_model_tinyyolov2_7_tiny_yolov2_model_cpu"),
(xfail_issue_36533, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov2_model_tinyyolov2_8_tiny_yolov2_Model_cpu"),
(xfail_issue_36533, "test_onnx_model_zoo_vision_style_transfer_fast_neural_style_model_mosaic_8_mosaic_mosaic_cpu"),
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet18_v2_7_resnet18v2_resnet18_v2_7_cpu"),
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet101_v1_7_resnet101v1_resnet101_v1_7_cpu"),
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet152_v1_7_resnet152v1_resnet152_v1_7_cpu"),

View File

@ -1380,6 +1380,52 @@ NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_shared_across_channel_batch_size_2)
{
Shape data_shape{2, 2, 5};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto mvn_func = make_shared<op::MVN>(data, true);
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
auto test_case = test::TestCase<TestEngine>(function);
// data
vector<float> data_vector(shape_size(data_shape));
iota(begin(data_vector), end(data_vector), 0);
test_case.add_input<float>(data_vector);
// expected result
test_case.add_expected_output<float>(
data_shape,
{-1.5666989f, -1.2185436f, -0.8703883f, -0.5222329f, -0.1740777f, 0.1740777f, 0.5222329f,
0.8703883f, 1.2185436f, 1.5666989f, -1.5666989f, -1.2185436f, -0.8703883f, -0.5222329f,
-0.1740777f, 0.1740777f, 0.5222329f, 0.8703883f, 1.2185436f, 1.5666989f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_not_shared_across_channel_batch_size_2)
{
Shape data_shape{2, 2, 5};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto mvn_func = make_shared<op::MVN>(data, false);
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
auto test_case = test::TestCase<TestEngine>(function);
// data
vector<float> data_vector(shape_size(data_shape));
iota(begin(data_vector), end(data_vector), 0);
test_case.add_input<float>(data_vector);
// expected result
test_case.add_expected_output<float>(
data_shape,
{-1.4142135f, -0.7071068f, 0.0000000f, 0.7071068f, 1.4142135f, -1.4142135f, -0.7071068f,
0.0000000f, 0.7071068f, 1.4142135f, -1.4142135f, -0.7071068f, 0.0000000f, 0.7071068f,
1.4142135f, -1.4142135f, -0.7071068f, 0.0000000f, 0.7071068f, 1.4142135f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, grn_4d)
{
const Shape data_shape{1, 2, 3, 4};

View File

@ -34,12 +34,12 @@ TEST(type_prop, mvn_partial)
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto mvn_func = make_shared<op::MVN>(data);
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{0, 1, 2}));
EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{1, 2}));
ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
(PartialShape{1, Dimension::dynamic(), 6})));
// across_channels = false
EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{0, 2}));
EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{2}));
// rank unknown
auto mvn_partial =