[ MO ] Fix for remove_op_node_with_data_node (#1934)

This commit is contained in:
Evgenya Stepyreva 2020-08-25 20:30:41 +03:00 committed by GitHub
parent 19570916e9
commit c4920ef5a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -40,5 +40,5 @@ class UselessMergeEraser(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: dict):
if len(graph.in_edges(match['merge'].id)) <= 1:
remove_op_node_with_data_node(graph, match['merge'])
remove_op_node_with_data_node(graph, match['merge'], list(match['merge'].in_nodes().values())[0])
log.info("Useles Merge op and data nodes was deleted op='{}'".format(match['merge'].id))

View File

@ -217,10 +217,11 @@ def merge_data_nodes(graph, survived, removed):
# TODO: unit tests
def remove_op_node_with_data_node(graph, node_to_remove):
def remove_op_node_with_data_node(graph, node_to_remove, input_data_node=None):
from mo.graph.graph import Node
assert node_to_remove.kind == 'op'
input_data_node = node_to_remove.in_node()
if input_data_node is None:
input_data_node = node_to_remove.in_node()
output_node = [v for _, v in graph.out_edges(node_to_remove.id)]
assert len(output_node) == 1, "Cannot remove node producing two or more output tensors"
output_node = Node(graph, output_node[0])