Fix MO Reader for Squeeze without axes (#16398)
* Fix MO Reader for Squeeze without axes * Fix style * Update tools/mo/openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py
This commit is contained in:
parent
63797db257
commit
7d56c75d65
@ -3,12 +3,23 @@
|
||||
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.squeeze import Squeeze
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import shape_array, is_fully_defined
|
||||
|
||||
|
||||
class SqueezeInternal(Squeeze):
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
axis_value = node.in_port(1).data.get_value()
|
||||
Squeeze.infer(node)
|
||||
# preserve initial axis value
|
||||
node.in_port(1).data.set_value(axis_value)
|
||||
if node.is_in_port_connected(1):
|
||||
axis_value = node.in_port(1).data.get_value()
|
||||
Squeeze.infer(node)
|
||||
# preserve initial axis value
|
||||
node.in_port(1).data.set_value(axis_value)
|
||||
else:
|
||||
# Squeeze without axes provided
|
||||
node_name = node.soft_get('name', node.id)
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
assert is_fully_defined(
|
||||
input_shape), 'Squeeze dimensions are not defined for op "{}"'.format(node_name)
|
||||
output_shape = [s for s in shape_array(input_shape).tolist() if s != 1]
|
||||
node.out_port(0).data.set_shape(shape_array(output_shape))
|
||||
|
||||
|
@ -166,6 +166,34 @@ class TestFunction(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_squeeze_no_axes(self):
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': [2, 1, 3], 'kind': 'data'},
|
||||
|
||||
'squeeze': {'kind': 'op', 'type': 'Squeeze'},
|
||||
'squeeze_data': {'shape': [2, 3], 'kind': 'data', 'value': None},
|
||||
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
edges = [('input', 'input_data'),
|
||||
('input_data', 'squeeze'),
|
||||
('squeeze', 'squeeze_data'),
|
||||
('squeeze_data', 'result'),
|
||||
]
|
||||
|
||||
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||
|
||||
squeeze_node = Node(graph, 'squeeze')
|
||||
SqueezeInternal.infer(squeeze_node)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||
|
||||
# Check that graph wasn't changed after shape infer
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_unsqueeze(self):
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
|
Loading…
Reference in New Issue
Block a user