Support Unsqueeze-13, Squeeze-13 in MO (#13391)
This commit is contained in:
parent
6913a8ad39
commit
464e0aae72
@ -0,0 +1,78 @@
|
|||||||
|
|
||||||
|
ir_version: 8
|
||||||
|
producer_name: "onnx-importer-test"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
output: "AXIS"
|
||||||
|
op_type: "Constant"
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
t {
|
||||||
|
dims: 1
|
||||||
|
data_type: 6
|
||||||
|
int32_data: 0
|
||||||
|
name: "const_tensor"
|
||||||
|
}
|
||||||
|
type: TENSOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "X"
|
||||||
|
input: "AXIS"
|
||||||
|
output: "Y"
|
||||||
|
op_type: "Squeeze"
|
||||||
|
}
|
||||||
|
name: "test-model-unsqueeze"
|
||||||
|
input {
|
||||||
|
name: "X"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "AXIS"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 6
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "Y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
domain: "ai.onnx"
|
||||||
|
version: 13
|
||||||
|
}
|
@ -6062,3 +6062,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_is_nan) {
|
|||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_default_domain_opset13) {
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/squeeze_default_domain_opset13.onnx"));
|
||||||
|
|
||||||
|
auto input =
|
||||||
|
test::NDArray<float, 3>({{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}}).get_vector();
|
||||||
|
auto expected_output =
|
||||||
|
test::NDArray<float, 2>({{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}).get_vector();
|
||||||
|
|
||||||
|
auto test_case = test::TestCase(function, s_device);
|
||||||
|
test_case.add_input(input);
|
||||||
|
test_case.add_expected_output(expected_output);
|
||||||
|
test_case.run();
|
||||||
|
}
|
@ -17,12 +17,19 @@ from unit_tests.utils.extractors import PB
|
|||||||
class TestUnsqueezeONNXExt(unittest.TestCase):
|
class TestUnsqueezeONNXExt(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_unsqueeze_node(axes):
|
def _create_unsqueeze_node(axes):
|
||||||
pb = onnx.helper.make_node(
|
if axes is None:
|
||||||
'Unsqueeze',
|
pb = onnx.helper.make_node(
|
||||||
inputs=['x'],
|
'Unsqueeze',
|
||||||
outputs=['y'],
|
inputs=['x'],
|
||||||
axes=axes,
|
outputs=['y'],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
pb = onnx.helper.make_node(
|
||||||
|
'Unsqueeze',
|
||||||
|
inputs=['x'],
|
||||||
|
outputs=['y'],
|
||||||
|
axes=axes,
|
||||||
|
)
|
||||||
|
|
||||||
node = PB({'pb': pb})
|
node = PB({'pb': pb})
|
||||||
return node
|
return node
|
||||||
@ -31,7 +38,7 @@ class TestUnsqueezeONNXExt(unittest.TestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
Op.registered_ops['Unsqueeze'] = Unsqueeze
|
Op.registered_ops['Unsqueeze'] = Unsqueeze
|
||||||
|
|
||||||
@generate(*[[0, 1, 2, 3], [1]])
|
@generate(*[[0, 1, 2, 3], [1], []])
|
||||||
def test_unsqueeze_ext(self, axes):
|
def test_unsqueeze_ext(self, axes):
|
||||||
node = self._create_unsqueeze_node(axes)
|
node = self._create_unsqueeze_node(axes)
|
||||||
UnsqueezeFrontExtractor.extract(node)
|
UnsqueezeFrontExtractor.extract(node)
|
||||||
|
Loading…
Reference in New Issue
Block a user