Fix: Refreshing of places after subgraph extraction (#12494)

This commit is contained in:
Artur Kulikowski 2022-08-11 14:55:06 +02:00 committed by GitHub
parent f23fd569bc
commit ab9319ba94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -78,6 +78,11 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
log.debug('Inputs are same: {}, outputs are same: {}'.format(
inputs_equal, outputs_equal))
def create_target_input_shapes(new_input_places):
new_input_place_names = [x.get_names()[0] for x in new_input_places]
shapes = [shape for shape in argv.placeholder_shapes.values()]
return dict(zip(new_input_place_names, shapes))
if not inputs_equal and not outputs_equal:
log.debug('Using extract subgraph')
new_input_places = [x['node'] for x in user_shapes]
@ -86,11 +91,11 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
input_model.extract_subgraph(new_input_places, new_output_places)
# invalidation of existing Place objects could have happened in the operation above
if user_shapes:
new_input_places_name = [x.get_names()[0] for x in new_input_places]
placeholder_shapes = create_target_input_shapes(new_input_places)
new_output_places_name = [x.get_names()[0] for x in new_output_places]
user_shapes, outputs, _ = fe_user_data_repack(
input_model, new_input_places_name, argv.placeholder_data_types,
input_model, placeholder_shapes, argv.placeholder_data_types,
new_output_places_name, argv.freeze_placeholder_with_value, moc_front_end.get_name())
elif not inputs_equal:
log.debug('Using override_all_inputs')
@ -99,9 +104,10 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
input_model.override_all_inputs(new_input_places)
# invalidation of existing Place objects could have happened in the operation above
if user_shapes:
new_input_places_name = [x.get_names()[0] for x in new_input_places]
placeholder_shapes = create_target_input_shapes(new_input_places)
user_shapes, outputs, _ = fe_user_data_repack(
input_model, new_input_places_name, argv.placeholder_data_types,
input_model, placeholder_shapes, argv.placeholder_data_types,
argv.output, argv.freeze_placeholder_with_value, moc_front_end.get_name())
elif not outputs_equal:
log.debug('Using override_all_outputs')