Preserve input constant in Squeeze and Unsqueeze shape infer (#14432)
* Preserve input constant when Squeeze and Unsqueeze shape infer is sone in MO IR Reader * Fix bom test * Add tests * Fix value * Fix test * Apply suggestions from code review Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
b850f422ba
commit
fce4641a25
@ -1076,6 +1076,8 @@ openvino/tools/mo/utils/ir_reader/extenders/strided_slice_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/extenders/tensoriterator_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/extenders/topk_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/extenders/variadic_split_extender.py
|
||||
openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py
|
||||
openvino/tools/mo/utils/ir_reader/internal_ops/unsqueeze.py
|
||||
openvino/tools/mo/utils/ir_reader/layer_to_class.py
|
||||
openvino/tools/mo/utils/ir_reader/restore_graph.py
|
||||
openvino/tools/mo/utils/json_schema.py
|
||||
|
@ -0,0 +1,14 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.squeeze import Squeeze
|
||||
|
||||
|
||||
class SqueezeInternal(Squeeze):
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
axis_value = node.in_port(1).data.get_value()
|
||||
Squeeze.infer(node)
|
||||
# preserve initial axis value
|
||||
node.in_port(1).data.set_value(axis_value)
|
@ -0,0 +1,14 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
class UnsqueezeInternal(Unsqueeze):
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
axis_value = node.in_port(1).data.get_value()
|
||||
Unsqueeze.infer(node)
|
||||
# preserve initial axis value
|
||||
node.in_port(1).data.set_value(axis_value)
|
@ -32,6 +32,8 @@ from openvino.tools.mo.ops.split import Split, VariadicSplit
|
||||
from openvino.tools.mo.utils.class_registration import update_registration
|
||||
from openvino.tools.mo.utils.import_extensions import import_by_path
|
||||
from openvino.tools.mo.utils.ir_reader.extender import Extender
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
|
||||
|
||||
# Operations not registered in collect_ops() function
|
||||
custom_ops = {
|
||||
@ -50,9 +52,11 @@ custom_ops = {
|
||||
'Power': Pow,
|
||||
'Slice': OvSlice,
|
||||
'Split': Split,
|
||||
'Squeeze': SqueezeInternal,
|
||||
'Subtract': Sub,
|
||||
'VariadicSplit': VariadicSplit,
|
||||
'Clamp': AttributedClamp,
|
||||
'Unsqueeze': UnsqueezeInternal,
|
||||
}
|
||||
|
||||
|
||||
|
@ -8,6 +8,8 @@ from generator import generator, generate
|
||||
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
|
||||
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
|
||||
from openvino.tools.mo.utils.ir_reader.layer_to_class import groupconv_to_conv, restore_tensor_names
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
@ -130,3 +132,69 @@ class TestFunction(unittest.TestCase):
|
||||
assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 'ghi,jkl')], 'Restored debug info is wrong!'
|
||||
assert node_3['fw_tensor_debug_info'] == [('mno', 'mno'), ('pqr,stu', 'pqr,stu')], \
|
||||
'Restored debug info is wrong!'
|
||||
|
||||
def test_squeeze(self):
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': [2, 1, 3], 'kind': 'data'},
|
||||
|
||||
'axis': {'kind': 'op', 'type': 'Const', 'op': 'Const', 'value': np.array(1), 'shape': []},
|
||||
'axis_data': {'shape': [], 'kind': 'data', 'value': np.array(1)},
|
||||
|
||||
'squeeze': {'kind': 'op', 'type': 'Squeeze'},
|
||||
'squeeze_data': {'shape': [2, 3], 'kind': 'data', 'value': None},
|
||||
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
edges = [('input', 'input_data'),
|
||||
('input_data', 'squeeze'),
|
||||
('axis', 'axis_data'),
|
||||
('axis_data', 'squeeze'),
|
||||
('squeeze', 'squeeze_data'),
|
||||
('squeeze_data', 'result'),
|
||||
]
|
||||
|
||||
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||
|
||||
squeeze_node = Node(graph, 'squeeze')
|
||||
SqueezeInternal.infer(squeeze_node)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||
|
||||
# Check that graph wasn't changed after shape infer
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_unsqueeze(self):
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': [2, 3], 'kind': 'data'},
|
||||
|
||||
'axis': {'kind': 'op', 'type': 'Const', 'op': 'Const', 'value': np.array(1), 'shape': []},
|
||||
'axis_data': {'shape': [], 'kind': 'data', 'value': np.array(1)},
|
||||
|
||||
'unsqueeze': {'kind': 'op', 'type': 'Unsqueeze'},
|
||||
'unsqueeze_data': {'shape': [2, 1, 3], 'kind': 'data', 'value': None},
|
||||
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
edges = [('input', 'input_data'),
|
||||
('input_data', 'unsqueeze'),
|
||||
('axis', 'axis_data'),
|
||||
('axis_data', 'unsqueeze'),
|
||||
('unsqueeze', 'unsqueeze_data'),
|
||||
('unsqueeze_data', 'result'),
|
||||
]
|
||||
|
||||
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||
|
||||
unsqueeze_node = Node(graph, 'unsqueeze')
|
||||
UnsqueezeInternal.infer(unsqueeze_node)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||
|
||||
# Check that graph wasn't changed after shape infer
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user