diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py index 3984f9b4beb..8a4e767c5b9 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pipeline.py @@ -78,6 +78,11 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): log.debug('Inputs are same: {}, outputs are same: {}'.format( inputs_equal, outputs_equal)) + def create_target_input_shapes(new_input_places): + new_input_place_names = [x.get_names()[0] for x in new_input_places] + shapes = [shape for shape in argv.placeholder_shapes.values()] + return dict(zip(new_input_place_names, shapes)) + if not inputs_equal and not outputs_equal: log.debug('Using extract subgraph') new_input_places = [x['node'] for x in user_shapes] @@ -86,11 +91,11 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): input_model.extract_subgraph(new_input_places, new_output_places) # invalidation of existing Place objects could have happened in the operation above if user_shapes: - new_input_places_name = [x.get_names()[0] for x in new_input_places] + placeholder_shapes = create_target_input_shapes(new_input_places) new_output_places_name = [x.get_names()[0] for x in new_output_places] user_shapes, outputs, _ = fe_user_data_repack( - input_model, new_input_places_name, argv.placeholder_data_types, + input_model, placeholder_shapes, argv.placeholder_data_types, new_output_places_name, argv.freeze_placeholder_with_value, moc_front_end.get_name()) elif not inputs_equal: log.debug('Using override_all_inputs') @@ -99,9 +104,10 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd): input_model.override_all_inputs(new_input_places) # invalidation of existing Place objects could have happened in the operation above if user_shapes: - new_input_places_name = [x.get_names()[0] for x in new_input_places] + placeholder_shapes = create_target_input_shapes(new_input_places) + user_shapes, outputs, _ = fe_user_data_repack( - input_model, new_input_places_name, argv.placeholder_data_types, + input_model, placeholder_shapes, argv.placeholder_data_types, argv.output, argv.freeze_placeholder_with_value, moc_front_end.get_name()) elif not outputs_equal: log.debug('Using override_all_outputs')