From 7d56c75d65f0072ff9fc36c5e747ce35d15d031a Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 21 Mar 2023 10:28:58 +0100 Subject: [PATCH] Fix MO Reader for Squeeze without axes (#16398) * Fix MO Reader for Squeeze without axes * Fix style * Update tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py --- .../utils/ir_reader/internal_ops/squeeze.py | 19 ++++++++++--- .../mo/utils/ir_reader/layer_to_class_test.py | 28 +++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py b/tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py index 67bfc80dea5..5e9702e30f8 100644 --- a/tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py +++ b/tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py @@ -3,12 +3,23 @@ from openvino.tools.mo.graph.graph import Node from openvino.tools.mo.ops.squeeze import Squeeze +from openvino.tools.mo.front.common.partial_infer.utils import shape_array, is_fully_defined class SqueezeInternal(Squeeze): @staticmethod def infer(node: Node): - axis_value = node.in_port(1).data.get_value() - Squeeze.infer(node) - # preserve initial axis value - node.in_port(1).data.set_value(axis_value) + if node.is_in_port_connected(1): + axis_value = node.in_port(1).data.get_value() + Squeeze.infer(node) + # preserve initial axis value + node.in_port(1).data.set_value(axis_value) + else: + # Squeeze without axes provided + node_name = node.soft_get('name', node.id) + input_shape = node.in_port(0).data.get_shape() + assert is_fully_defined( + input_shape), 'Squeeze dimensions are not defined for op "{}"'.format(node_name) + output_shape = [s for s in shape_array(input_shape).tolist() if s != 1] + node.out_port(0).data.set_shape(shape_array(output_shape)) + diff --git a/tools/mo/unit_tests/mo/utils/ir_reader/layer_to_class_test.py b/tools/mo/unit_tests/mo/utils/ir_reader/layer_to_class_test.py index f86e4514ca7..8dd6a17aba6 100644 --- a/tools/mo/unit_tests/mo/utils/ir_reader/layer_to_class_test.py +++ b/tools/mo/unit_tests/mo/utils/ir_reader/layer_to_class_test.py @@ -166,6 +166,34 @@ class TestFunction(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + def test_squeeze_no_axes(self): + nodes_attributes = { + 'input': {'kind': 'op', 'type': 'Parameter'}, + 'input_data': {'shape': [2, 1, 3], 'kind': 'data'}, + + 'squeeze': {'kind': 'op', 'type': 'Squeeze'}, + 'squeeze_data': {'shape': [2, 3], 'kind': 'data', 'value': None}, + + 'result': {'kind': 'op', 'type': 'Result'} + } + + edges = [('input', 'input_data'), + ('input_data', 'squeeze'), + ('squeeze', 'squeeze_data'), + ('squeeze_data', 'result'), + ] + + graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True) + + squeeze_node = Node(graph, 'squeeze') + SqueezeInternal.infer(squeeze_node) + + graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True) + + # Check that graph wasn't changed after shape infer + (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + def test_unsqueeze(self): nodes_attributes = { 'input': {'kind': 'op', 'type': 'Parameter'},