[ MO: CVS-32286 ] IdentityN fix (#668)

This commit is contained in:
Evgenya Stepyreva
2020-05-29 09:11:22 +03:00
committed by GitHub
parent e51e1682ca
commit 5cc8114322
2 changed files with 31 additions and 1 deletions

View File

@@ -29,6 +29,11 @@ class IdentityN_to_Identity(FrontReplacementPattern):
IdentityN Identity Identity
/ \ | |
output_0 output_1 output_0 output_1
ATTENTION: not all in/outputs of the IdentityN may survive during ModelOptimizer pipeline.
And it breaks the original operation semantics.
For example, output_1 may be not be used during network output computations.
To preserve this unused in/output ports we disconnect the corresponding out/input port.
"""
enabled = True
@@ -41,12 +46,20 @@ class IdentityN_to_Identity(FrontReplacementPattern):
dtypes = node.data_types
for idx, port in node.in_ports().items():
assert node.is_out_port_connected(idx), 'IdentityN {} has inconsistent input and output ports'.format(name)
if not node.is_in_port_connected(idx) or not node.is_out_port_connected(idx):
# ATTENTION section in the description above
continue
assert idx < len(dtypes), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(name, dtypes)
identity = Identity(graph, {'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx]}).create_node()
port.get_connection().set_destination(identity.in_port(0))
node.out_port(idx).get_connection().set_source(identity.out_port(0))
# ATTENTION section in the description above
for in_port in node.in_ports().values():
in_port.disconnect()
for out_port in node.out_ports().values():
out_port.disconnect()
def find_and_replace_pattern(self, graph: Graph):
for identityN in graph.get_op_nodes(op='IdentityN'):
self.replace_identityN(identityN)

View File

@@ -61,3 +61,20 @@ class TestIdentityN(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_identityN_unused_ports(self):
graph = build_graph(nodes, [
*connect('placeholder_0', '0:identityN'),
*connect('placeholder_1', '1:identityN'),
*connect('identityN:0', 'output0'),
], nodes_with_edges_only=True)
IdentityN_to_Identity().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes, [
*connect('placeholder_0', 'identity0'),
*connect('identity0', 'output0'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True)
self.assertTrue(flag, resp)