Fix graph clenup (#3159)
* Fix graph clenup * Refactoring graph clean up function * Change wa comment Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
74ff38fedb
commit
9cb3c2a6be
@ -139,14 +139,10 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
else:
|
||||
mark_as_correct_data_layout(node)
|
||||
node['nchw_layout'] = True
|
||||
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
|
||||
node.out_node()['nchw_layout'] = True
|
||||
|
||||
for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]:
|
||||
mark_as_correct_data_layout(node)
|
||||
node['nchw_layout'] = True
|
||||
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
|
||||
node.out_node()['nchw_layout'] = True
|
||||
|
||||
@staticmethod
|
||||
def get_weighted_layer_type_to_in_weights_port():
|
||||
|
@ -125,8 +125,16 @@ def mark_const_producer_nodes(graph):
|
||||
|
||||
|
||||
def eliminate_dead_nodes(graph):
|
||||
from mo.graph.graph import Node
|
||||
nodes_to_remove = set()
|
||||
for node_name, node_attrs in graph.nodes(data=True):
|
||||
# The Const operation node may have set an attribute 'nchw_layout' attribute to prevent shape permutation.
|
||||
# During graph clean-up the operation node is removed and the attribute is lost.
|
||||
# This results in permutation of the Const shape in the IR and wrong inference results.
|
||||
# Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation."
|
||||
if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False):
|
||||
Node(graph, node_name).out_node()['nchw_layout'] = True
|
||||
|
||||
if not node_attrs['is_output_reachable'] or \
|
||||
(node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or
|
||||
node_attrs.get('force_dead_node', False))):
|
||||
@ -153,13 +161,6 @@ def add_constant_operations(graph):
|
||||
graph.add_edges_from([(const_node.id, node.id, {'out': 0})])
|
||||
|
||||
|
||||
def remove_const_ops(graph):
|
||||
|
||||
for node in graph.get_op_nodes(type='Const'):
|
||||
graph.remove_edge(node.id, node.out_node().id)
|
||||
graph.remove_node(node.id)
|
||||
|
||||
|
||||
def shape_inference(graph):
|
||||
for node in graph.pseudo_topological_sort():
|
||||
if node.has_and_set('need_shape_inference'):
|
||||
@ -252,5 +253,3 @@ def remove_edges_for_nodes(graph, node_attrs: dict, edge_attrs: dict):
|
||||
src_node, edge = nodes_edges[port]
|
||||
if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
|
||||
graph.remove_edge(src_node.id, node.id)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user