Use MVN in GroupNorm/InstanceNorm in ONNX importer (#2711)
* Use MVN in GroupNorm/InstanceNorm in ONNX importer * Remove mosaic_8 model from xfail list
This commit is contained in:
parent
458425ac9e
commit
d901bbfce3
@ -60,7 +60,6 @@ void op::MVN::validate_and_infer_types()
|
||||
if (m_reduction_axes.empty() && input_value(0).get_partial_shape().rank().is_static())
|
||||
{
|
||||
AxisSet reduction_axes;
|
||||
reduction_axes.insert(0);
|
||||
size_t start_axis = m_across_channels ? 1 : 2;
|
||||
for (size_t i = start_axis; i < input_value(0).get_partial_shape().rank().get_length(); ++i)
|
||||
{
|
||||
@ -90,11 +89,10 @@ OutputVector op::MVN::decompose_op() const
|
||||
{
|
||||
// calculate variance
|
||||
auto variance = builder::opset1::variance(data, m_reduction_axes);
|
||||
variance = make_shared<op::Sqrt>(variance);
|
||||
// add epsilon
|
||||
auto eps_node = op::Constant::create(
|
||||
data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
|
||||
variance = variance + eps_node;
|
||||
variance = std::make_shared<op::Sqrt>(variance + eps_node);
|
||||
variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes);
|
||||
|
||||
return OutputVector{mean_normalization / variance};
|
||||
|
@ -93,21 +93,7 @@ namespace ngraph
|
||||
const auto reduction_axes =
|
||||
common::get_monotonic_range_along_node_rank(data, 2);
|
||||
|
||||
const std::shared_ptr<ngraph::Node> eps_node =
|
||||
std::make_shared<default_opset::Constant>(
|
||||
data.get_element_type(), Shape{}, epsilon);
|
||||
|
||||
auto mean =
|
||||
std::make_shared<default_opset::ReduceMean>(data, reduction_axes, true);
|
||||
auto diff = std::make_shared<default_opset::Subtract>(data, mean);
|
||||
auto variance = std::make_shared<default_opset::ReduceMean>(
|
||||
std::make_shared<default_opset::Power>(
|
||||
diff,
|
||||
default_opset::Constant::create(data.get_element_type(), Shape{}, {2})),
|
||||
reduction_axes,
|
||||
true);
|
||||
const auto sqrt = std::make_shared<default_opset::Sqrt>(
|
||||
std::make_shared<default_opset::Add>(variance, eps_node));
|
||||
auto mvn = std::make_shared<default_opset::MVN>(data, false, true, epsilon);
|
||||
|
||||
std::shared_ptr<ngraph::Node> data_shape_node;
|
||||
if (data_pshape.is_static())
|
||||
@ -132,10 +118,9 @@ namespace ngraph
|
||||
data_shape_node,
|
||||
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1));
|
||||
|
||||
// scale * (data - mean) / sqrt + bias
|
||||
// scale * mvn + bias
|
||||
std::shared_ptr<ngraph::Node> result =
|
||||
std::make_shared<default_opset::Divide>(scale, sqrt);
|
||||
result = std::make_shared<default_opset::Multiply>(diff, result);
|
||||
std::make_shared<default_opset::Multiply>(mvn, scale);
|
||||
result = std::make_shared<default_opset::Add>(result, bias);
|
||||
|
||||
return {result};
|
||||
|
@ -106,33 +106,19 @@ namespace ngraph
|
||||
}
|
||||
auto data_reshaped = std::make_shared<default_opset::Reshape>(
|
||||
data, detail::create_group_norm_shape(data, num_groups), true);
|
||||
const auto reduction_axes =
|
||||
common::get_monotonic_range_along_node_rank(data_reshaped, 2);
|
||||
auto mean = std::make_shared<default_opset::ReduceMean>(
|
||||
data_reshaped, reduction_axes, true);
|
||||
auto diff = std::make_shared<default_opset::Subtract>(data_reshaped, mean);
|
||||
auto variance = std::make_shared<default_opset::ReduceMean>(
|
||||
std::make_shared<default_opset::Power>(
|
||||
diff, default_opset::Constant::create(element::f32, Shape{}, {2})),
|
||||
reduction_axes,
|
||||
true);
|
||||
|
||||
const std::shared_ptr<ngraph::Node> eps_node =
|
||||
std::make_shared<default_opset::Constant>(element::f32, Shape{}, eps);
|
||||
const auto sqrt = std::make_shared<default_opset::Sqrt>(
|
||||
std::make_shared<default_opset::Add>(variance, eps_node));
|
||||
auto mvn =
|
||||
std::make_shared<default_opset::MVN>(data_reshaped, false, true, eps);
|
||||
std::shared_ptr<ngraph::Node> result =
|
||||
std::make_shared<default_opset::Reshape>(mvn, data_shape_node, true);
|
||||
|
||||
const auto& rank = data.get_partial_shape().rank();
|
||||
NGRAPH_CHECK(rank.is_static());
|
||||
auto data_rank_size = rank.get_length();
|
||||
|
||||
std::shared_ptr<ngraph::Node> result =
|
||||
std::make_shared<default_opset::Divide>(diff, sqrt);
|
||||
result =
|
||||
std::make_shared<default_opset::Reshape>(result, data_shape_node, true);
|
||||
result = std::make_shared<default_opset::Multiply>(
|
||||
reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size),
|
||||
result);
|
||||
result,
|
||||
reshape::reshape_channel_shaped_node_to_nchw(scale, data_rank_size));
|
||||
result = std::make_shared<default_opset::Add>(
|
||||
result, reshape::reshape_channel_shaped_node_to_nchw(bias, data_rank_size));
|
||||
|
||||
|
@ -155,7 +155,6 @@ if len(zoo_models) > 0:
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_vgg_model_vgg19_bn_7_vgg19_bn_vgg19_bn_cpu"),
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov2_model_tinyyolov2_7_tiny_yolov2_model_cpu"),
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov2_model_tinyyolov2_8_tiny_yolov2_Model_cpu"),
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_style_transfer_fast_neural_style_model_mosaic_8_mosaic_mosaic_cpu"),
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet18_v2_7_resnet18v2_resnet18_v2_7_cpu"),
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet101_v1_7_resnet101v1_resnet101_v1_7_cpu"),
|
||||
(xfail_issue_36533, "test_onnx_model_zoo_vision_classification_resnet_model_resnet152_v1_7_resnet152v1_resnet152_v1_7_cpu"),
|
||||
|
@ -1380,6 +1380,52 @@ NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_shared_across_channel_batch_size_2)
|
||||
{
|
||||
Shape data_shape{2, 2, 5};
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto mvn_func = make_shared<op::MVN>(data, true);
|
||||
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// data
|
||||
vector<float> data_vector(shape_size(data_shape));
|
||||
iota(begin(data_vector), end(data_vector), 0);
|
||||
test_case.add_input<float>(data_vector);
|
||||
|
||||
// expected result
|
||||
test_case.add_expected_output<float>(
|
||||
data_shape,
|
||||
{-1.5666989f, -1.2185436f, -0.8703883f, -0.5222329f, -0.1740777f, 0.1740777f, 0.5222329f,
|
||||
0.8703883f, 1.2185436f, 1.5666989f, -1.5666989f, -1.2185436f, -0.8703883f, -0.5222329f,
|
||||
-0.1740777f, 0.1740777f, 0.5222329f, 0.8703883f, 1.2185436f, 1.5666989f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_not_shared_across_channel_batch_size_2)
|
||||
{
|
||||
Shape data_shape{2, 2, 5};
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto mvn_func = make_shared<op::MVN>(data, false);
|
||||
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// data
|
||||
vector<float> data_vector(shape_size(data_shape));
|
||||
iota(begin(data_vector), end(data_vector), 0);
|
||||
test_case.add_input<float>(data_vector);
|
||||
|
||||
// expected result
|
||||
test_case.add_expected_output<float>(
|
||||
data_shape,
|
||||
{-1.4142135f, -0.7071068f, 0.0000000f, 0.7071068f, 1.4142135f, -1.4142135f, -0.7071068f,
|
||||
0.0000000f, 0.7071068f, 1.4142135f, -1.4142135f, -0.7071068f, 0.0000000f, 0.7071068f,
|
||||
1.4142135f, -1.4142135f, -0.7071068f, 0.0000000f, 0.7071068f, 1.4142135f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, grn_4d)
|
||||
{
|
||||
const Shape data_shape{1, 2, 3, 4};
|
||||
|
@ -34,12 +34,12 @@ TEST(type_prop, mvn_partial)
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
|
||||
auto mvn_func = make_shared<op::MVN>(data);
|
||||
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
|
||||
EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{0, 1, 2}));
|
||||
EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{1, 2}));
|
||||
ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
|
||||
(PartialShape{1, Dimension::dynamic(), 6})));
|
||||
|
||||
// across_channels = false
|
||||
EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{0, 2}));
|
||||
EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{2}));
|
||||
|
||||
// rank unknown
|
||||
auto mvn_partial =
|
||||
|
Loading…
Reference in New Issue
Block a user