Fix: Refreshing of places after subgraph extraction (#12494)
This commit is contained in:
parent
f23fd569bc
commit
ab9319ba94
@ -78,6 +78,11 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
log.debug('Inputs are same: {}, outputs are same: {}'.format(
|
||||
inputs_equal, outputs_equal))
|
||||
|
||||
def create_target_input_shapes(new_input_places):
|
||||
new_input_place_names = [x.get_names()[0] for x in new_input_places]
|
||||
shapes = [shape for shape in argv.placeholder_shapes.values()]
|
||||
return dict(zip(new_input_place_names, shapes))
|
||||
|
||||
if not inputs_equal and not outputs_equal:
|
||||
log.debug('Using extract subgraph')
|
||||
new_input_places = [x['node'] for x in user_shapes]
|
||||
@ -86,11 +91,11 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
input_model.extract_subgraph(new_input_places, new_output_places)
|
||||
# invalidation of existing Place objects could have happened in the operation above
|
||||
if user_shapes:
|
||||
new_input_places_name = [x.get_names()[0] for x in new_input_places]
|
||||
placeholder_shapes = create_target_input_shapes(new_input_places)
|
||||
new_output_places_name = [x.get_names()[0] for x in new_output_places]
|
||||
|
||||
user_shapes, outputs, _ = fe_user_data_repack(
|
||||
input_model, new_input_places_name, argv.placeholder_data_types,
|
||||
input_model, placeholder_shapes, argv.placeholder_data_types,
|
||||
new_output_places_name, argv.freeze_placeholder_with_value, moc_front_end.get_name())
|
||||
elif not inputs_equal:
|
||||
log.debug('Using override_all_inputs')
|
||||
@ -99,9 +104,10 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
input_model.override_all_inputs(new_input_places)
|
||||
# invalidation of existing Place objects could have happened in the operation above
|
||||
if user_shapes:
|
||||
new_input_places_name = [x.get_names()[0] for x in new_input_places]
|
||||
placeholder_shapes = create_target_input_shapes(new_input_places)
|
||||
|
||||
user_shapes, outputs, _ = fe_user_data_repack(
|
||||
input_model, new_input_places_name, argv.placeholder_data_types,
|
||||
input_model, placeholder_shapes, argv.placeholder_data_types,
|
||||
argv.output, argv.freeze_placeholder_with_value, moc_front_end.get_name())
|
||||
elif not outputs_equal:
|
||||
log.debug('Using override_all_outputs')
|
||||
|
Loading…
Reference in New Issue
Block a user