[ MO ] Fix for remove_op_node_with_data_node (#1934)
This commit is contained in:
parent
19570916e9
commit
c4920ef5a0
@ -40,5 +40,5 @@ class UselessMergeEraser(MiddleReplacementPattern):
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
if len(graph.in_edges(match['merge'].id)) <= 1:
|
||||
remove_op_node_with_data_node(graph, match['merge'])
|
||||
remove_op_node_with_data_node(graph, match['merge'], list(match['merge'].in_nodes().values())[0])
|
||||
log.info("Useles Merge op and data nodes was deleted op='{}'".format(match['merge'].id))
|
||||
|
@ -217,10 +217,11 @@ def merge_data_nodes(graph, survived, removed):
|
||||
|
||||
|
||||
# TODO: unit tests
|
||||
def remove_op_node_with_data_node(graph, node_to_remove):
|
||||
def remove_op_node_with_data_node(graph, node_to_remove, input_data_node=None):
|
||||
from mo.graph.graph import Node
|
||||
assert node_to_remove.kind == 'op'
|
||||
input_data_node = node_to_remove.in_node()
|
||||
if input_data_node is None:
|
||||
input_data_node = node_to_remove.in_node()
|
||||
output_node = [v for _, v in graph.out_edges(node_to_remove.id)]
|
||||
assert len(output_node) == 1, "Cannot remove node producing two or more output tensors"
|
||||
output_node = Node(graph, output_node[0])
|
||||
|
Loading…
Reference in New Issue
Block a user