Update remove_converts pass with shape inference (#10474)
This commit is contained in:
@@ -938,17 +938,18 @@ def find_shape_subgraph_endpoints(out_ports: List[Port], visited: set = None) ->
|
||||
|
||||
|
||||
def remove_converts(graph: Graph):
|
||||
for op in graph.get_op_nodes(type='Convert'):
|
||||
source_op = op.in_port(0).get_source().node
|
||||
if source_op.type == 'Const' and source_op.data_type == np.float16:
|
||||
# Get access to data node after Convert operation and set Insert_Convert_operation_after
|
||||
# to restore Convert operation later
|
||||
op.out_node(0)['Insert_Convert_operation_after'] = True
|
||||
# Mark Const and Convert operation to fold them
|
||||
source_op['need_shape_inference'] = True
|
||||
op.out_node(0)['old_rt_info'] = op['rt_info']
|
||||
op['stop_value_propagation'] = False
|
||||
op['need_shape_inference'] = True
|
||||
for op in graph.get_op_nodes():
|
||||
if op.type == 'Convert':
|
||||
source_op = op.in_port(0).get_source().node
|
||||
if source_op.type == 'Const' and source_op.data_type == np.float16:
|
||||
# Get access to data node after Convert operation and set Insert_Convert_operation_after
|
||||
# to restore Convert operation later
|
||||
op.out_node(0)['Insert_Convert_operation_after'] = True
|
||||
# Mark Const and Convert operation to fold them
|
||||
source_op['need_shape_inference'] = True
|
||||
op.out_node(0)['old_rt_info'] = op['rt_info']
|
||||
op['stop_value_propagation'] = False
|
||||
op['need_shape_inference'] = True
|
||||
graph.clean_up()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user