Preserve input constant in Squeeze and Unsqueeze shape infer (#14432)

* Preserve input constant when Squeeze and Unsqueeze shape infer is sone in MO IR Reader

* Fix bom test

* Add tests

* Fix value

* Fix test

* Apply suggestions from code review

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Maxim Vafin 2022-12-16 12:08:28 +01:00 committed by GitHub
parent b850f422ba
commit fce4641a25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 1 deletions

View File

@ -1076,6 +1076,8 @@ openvino/tools/mo/utils/ir_reader/extenders/strided_slice_extender.py
openvino/tools/mo/utils/ir_reader/extenders/tensoriterator_extender.py
openvino/tools/mo/utils/ir_reader/extenders/topk_extender.py
openvino/tools/mo/utils/ir_reader/extenders/variadic_split_extender.py
openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py
openvino/tools/mo/utils/ir_reader/internal_ops/unsqueeze.py
openvino/tools/mo/utils/ir_reader/layer_to_class.py
openvino/tools/mo/utils/ir_reader/restore_graph.py
openvino/tools/mo/utils/json_schema.py

View File

@ -0,0 +1,14 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.squeeze import Squeeze
class SqueezeInternal(Squeeze):
@staticmethod
def infer(node: Node):
axis_value = node.in_port(1).data.get_value()
Squeeze.infer(node)
# preserve initial axis value
node.in_port(1).data.set_value(axis_value)

View File

@ -0,0 +1,14 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.unsqueeze import Unsqueeze
class UnsqueezeInternal(Unsqueeze):
@staticmethod
def infer(node: Node):
axis_value = node.in_port(1).data.get_value()
Unsqueeze.infer(node)
# preserve initial axis value
node.in_port(1).data.set_value(axis_value)

View File

@ -32,6 +32,8 @@ from openvino.tools.mo.ops.split import Split, VariadicSplit
from openvino.tools.mo.utils.class_registration import update_registration
from openvino.tools.mo.utils.import_extensions import import_by_path
from openvino.tools.mo.utils.ir_reader.extender import Extender
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
# Operations not registered in collect_ops() function
custom_ops = {
@ -50,9 +52,11 @@ custom_ops = {
'Power': Pow,
'Slice': OvSlice,
'Split': Split,
'Squeeze': SqueezeInternal,
'Subtract': Sub,
'VariadicSplit': VariadicSplit,
'Clamp': AttributedClamp,
'Unsqueeze': UnsqueezeInternal,
}

View File

@ -8,6 +8,8 @@ from generator import generator, generate
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
from openvino.tools.mo.utils.ir_reader.layer_to_class import groupconv_to_conv, restore_tensor_names
from unit_tests.utils.graph import build_graph
@ -128,5 +130,71 @@ class TestFunction(unittest.TestCase):
assert node_1['fw_tensor_debug_info'] == [('abc', 'abc'), ('def', 'def')], 'Restored debug info is wrong!'
assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 'ghi,jkl')], 'Restored debug info is wrong!'
assert node_3['fw_tensor_debug_info'] == [('mno', 'mno'), ('pqr,stu', 'pqr,stu')],\
assert node_3['fw_tensor_debug_info'] == [('mno', 'mno'), ('pqr,stu', 'pqr,stu')], \
'Restored debug info is wrong!'
def test_squeeze(self):
nodes_attributes = {
'input': {'kind': 'op', 'type': 'Parameter'},
'input_data': {'shape': [2, 1, 3], 'kind': 'data'},
'axis': {'kind': 'op', 'type': 'Const', 'op': 'Const', 'value': np.array(1), 'shape': []},
'axis_data': {'shape': [], 'kind': 'data', 'value': np.array(1)},
'squeeze': {'kind': 'op', 'type': 'Squeeze'},
'squeeze_data': {'shape': [2, 3], 'kind': 'data', 'value': None},
'result': {'kind': 'op', 'type': 'Result'}
}
edges = [('input', 'input_data'),
('input_data', 'squeeze'),
('axis', 'axis_data'),
('axis_data', 'squeeze'),
('squeeze', 'squeeze_data'),
('squeeze_data', 'result'),
]
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
squeeze_node = Node(graph, 'squeeze')
SqueezeInternal.infer(squeeze_node)
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
# Check that graph wasn't changed after shape infer
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_unsqueeze(self):
nodes_attributes = {
'input': {'kind': 'op', 'type': 'Parameter'},
'input_data': {'shape': [2, 3], 'kind': 'data'},
'axis': {'kind': 'op', 'type': 'Const', 'op': 'Const', 'value': np.array(1), 'shape': []},
'axis_data': {'shape': [], 'kind': 'data', 'value': np.array(1)},
'unsqueeze': {'kind': 'op', 'type': 'Unsqueeze'},
'unsqueeze_data': {'shape': [2, 1, 3], 'kind': 'data', 'value': None},
'result': {'kind': 'op', 'type': 'Result'}
}
edges = [('input', 'input_data'),
('input_data', 'unsqueeze'),
('axis', 'axis_data'),
('axis_data', 'unsqueeze'),
('unsqueeze', 'unsqueeze_data'),
('unsqueeze_data', 'result'),
]
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
unsqueeze_node = Node(graph, 'unsqueeze')
UnsqueezeInternal.infer(unsqueeze_node)
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
# Check that graph wasn't changed after shape infer
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)