Support Unsqueeze-13, Squeeze-13 in MO (#13391)
This commit is contained in:
parent
6913a8ad39
commit
464e0aae72
@ -0,0 +1,78 @@
|
||||
|
||||
ir_version: 8
|
||||
producer_name: "onnx-importer-test"
|
||||
graph {
|
||||
node {
|
||||
output: "AXIS"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 6
|
||||
int32_data: 0
|
||||
name: "const_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "AXIS"
|
||||
output: "Y"
|
||||
op_type: "Squeeze"
|
||||
}
|
||||
name: "test-model-unsqueeze"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "AXIS"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: "ai.onnx"
|
||||
version: 13
|
||||
}
|
@ -6062,3 +6062,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_is_nan) {
|
||||
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_default_domain_opset13) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/squeeze_default_domain_opset13.onnx"));
|
||||
|
||||
auto input =
|
||||
test::NDArray<float, 3>({{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}}).get_vector();
|
||||
auto expected_output =
|
||||
test::NDArray<float, 2>({{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}).get_vector();
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input(input);
|
||||
test_case.add_expected_output(expected_output);
|
||||
test_case.run();
|
||||
}
|
@ -17,6 +17,13 @@ from unit_tests.utils.extractors import PB
|
||||
class TestUnsqueezeONNXExt(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _create_unsqueeze_node(axes):
|
||||
if axes is None:
|
||||
pb = onnx.helper.make_node(
|
||||
'Unsqueeze',
|
||||
inputs=['x'],
|
||||
outputs=['y'],
|
||||
)
|
||||
else:
|
||||
pb = onnx.helper.make_node(
|
||||
'Unsqueeze',
|
||||
inputs=['x'],
|
||||
@ -31,7 +38,7 @@ class TestUnsqueezeONNXExt(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
Op.registered_ops['Unsqueeze'] = Unsqueeze
|
||||
|
||||
@generate(*[[0, 1, 2, 3], [1]])
|
||||
@generate(*[[0, 1, 2, 3], [1], []])
|
||||
def test_unsqueeze_ext(self, axes):
|
||||
node = self._create_unsqueeze_node(axes)
|
||||
UnsqueezeFrontExtractor.extract(node)
|
||||
|
Loading…
Reference in New Issue
Block a user