Preserve input constant in Squeeze and Unsqueeze shape infer (#14432)
* Preserve input constant when Squeeze and Unsqueeze shape infer is sone in MO IR Reader * Fix bom test * Add tests * Fix value * Fix test * Apply suggestions from code review Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
b850f422ba
commit
fce4641a25
@ -1076,6 +1076,8 @@ openvino/tools/mo/utils/ir_reader/extenders/strided_slice_extender.py
|
|||||||
openvino/tools/mo/utils/ir_reader/extenders/tensoriterator_extender.py
|
openvino/tools/mo/utils/ir_reader/extenders/tensoriterator_extender.py
|
||||||
openvino/tools/mo/utils/ir_reader/extenders/topk_extender.py
|
openvino/tools/mo/utils/ir_reader/extenders/topk_extender.py
|
||||||
openvino/tools/mo/utils/ir_reader/extenders/variadic_split_extender.py
|
openvino/tools/mo/utils/ir_reader/extenders/variadic_split_extender.py
|
||||||
|
openvino/tools/mo/utils/ir_reader/internal_ops/squeeze.py
|
||||||
|
openvino/tools/mo/utils/ir_reader/internal_ops/unsqueeze.py
|
||||||
openvino/tools/mo/utils/ir_reader/layer_to_class.py
|
openvino/tools/mo/utils/ir_reader/layer_to_class.py
|
||||||
openvino/tools/mo/utils/ir_reader/restore_graph.py
|
openvino/tools/mo/utils/ir_reader/restore_graph.py
|
||||||
openvino/tools/mo/utils/json_schema.py
|
openvino/tools/mo/utils/json_schema.py
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from openvino.tools.mo.graph.graph import Node
|
||||||
|
from openvino.tools.mo.ops.squeeze import Squeeze
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeInternal(Squeeze):
|
||||||
|
@staticmethod
|
||||||
|
def infer(node: Node):
|
||||||
|
axis_value = node.in_port(1).data.get_value()
|
||||||
|
Squeeze.infer(node)
|
||||||
|
# preserve initial axis value
|
||||||
|
node.in_port(1).data.set_value(axis_value)
|
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from openvino.tools.mo.graph.graph import Node
|
||||||
|
from openvino.tools.mo.ops.unsqueeze import Unsqueeze
|
||||||
|
|
||||||
|
|
||||||
|
class UnsqueezeInternal(Unsqueeze):
|
||||||
|
@staticmethod
|
||||||
|
def infer(node: Node):
|
||||||
|
axis_value = node.in_port(1).data.get_value()
|
||||||
|
Unsqueeze.infer(node)
|
||||||
|
# preserve initial axis value
|
||||||
|
node.in_port(1).data.set_value(axis_value)
|
@ -32,6 +32,8 @@ from openvino.tools.mo.ops.split import Split, VariadicSplit
|
|||||||
from openvino.tools.mo.utils.class_registration import update_registration
|
from openvino.tools.mo.utils.class_registration import update_registration
|
||||||
from openvino.tools.mo.utils.import_extensions import import_by_path
|
from openvino.tools.mo.utils.import_extensions import import_by_path
|
||||||
from openvino.tools.mo.utils.ir_reader.extender import Extender
|
from openvino.tools.mo.utils.ir_reader.extender import Extender
|
||||||
|
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
|
||||||
|
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
|
||||||
|
|
||||||
# Operations not registered in collect_ops() function
|
# Operations not registered in collect_ops() function
|
||||||
custom_ops = {
|
custom_ops = {
|
||||||
@ -50,9 +52,11 @@ custom_ops = {
|
|||||||
'Power': Pow,
|
'Power': Pow,
|
||||||
'Slice': OvSlice,
|
'Slice': OvSlice,
|
||||||
'Split': Split,
|
'Split': Split,
|
||||||
|
'Squeeze': SqueezeInternal,
|
||||||
'Subtract': Sub,
|
'Subtract': Sub,
|
||||||
'VariadicSplit': VariadicSplit,
|
'VariadicSplit': VariadicSplit,
|
||||||
'Clamp': AttributedClamp,
|
'Clamp': AttributedClamp,
|
||||||
|
'Unsqueeze': UnsqueezeInternal,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ from generator import generator, generate
|
|||||||
|
|
||||||
from openvino.tools.mo.graph.graph import Node
|
from openvino.tools.mo.graph.graph import Node
|
||||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from openvino.tools.mo.utils.ir_reader.internal_ops.squeeze import SqueezeInternal
|
||||||
|
from openvino.tools.mo.utils.ir_reader.internal_ops.unsqueeze import UnsqueezeInternal
|
||||||
from openvino.tools.mo.utils.ir_reader.layer_to_class import groupconv_to_conv, restore_tensor_names
|
from openvino.tools.mo.utils.ir_reader.layer_to_class import groupconv_to_conv, restore_tensor_names
|
||||||
from unit_tests.utils.graph import build_graph
|
from unit_tests.utils.graph import build_graph
|
||||||
|
|
||||||
@ -128,5 +130,71 @@ class TestFunction(unittest.TestCase):
|
|||||||
|
|
||||||
assert node_1['fw_tensor_debug_info'] == [('abc', 'abc'), ('def', 'def')], 'Restored debug info is wrong!'
|
assert node_1['fw_tensor_debug_info'] == [('abc', 'abc'), ('def', 'def')], 'Restored debug info is wrong!'
|
||||||
assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 'ghi,jkl')], 'Restored debug info is wrong!'
|
assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 'ghi,jkl')], 'Restored debug info is wrong!'
|
||||||
assert node_3['fw_tensor_debug_info'] == [('mno', 'mno'), ('pqr,stu', 'pqr,stu')],\
|
assert node_3['fw_tensor_debug_info'] == [('mno', 'mno'), ('pqr,stu', 'pqr,stu')], \
|
||||||
'Restored debug info is wrong!'
|
'Restored debug info is wrong!'
|
||||||
|
|
||||||
|
def test_squeeze(self):
|
||||||
|
nodes_attributes = {
|
||||||
|
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||||
|
'input_data': {'shape': [2, 1, 3], 'kind': 'data'},
|
||||||
|
|
||||||
|
'axis': {'kind': 'op', 'type': 'Const', 'op': 'Const', 'value': np.array(1), 'shape': []},
|
||||||
|
'axis_data': {'shape': [], 'kind': 'data', 'value': np.array(1)},
|
||||||
|
|
||||||
|
'squeeze': {'kind': 'op', 'type': 'Squeeze'},
|
||||||
|
'squeeze_data': {'shape': [2, 3], 'kind': 'data', 'value': None},
|
||||||
|
|
||||||
|
'result': {'kind': 'op', 'type': 'Result'}
|
||||||
|
}
|
||||||
|
|
||||||
|
edges = [('input', 'input_data'),
|
||||||
|
('input_data', 'squeeze'),
|
||||||
|
('axis', 'axis_data'),
|
||||||
|
('axis_data', 'squeeze'),
|
||||||
|
('squeeze', 'squeeze_data'),
|
||||||
|
('squeeze_data', 'result'),
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
squeeze_node = Node(graph, 'squeeze')
|
||||||
|
SqueezeInternal.infer(squeeze_node)
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
# Check that graph wasn't changed after shape infer
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_unsqueeze(self):
|
||||||
|
nodes_attributes = {
|
||||||
|
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||||
|
'input_data': {'shape': [2, 3], 'kind': 'data'},
|
||||||
|
|
||||||
|
'axis': {'kind': 'op', 'type': 'Const', 'op': 'Const', 'value': np.array(1), 'shape': []},
|
||||||
|
'axis_data': {'shape': [], 'kind': 'data', 'value': np.array(1)},
|
||||||
|
|
||||||
|
'unsqueeze': {'kind': 'op', 'type': 'Unsqueeze'},
|
||||||
|
'unsqueeze_data': {'shape': [2, 1, 3], 'kind': 'data', 'value': None},
|
||||||
|
|
||||||
|
'result': {'kind': 'op', 'type': 'Result'}
|
||||||
|
}
|
||||||
|
|
||||||
|
edges = [('input', 'input_data'),
|
||||||
|
('input_data', 'unsqueeze'),
|
||||||
|
('axis', 'axis_data'),
|
||||||
|
('axis_data', 'unsqueeze'),
|
||||||
|
('unsqueeze', 'unsqueeze_data'),
|
||||||
|
('unsqueeze_data', 'result'),
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
unsqueeze_node = Node(graph, 'unsqueeze')
|
||||||
|
UnsqueezeInternal.infer(unsqueeze_node)
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes_attributes, edges, nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
# Check that graph wasn't changed after shape infer
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user