Fix ResultRename transformation to preserve output order (#14956)

* Fix ResultRename transformation to preserve output order

* Apply review feedback
This commit is contained in:
Maxim Vafin 2023-01-06 14:48:13 +01:00 committed by GitHub
parent ddd4f050c7
commit 51fb4fd2e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 5 deletions

View File

@ -15,6 +15,7 @@ class ResultRename(BackReplacementPattern):
def find_and_replace_pattern(self, graph: Graph):
op_names = set()
result_names_map = dict()
for node in graph.get_op_nodes():
if node.has_valid('name'):
op_names.add(node['name'])
@ -46,4 +47,10 @@ class ResultRename(BackReplacementPattern):
log.warning("Tensor name for Result node with name {} wasn't found. "
"Default renaming was used: {}".format(node.soft_get('name', node.id),
result_name))
result_names_map[node['name']] = result_name
node['name'] = result_name
# Change names saved in graph.outputs_order
for i in range(len(graph.outputs_order)):
if graph.outputs_order[i] in result_names_map:
graph.outputs_order[i] = result_names_map[graph.outputs_order[i]]

View File

@ -6,15 +6,15 @@ import unittest
from openvino.tools.mo.back.ResultRename import ResultRename
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op, result
from unit_tests.utils.graph import build_graph, result
nodes = {
**regular_op('Op1', {'type': 'Op1', 'kind': 'op', 'op': 'Op1'}),
**regular_op('Op2', {'type': 'Op2', 'kind': 'op', 'op': 'Op2'}),
'Op1': {'type': 'Op1', 'kind': 'op', 'op': 'Op1'},
'Op2': {'type': 'Op2', 'kind': 'op', 'op': 'Op2'},
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1_tensor')]},
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op2', 'Op2_tensor')]},
**result('result1'),
**result('result2'),
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1_tensor')]},
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op2_tensor')]},
}
@ -45,12 +45,16 @@ class ResultRenameTest(unittest.TestCase):
graph = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result1'),
('Op1_data', 'Op2'), ('Op2', 'Op2_data'),
('Op2_data', 'result2')])
graph.outputs_order = ['result1', 'result2']
ResultRename().find_and_replace_pattern(graph)
res1_node = Node(graph, 'result1')
res2_node = Node(graph, 'result2')
self.assertTrue(res1_node['name'] == 'Op1_tensor')
self.assertTrue(res2_node['name'] == 'Op2_tensor')
self.assertTrue(graph.outputs_order == ['Op1_tensor', 'Op2_tensor'])
def test_case5(self):
graph = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result1'),
('Op1_data', 'Op2'), ('Op2', 'Op2_data'),
@ -64,3 +68,26 @@ class ResultRenameTest(unittest.TestCase):
self.assertTrue(res1_node['name'] == 'Op1_tensor')
self.assertTrue(res2_node['name'] == 'Op2_tensor')
def test_case6(self):
_nodes = nodes.copy()
_nodes.update({
'Op3': {'type': 'Op3', 'kind': 'op', 'op': 'Op3'},
'Op4': {'type': 'Op4', 'kind': 'op', 'op': 'Op4'},
'Op3_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op3', 'Op3_tensor')]},
'Op4_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op4', 'Op4_tensor')]},
**result('result3'),
**result('result4'),
})
graph = build_graph(_nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result1'), ('Op1_data', 'Op2'),
('Op2', 'Op2_data'), ('Op2_data', 'result2'), ('Op2_data', 'Op3'),
('Op3', 'Op3_data'), ('Op3_data', 'result3'), ('Op3_data', 'Op4'),
('Op4', 'Op4_data'), ('Op4_data', 'result4')])
graph.outputs_order = ['result1', 'result3', 'result4', 'result2']
ResultRename().find_and_replace_pattern(graph)
self.assertTrue(Node(graph, 'result1')['name'] == 'Op1_tensor')
self.assertTrue(Node(graph, 'result2')['name'] == 'Op2_tensor')
self.assertTrue(Node(graph, 'result3')['name'] == 'Op3_tensor')
self.assertTrue(Node(graph, 'result4')['name'] == 'Op4_tensor')
self.assertTrue(graph.outputs_order == ['Op1_tensor', 'Op3_tensor', 'Op4_tensor', 'Op2_tensor'])