[ MO: CVS-32286 ] IdentityN fix (#668)
This commit is contained in:
committed by
GitHub
parent
e51e1682ca
commit
5cc8114322
@@ -29,6 +29,11 @@ class IdentityN_to_Identity(FrontReplacementPattern):
|
||||
IdentityN Identity Identity
|
||||
/ \ | |
|
||||
output_0 output_1 output_0 output_1
|
||||
|
||||
ATTENTION: not all in/outputs of the IdentityN may survive during ModelOptimizer pipeline.
|
||||
And it breaks the original operation semantics.
|
||||
For example, output_1 may be not be used during network output computations.
|
||||
To preserve this unused in/output ports we disconnect the corresponding out/input port.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
@@ -41,12 +46,20 @@ class IdentityN_to_Identity(FrontReplacementPattern):
|
||||
dtypes = node.data_types
|
||||
|
||||
for idx, port in node.in_ports().items():
|
||||
assert node.is_out_port_connected(idx), 'IdentityN {} has inconsistent input and output ports'.format(name)
|
||||
if not node.is_in_port_connected(idx) or not node.is_out_port_connected(idx):
|
||||
# ATTENTION section in the description above
|
||||
continue
|
||||
assert idx < len(dtypes), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(name, dtypes)
|
||||
identity = Identity(graph, {'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx]}).create_node()
|
||||
port.get_connection().set_destination(identity.in_port(0))
|
||||
node.out_port(idx).get_connection().set_source(identity.out_port(0))
|
||||
|
||||
# ATTENTION section in the description above
|
||||
for in_port in node.in_ports().values():
|
||||
in_port.disconnect()
|
||||
for out_port in node.out_ports().values():
|
||||
out_port.disconnect()
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for identityN in graph.get_op_nodes(op='IdentityN'):
|
||||
self.replace_identityN(identityN)
|
||||
|
||||
@@ -61,3 +61,20 @@ class TestIdentityN(unittest.TestCase):
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_identityN_unused_ports(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder_0', '0:identityN'),
|
||||
*connect('placeholder_1', '1:identityN'),
|
||||
*connect('identityN:0', 'output0'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
IdentityN_to_Identity().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder_0', 'identity0'),
|
||||
*connect('identity0', 'output0'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
Reference in New Issue
Block a user