Fix graph clenup (#3159)

* Fix graph clenup

* Refactoring graph clean up function

* Change wa comment

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
iliya mironov 2020-11-17 16:28:27 +03:00 committed by GitHub
parent 74ff38fedb
commit 9cb3c2a6be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 13 deletions

View File

@ -139,14 +139,10 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
else:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
node.out_node()['nchw_layout'] = True
for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
node.out_node()['nchw_layout'] = True
@staticmethod
def get_weighted_layer_type_to_in_weights_port():

View File

@ -125,8 +125,16 @@ def mark_const_producer_nodes(graph):
def eliminate_dead_nodes(graph):
from mo.graph.graph import Node
nodes_to_remove = set()
for node_name, node_attrs in graph.nodes(data=True):
# The Const operation node may have set an attribute 'nchw_layout' attribute to prevent shape permutation.
# During graph clean-up the operation node is removed and the attribute is lost.
# This results in permutation of the Const shape in the IR and wrong inference results.
# Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation."
if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False):
Node(graph, node_name).out_node()['nchw_layout'] = True
if not node_attrs['is_output_reachable'] or \
(node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or
node_attrs.get('force_dead_node', False))):
@ -153,13 +161,6 @@ def add_constant_operations(graph):
graph.add_edges_from([(const_node.id, node.id, {'out': 0})])
def remove_const_ops(graph):
for node in graph.get_op_nodes(type='Const'):
graph.remove_edge(node.id, node.out_node().id)
graph.remove_node(node.id)
def shape_inference(graph):
for node in graph.pseudo_topological_sort():
if node.has_and_set('need_shape_inference'):
@ -252,5 +253,3 @@ def remove_edges_for_nodes(graph, node_attrs: dict, edge_attrs: dict):
src_node, edge = nodes_edges[port]
if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
graph.remove_edge(src_node.id, node.id)