Fix MO Reader for Squeeze without axes (#16398)

* Fix MO Reader for Squeeze without axes

* Fix style

* Update tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py
This commit is contained in:
Maxim Vafin 2023-03-21 10:28:58 +01:00 committed by GitHub
parent 63797db257
commit 7d56c75d65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 4 deletions

View File

@ -3,12 +3,23 @@
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.squeeze import Squeeze
from openvino.tools.mo.front.common.partial_infer.utils import shape_array, is_fully_defined
class SqueezeInternal(Squeeze):
@staticmethod
def infer(node: Node):
if node.is_in_port_connected(1):
axis_value = node.in_port(1).data.get_value()
Squeeze.infer(node)
# preserve initial axis value
node.in_port(1).data.set_value(axis_value)
else:
# Squeeze without axes provided
node_name = node.soft_get('name', node.id)
input_shape = node.in_port(0).data.get_shape()
assert is_fully_defined(
input_shape), 'Squeeze dimensions are not defined for op "{}"'.format(node_name)
output_shape = [s for s in shape_array(input_shape).tolist() if s != 1]
node.out_port(0).data.set_shape(shape_array(output_shape))

View File

@ -166,6 +166,34 @@ class TestFunction(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_squeeze_no_axes(self):
nodes_attributes = {
'input': {'kind': 'op', 'type': 'Parameter'},
'input_data': {'shape': [2, 1, 3], 'kind': 'data'},
'squeeze': {'kind': 'op', 'type': 'Squeeze'},
'squeeze_data': {'shape': [2, 3], 'kind': 'data', 'value': None},
'result': {'kind': 'op', 'type': 'Result'}
}
edges = [('input', 'input_data'),
('input_data', 'squeeze'),
('squeeze', 'squeeze_data'),
('squeeze_data', 'result'),
]
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
squeeze_node = Node(graph, 'squeeze')
SqueezeInternal.infer(squeeze_node)
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
# Check that graph wasn't changed after shape infer
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_unsqueeze(self):
nodes_attributes = {
'input': {'kind': 'op', 'type': 'Parameter'},