Support Unsqueeze-13, Squeeze-13 in MO (#13391)

This commit is contained in:
Bartek Szmelczynski 2022-10-28 13:01:21 +02:00 committed by GitHub
parent 6913a8ad39
commit 464e0aae72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 7 deletions

View File

@ -0,0 +1,78 @@
ir_version: 8
producer_name: "onnx-importer-test"
graph {
node {
output: "AXIS"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 6
int32_data: 0
name: "const_tensor"
}
type: TENSOR
}
}
node {
input: "X"
input: "AXIS"
output: "Y"
op_type: "Squeeze"
}
name: "test-model-unsqueeze"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "AXIS"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
domain: "ai.onnx"
version: 13
}

View File

@ -6062,3 +6062,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_is_nan) {
// clang-format on // clang-format on
} }
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_default_domain_opset13) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/squeeze_default_domain_opset13.onnx"));
auto input =
test::NDArray<float, 3>({{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}}).get_vector();
auto expected_output =
test::NDArray<float, 2>({{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}).get_vector();
auto test_case = test::TestCase(function, s_device);
test_case.add_input(input);
test_case.add_expected_output(expected_output);
test_case.run();
}

View File

@ -17,6 +17,13 @@ from unit_tests.utils.extractors import PB
class TestUnsqueezeONNXExt(unittest.TestCase): class TestUnsqueezeONNXExt(unittest.TestCase):
@staticmethod @staticmethod
def _create_unsqueeze_node(axes): def _create_unsqueeze_node(axes):
if axes is None:
pb = onnx.helper.make_node(
'Unsqueeze',
inputs=['x'],
outputs=['y'],
)
else:
pb = onnx.helper.make_node( pb = onnx.helper.make_node(
'Unsqueeze', 'Unsqueeze',
inputs=['x'], inputs=['x'],
@ -31,7 +38,7 @@ class TestUnsqueezeONNXExt(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
Op.registered_ops['Unsqueeze'] = Unsqueeze Op.registered_ops['Unsqueeze'] = Unsqueeze
@generate(*[[0, 1, 2, 3], [1]]) @generate(*[[0, 1, 2, 3], [1], []])
def test_unsqueeze_ext(self, axes): def test_unsqueeze_ext(self, axes):
node = self._create_unsqueeze_node(axes) node = self._create_unsqueeze_node(axes)
UnsqueezeFrontExtractor.extract(node) UnsqueezeFrontExtractor.extract(node)