From 9cb3c2a6beed4bf0e6caae753e9dab9aee6573ee Mon Sep 17 00:00:00 2001 From: iliya mironov Date: Tue, 17 Nov 2020 16:28:27 +0300 Subject: [PATCH] Fix graph clenup (#3159) * Fix graph clenup * Refactoring graph clean up function * Change wa comment Co-authored-by: Your Name --- .../middle/MarkSubgraphsWithCorrectLayout.py | 4 ---- model-optimizer/mo/middle/passes/eliminate.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py index 0010a9004c4..fc7b834d24f 100644 --- a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py +++ b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py @@ -139,14 +139,10 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern): else: mark_as_correct_data_layout(node) node['nchw_layout'] = True - if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up - node.out_node()['nchw_layout'] = True for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]: mark_as_correct_data_layout(node) node['nchw_layout'] = True - if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up - node.out_node()['nchw_layout'] = True @staticmethod def get_weighted_layer_type_to_in_weights_port(): diff --git a/model-optimizer/mo/middle/passes/eliminate.py b/model-optimizer/mo/middle/passes/eliminate.py index 8137a47b3fd..2aa689cf769 100644 --- a/model-optimizer/mo/middle/passes/eliminate.py +++ b/model-optimizer/mo/middle/passes/eliminate.py @@ -125,8 +125,16 @@ def mark_const_producer_nodes(graph): def eliminate_dead_nodes(graph): + from mo.graph.graph import Node nodes_to_remove = set() for node_name, node_attrs in graph.nodes(data=True): + # The Const operation node may have set an attribute 'nchw_layout' attribute to prevent shape permutation. + # During graph clean-up the operation node is removed and the attribute is lost. + # This results in permutation of the Const shape in the IR and wrong inference results. + # Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation." + if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False): + Node(graph, node_name).out_node()['nchw_layout'] = True + if not node_attrs['is_output_reachable'] or \ (node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or node_attrs.get('force_dead_node', False))): @@ -153,13 +161,6 @@ def add_constant_operations(graph): graph.add_edges_from([(const_node.id, node.id, {'out': 0})]) -def remove_const_ops(graph): - - for node in graph.get_op_nodes(type='Const'): - graph.remove_edge(node.id, node.out_node().id) - graph.remove_node(node.id) - - def shape_inference(graph): for node in graph.pseudo_topological_sort(): if node.has_and_set('need_shape_inference'): @@ -252,5 +253,3 @@ def remove_edges_for_nodes(graph, node_attrs: dict, edge_attrs: dict): src_node, edge = nodes_edges[port] if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]): graph.remove_edge(src_node.id, node.id) - -