Fix ResultRename transformation to preserve output order (#14956)
* Fix ResultRename transformation to preserve output order * Apply review feedback
This commit is contained in:
parent
ddd4f050c7
commit
51fb4fd2e3
@ -15,6 +15,7 @@ class ResultRename(BackReplacementPattern):
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
op_names = set()
|
||||
result_names_map = dict()
|
||||
for node in graph.get_op_nodes():
|
||||
if node.has_valid('name'):
|
||||
op_names.add(node['name'])
|
||||
@ -46,4 +47,10 @@ class ResultRename(BackReplacementPattern):
|
||||
log.warning("Tensor name for Result node with name {} wasn't found. "
|
||||
"Default renaming was used: {}".format(node.soft_get('name', node.id),
|
||||
result_name))
|
||||
result_names_map[node['name']] = result_name
|
||||
node['name'] = result_name
|
||||
|
||||
# Change names saved in graph.outputs_order
|
||||
for i in range(len(graph.outputs_order)):
|
||||
if graph.outputs_order[i] in result_names_map:
|
||||
graph.outputs_order[i] = result_names_map[graph.outputs_order[i]]
|
||||
|
@ -6,15 +6,15 @@ import unittest
|
||||
from openvino.tools.mo.back.ResultRename import ResultRename
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op, result
|
||||
from unit_tests.utils.graph import build_graph, result
|
||||
|
||||
nodes = {
|
||||
**regular_op('Op1', {'type': 'Op1', 'kind': 'op', 'op': 'Op1'}),
|
||||
**regular_op('Op2', {'type': 'Op2', 'kind': 'op', 'op': 'Op2'}),
|
||||
'Op1': {'type': 'Op1', 'kind': 'op', 'op': 'Op1'},
|
||||
'Op2': {'type': 'Op2', 'kind': 'op', 'op': 'Op2'},
|
||||
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1_tensor')]},
|
||||
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op2', 'Op2_tensor')]},
|
||||
**result('result1'),
|
||||
**result('result2'),
|
||||
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1_tensor')]},
|
||||
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op2_tensor')]},
|
||||
}
|
||||
|
||||
|
||||
@ -45,12 +45,16 @@ class ResultRenameTest(unittest.TestCase):
|
||||
graph = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result1'),
|
||||
('Op1_data', 'Op2'), ('Op2', 'Op2_data'),
|
||||
('Op2_data', 'result2')])
|
||||
graph.outputs_order = ['result1', 'result2']
|
||||
|
||||
ResultRename().find_and_replace_pattern(graph)
|
||||
res1_node = Node(graph, 'result1')
|
||||
res2_node = Node(graph, 'result2')
|
||||
self.assertTrue(res1_node['name'] == 'Op1_tensor')
|
||||
self.assertTrue(res2_node['name'] == 'Op2_tensor')
|
||||
|
||||
self.assertTrue(graph.outputs_order == ['Op1_tensor', 'Op2_tensor'])
|
||||
|
||||
def test_case5(self):
|
||||
graph = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result1'),
|
||||
('Op1_data', 'Op2'), ('Op2', 'Op2_data'),
|
||||
@ -64,3 +68,26 @@ class ResultRenameTest(unittest.TestCase):
|
||||
self.assertTrue(res1_node['name'] == 'Op1_tensor')
|
||||
self.assertTrue(res2_node['name'] == 'Op2_tensor')
|
||||
|
||||
def test_case6(self):
|
||||
_nodes = nodes.copy()
|
||||
_nodes.update({
|
||||
'Op3': {'type': 'Op3', 'kind': 'op', 'op': 'Op3'},
|
||||
'Op4': {'type': 'Op4', 'kind': 'op', 'op': 'Op4'},
|
||||
'Op3_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op3', 'Op3_tensor')]},
|
||||
'Op4_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op4', 'Op4_tensor')]},
|
||||
**result('result3'),
|
||||
**result('result4'),
|
||||
})
|
||||
graph = build_graph(_nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result1'), ('Op1_data', 'Op2'),
|
||||
('Op2', 'Op2_data'), ('Op2_data', 'result2'), ('Op2_data', 'Op3'),
|
||||
('Op3', 'Op3_data'), ('Op3_data', 'result3'), ('Op3_data', 'Op4'),
|
||||
('Op4', 'Op4_data'), ('Op4_data', 'result4')])
|
||||
graph.outputs_order = ['result1', 'result3', 'result4', 'result2']
|
||||
|
||||
ResultRename().find_and_replace_pattern(graph)
|
||||
self.assertTrue(Node(graph, 'result1')['name'] == 'Op1_tensor')
|
||||
self.assertTrue(Node(graph, 'result2')['name'] == 'Op2_tensor')
|
||||
self.assertTrue(Node(graph, 'result3')['name'] == 'Op3_tensor')
|
||||
self.assertTrue(Node(graph, 'result4')['name'] == 'Op4_tensor')
|
||||
|
||||
self.assertTrue(graph.outputs_order == ['Op1_tensor', 'Op3_tensor', 'Op4_tensor', 'Op2_tensor'])
|
||||
|
Loading…
Reference in New Issue
Block a user