Update remove_converts pass with shape inference (#10474)

This commit is contained in:
Nikita Malinin
2022-02-17 18:17:07 +03:00
committed by GitHub
parent 6e5eb87340
commit a090abbc92

View File

@@ -938,17 +938,18 @@ def find_shape_subgraph_endpoints(out_ports: List[Port], visited: set = None) ->
def remove_converts(graph: Graph):
for op in graph.get_op_nodes(type='Convert'):
source_op = op.in_port(0).get_source().node
if source_op.type == 'Const' and source_op.data_type == np.float16:
# Get access to data node after Convert operation and set Insert_Convert_operation_after
# to restore Convert operation later
op.out_node(0)['Insert_Convert_operation_after'] = True
# Mark Const and Convert operation to fold them
source_op['need_shape_inference'] = True
op.out_node(0)['old_rt_info'] = op['rt_info']
op['stop_value_propagation'] = False
op['need_shape_inference'] = True
for op in graph.get_op_nodes():
if op.type == 'Convert':
source_op = op.in_port(0).get_source().node
if source_op.type == 'Const' and source_op.data_type == np.float16:
# Get access to data node after Convert operation and set Insert_Convert_operation_after
# to restore Convert operation later
op.out_node(0)['Insert_Convert_operation_after'] = True
# Mark Const and Convert operation to fold them
source_op['need_shape_inference'] = True
op.out_node(0)['old_rt_info'] = op['rt_info']
op['stop_value_propagation'] = False
op['need_shape_inference'] = True
graph.clean_up()