[MO] Implementation of names uniqueness check (#5651)
* added new transformation to check the uniqueness of nodes names * added unittest * remove redundant line * conversation resolving * updated unittest * added new unittest, added check for uniqueness of new node name * added a description * added renaming of several results with the same name and unittest for this case * another implementation, updated unittests * added a comment * updated comments * added comment to the nodes_with_equal_names func * added a condition * added a result name check in unittests
This commit is contained in:
@@ -37,6 +37,7 @@ extensions/back/LRNToNorm.py
|
||||
extensions/back/MarkNodesWithShapeValues.py
|
||||
extensions/back/MatMulNormalizer.py
|
||||
extensions/back/MaxPool.py
|
||||
extensions/back/names_uniqueness_check.py
|
||||
extensions/back/NormalizeToNormalizeL2.py
|
||||
extensions/back/op_versioning.py
|
||||
extensions/back/OptimizeTransposeReshapeSequence.py
|
||||
|
||||
67
model-optimizer/extensions/back/names_uniqueness_check.py
Normal file
67
model-optimizer/extensions/back/names_uniqueness_check.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import defaultdict
|
||||
from extensions.back.pass_separator import BackFinish
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
|
||||
|
||||
def nodes_with_equal_names(graph: Graph):
|
||||
"""
|
||||
:param graph: Graph to operate on
|
||||
:return: Dictionary with node names as keys and a list of their corresponding nodes as values
|
||||
"""
|
||||
names_dict = defaultdict(list)
|
||||
for node in graph.get_op_nodes():
|
||||
node_name = node.soft_get('name', node.id)
|
||||
names_dict[node_name].append(node)
|
||||
return names_dict
|
||||
|
||||
|
||||
def make_node_names_unique(nodes: list, node_names: set):
|
||||
"""
|
||||
:param nodes: List with nodes matching a specific name
|
||||
:param node_names: Set with all node names contained in the graph
|
||||
:return: None
|
||||
|
||||
Result nodes will be renamed only when it is absolutely necessary(if there are several Result nodes with the same name).
|
||||
Function finds a position of Result nodes in the "nodes" list, take the first and rename all other nodes.
|
||||
If the "nodes" list does not contain Result nodes, then all nodes starting from the second one will be renamed.
|
||||
All new names are added to the "node_names" set.
|
||||
"""
|
||||
results_pos = [idx for idx, node in enumerate(nodes) if node.op == 'Result']
|
||||
node_position_to_keep = 0
|
||||
if len(results_pos) != 0:
|
||||
node_position_to_keep = results_pos[0]
|
||||
for idx, node in enumerate(nodes):
|
||||
if idx != node_position_to_keep:
|
||||
new_node_name = node.soft_get('name', node.id) + '_' + str(idx)
|
||||
# preparing a new unique name for the node
|
||||
while new_node_name in node_names:
|
||||
new_node_name += '_' + str(idx)
|
||||
node_names.add(new_node_name)
|
||||
rename_node(node, new_node_name)
|
||||
|
||||
|
||||
class NamesUniquenessCheck(BackReplacementPattern):
|
||||
"""
|
||||
If there are several layers with the same name in the original model and they are saved in the IR, IE will fail with
|
||||
the invalid IR error. IE checks the uniqueness of the names and, if it is not true, throws an exception. The way how
|
||||
to fix it on the MO side is to rename this nodes (one node will remain with the original name). Since we prefer to
|
||||
save framework names for the output nodes, nodes with op=Result will not be renamed, except the case when there are
|
||||
several Result nodes with the same name.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
return [BackFinish]
|
||||
|
||||
def run_before(self):
|
||||
return []
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
names_to_nodes = nodes_with_equal_names(graph)
|
||||
node_names = set(names_to_nodes.keys())
|
||||
for nodes in names_to_nodes.values():
|
||||
if len(nodes) > 1:
|
||||
make_node_names_unique(nodes, node_names)
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.back.names_uniqueness_check import NamesUniquenessCheck
|
||||
from mo.graph.graph import Node
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
|
||||
class TestNamesUniquenessCheck(unittest.TestCase):
|
||||
|
||||
def test_1(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
'input': {'kind': 'op', 'op': 'Parameter', 'name': 'node'},
|
||||
'cast': {'kind': 'op', 'op': 'Cast', 'name': 'node'},
|
||||
'result': {'kind': 'op', 'op': 'Result', 'name': 'node'}
|
||||
},
|
||||
edges=[
|
||||
('input', 'cast'),
|
||||
('cast', 'result')
|
||||
]
|
||||
)
|
||||
|
||||
NamesUniquenessCheck().find_and_replace_pattern(graph)
|
||||
names = [node.name for node in graph.get_op_nodes()]
|
||||
result_name = Node(graph, 'result').name
|
||||
|
||||
self.assertTrue(len(set(names)) == 3)
|
||||
self.assertTrue(result_name == 'node')
|
||||
|
||||
def test_2(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
'input': {'kind': 'op', 'op': 'Parameter', 'name': 'node'},
|
||||
'cast': {'kind': 'op', 'op': 'Cast', 'name': 'node_0'},
|
||||
'result': {'kind': 'op', 'op': 'Result', 'name': 'node'}
|
||||
},
|
||||
edges=[
|
||||
('input', 'cast'),
|
||||
('cast', 'result')
|
||||
]
|
||||
)
|
||||
|
||||
NamesUniquenessCheck().find_and_replace_pattern(graph)
|
||||
names = [node.name for node in graph.get_op_nodes()]
|
||||
result_name = Node(graph, 'result').name
|
||||
|
||||
self.assertTrue(len(set(names)) == 3)
|
||||
self.assertTrue(result_name == 'node')
|
||||
|
||||
def test_3(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs={
|
||||
'input': {'kind': 'op', 'op': 'Parameter', 'name': 'node_0'},
|
||||
'cast': {'kind': 'op', 'op': 'Cast', 'name': 'node_1'},
|
||||
'result_1': {'kind': 'op', 'op': 'Result', 'name': 'node'},
|
||||
'result_2': {'kind': 'op', 'op': 'Result', 'name': 'node'}
|
||||
},
|
||||
edges=[
|
||||
('input', 'cast'),
|
||||
('cast', 'result_1'),
|
||||
('cast', 'result_2'),
|
||||
]
|
||||
)
|
||||
NamesUniquenessCheck().find_and_replace_pattern(graph)
|
||||
names = [node.name for node in graph.get_op_nodes()]
|
||||
|
||||
self.assertTrue('node' in names)
|
||||
self.assertTrue(len(set(names)) == 4)
|
||||
Reference in New Issue
Block a user