Fix ONNX GroupNorm and ExperimentalDetectronGroupNorm (#4579)

This commit is contained in:
Bartosz Sledz 2021-03-09 13:40:45 +01:00 committed by GitHub
parent 5f12213a33
commit fc589572a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 150 additions and 9 deletions

View File

@ -57,6 +57,7 @@ namespace ngraph
"DetectionOutput",
"ExperimentalDetectronDetectionOutput",
"ExperimentalDetectronGenerateProposalsSingleImage",
"ExperimentalDetectronGroupNorm",
"ExperimentalDetectronPriorGridGenerator",
"ExperimentalDetectronROIFeatureExtractor",
"ExperimentalDetectronTopKROIs",

View File

@ -50,13 +50,14 @@ namespace ngraph
auto splits = builder::opset1::split(shape, rank_size);
auto num_groups_const =
default_opset::Constant::create(element::i64, Shape{1}, {num_groups});
NodeVector new_shape{
splits[0].get_node_shared_ptr(),
ngraph::OutputVector new_shape{
splits[0],
num_groups_const,
std::make_shared<default_opset::Divide>(splits[1], num_groups_const)};
for (size_t i = 2; i < rank_size; i++)
{
new_shape.push_back(splits[i].get_node_shared_ptr());
new_shape.push_back(splits[i]);
}
return std::make_shared<default_opset::Concat>(new_shape, 0);
}
@ -78,14 +79,20 @@ namespace ngraph
size_t num_groups =
static_cast<size_t>(node.get_attribute_value<int64_t>("num_groups"));
float eps = node.get_attribute_value<float>("eps", 1e-5);
float eps = node.get_attribute_value<float>("eps", 1e-6);
auto data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
auto data_reshaped = std::make_shared<default_opset::Reshape>(
data, detail::create_group_norm_shape(data, num_groups), true);
const auto reduction_axes =
common::get_monotonic_range_along_node_rank(data_reshaped, 2);
auto mvn =
std::make_shared<ngraph::opset5::MVN>(data_reshaped, false, true, eps);
std::make_shared<default_opset::MVN>(data_reshaped,
reduction_axes,
true,
eps,
ngraph::op::MVNEpsMode::INSIDE_SQRT);
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::Reshape>(mvn, data_shape_node, true);

View File

@ -0,0 +1,114 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
input: "gamma"
input: "beta"
output: "y"
op_type: "GroupNorm"
domain: "org.openvinotoolkit"
attribute {
name: "num_groups"
i: 4
type: INT
}
attribute {
name: "eps"
f: 1e-5
type: FLOAT
}
}
name: "group_norm_example"
initializer {
dims: 8
data_type: 1
name: "gamma"
raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A"
}
initializer {
dims: 8
data_type: 1
name: "beta"
raw_data: "\0\0\200?\0\0\0@\0\0@@\0\0\200@\0\0\240@\0\0\300@\0\0\340@\0\0\0A"
}
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 8
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "gamma"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
input {
name: "beta"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 8
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -97,7 +97,3 @@ opset_import {
domain: ""
version: 10
}
opset_import {
domain: "org.openvinotoolkit"
version: 1
}

View File

@ -233,6 +233,29 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm_5d)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/group_norm_5d.prototxt"));
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
Shape shape{2, 8, 1, 2, 1};
int size = shape_size(shape);
std::vector<float> data(size);
std::iota(data.begin(), data.end(), 0);
std::vector<float> output = {-0.34163546562, 0.55278813838, 2.89442372322, 4.68327093124,
-1.02490639686, 1.65836453437, 5.78884744644, 9.36654186248,
-1.70817732810, 2.76394081115, 8.68327140808, 14.04981231689,
-2.39144825935, 3.86951708793, 11.57769489288, 18.73308372497,
-0.34163546562, 0.55278813838, 2.89442372322, 4.68327093124,
-1.02490639686, 1.65836453437, 5.78884744644, 9.36654186248,
-1.70817732810, 2.76394081115, 8.68327140808, 14.04981231689,
-2.39144825935, 3.86951708793, 11.57769489288, 18.73308372497};
test_case.add_input<float>(data);
test_case.add_expected_output<float>(shape, output);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_normalize)
{
const auto function = onnx_import::import_onnx_model(