From c44df9907b1577f237c813f2b13ccad1b0d7f388 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 9 Aug 2023 18:01:35 +0200 Subject: [PATCH] [MO] Do not create example inputs based on input or input_shape (#18975) --- .../test_mo_convert_pytorch.py | 110 +++++------------- .../mo/moc_frontend/pytorch_frontend_utils.py | 59 +--------- .../moc_frontend/pytorch_frontend_utils.py | 58 +-------- 3 files changed, 36 insertions(+), 191 deletions(-) diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py index 5e6cdc765c9..ae48f1ce1ae 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py @@ -74,6 +74,7 @@ def make_pt_model_with_optional_input(): ) def forward(self, x, y=None, z=None): + logits = None if y is None: logits = self.linear_relu_stack(x + z) if z is None: @@ -87,7 +88,7 @@ def make_ref_pt_model_one_input(shape, dtype=np.float32): shape = PartialShape(shape) param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype) relu = ov.opset8.relu(param1) - if dtype != np.float32: + if dtype not in [np.float32, Type.dynamic]: relu = ov.opset8.convert(relu, np.float32) sigm = ov.opset8.sigmoid(relu) @@ -106,9 +107,13 @@ def make_ref_pt_model_two_inputs(shape, dtype=np.float32): shape = PartialShape(shape) param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype) param2 = ov.opset8.parameter(shape, name="input_1", dtype=dtype) - mul = ov.opset8.multiply(param1, param2) + if dtype == Type.dynamic: + cl = ov.opset8.convert_like(param2, param1) + mul = ov.opset8.multiply(param1, cl) + else: + mul = ov.opset8.multiply(param1, param2) relu = ov.opset8.relu(mul) - if dtype != np.float32: + if dtype not in [np.float32, Type.dynamic]: relu = ov.opset8.convert(relu, np.float32) sigm = ov.opset8.sigmoid(relu) @@ -277,7 +282,7 @@ def create_pytorch_jit_script_function(tmp_dir): return torch.sigmoid(torch.relu(x * y)) inp_shape = PartialShape([Dimension(1, -1), Dimension(-1, 5), 10]) - ref_model = make_ref_pt_model_two_inputs(inp_shape) + ref_model = make_ref_pt_model_two_inputs(inp_shape, dtype=Type.dynamic) return scripted_fn, ref_model, {'input': [(inp_shape), (inp_shape)]} @@ -292,7 +297,7 @@ def create_pytorch_nn_module_layout_list(tmp_dir): ref_model.inputs[1].node.layout = Layout('nhwc') return pt_model, ref_model, { - 'input_shape': [shape, shape], 'layout': ['nchw', Layout('nhwc')], 'use_convert_model_from_mo': True + 'input': [(shape, np.float32), (shape, np.float32)], 'layout': ['nchw', Layout('nhwc')], 'use_convert_model_from_mo': True } @@ -307,30 +312,7 @@ def create_pytorch_nn_module_layout_list_case2(tmp_dir): ref_model.inputs[1].node.layout = Layout('nhwc') return pt_model, ref_model, { - 'input_shape': [shape, shape], 'layout': ('nchw', Layout('nhwc')), 'use_convert_model_from_mo': True} - - -def create_pytorch_nn_module_mean_list(tmp_dir): - pt_model = make_pt_model_two_inputs() - shape = [1, 10, 10, 3] - - shape = PartialShape(shape) - param1 = ov.opset8.parameter(shape) - param2 = ov.opset8.parameter(shape) - const1 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32) - const2 = ov.opset8.constant([[[[-0.0, -0.0, -0.0]]]], dtype=np.float32) - add1 = ov.opset8.add(param1, const1) - add2 = ov.opset8.add(param2, const2) - mul = ov.opset8.multiply(add1, add2) - relu = ov.opset8.relu(mul) - sigm = ov.opset8.sigmoid(relu) - - parameter_list = [param1, param2] - ref_model = Model([sigm], parameter_list, "test") - - return pt_model, ref_model, { - 'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': False, - 'use_convert_model_from_mo': True} + 'input': [(shape, np.float32), (shape, np.float32)], 'layout': ('nchw', Layout('nhwc')), 'use_convert_model_from_mo': True} def create_pytorch_nn_module_mean_list_compression_disabled(tmp_dir): @@ -351,7 +333,7 @@ def create_pytorch_nn_module_mean_list_compression_disabled(tmp_dir): parameter_list = [param1, param2] ref_model = Model([sigm], parameter_list, "test") - return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], + return pt_model, ref_model, {'input': [(shape, np.float32), (shape, np.float32)], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': False, 'use_convert_model_from_mo': True} @@ -375,7 +357,7 @@ def create_pytorch_nn_module_mean_list_compression_default(tmp_dir): parameter_list = [param1, param2] ref_model = Model([sigm], parameter_list, "test") - return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], + return pt_model, ref_model, {'input': [(shape, np.float32), (shape, np.float32)], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'use_convert_model_from_mo': True} @@ -403,32 +385,10 @@ def create_pytorch_nn_module_mean_list_compression_enabled(tmp_dir): ref_model = Model([sigm], parameter_list, "test") return pt_model, ref_model, { - 'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], + 'input': [(shape, np.float32), (shape, np.float32)], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': True, 'use_convert_model_from_mo': True} -def create_pytorch_nn_module_scale_list(tmp_dir): - pt_model = make_pt_model_two_inputs() - shape = [1, 10, 10, 3] - - shape = PartialShape(shape) - param1 = ov.opset8.parameter(shape) - param2 = ov.opset8.parameter(shape) - const1 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float32) - const2 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float32) - sub1 = ov.opset8.multiply(param1, const1) - sub2 = ov.opset8.multiply(param2, const2) - mul = ov.opset8.multiply(sub1, sub2) - relu = ov.opset8.relu(mul) - sigm = ov.opset8.sigmoid(relu) - - parameter_list = [param1, param2] - ref_model = Model([sigm], parameter_list, "test") - - return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], 'compress_to_fp16': False, - 'use_convert_model_from_mo': True} - - def create_pytorch_nn_module_scale_list_compression_disabled(tmp_dir): pt_model = make_pt_model_two_inputs() shape = [1, 10, 10, 3] @@ -447,7 +407,8 @@ def create_pytorch_nn_module_scale_list_compression_disabled(tmp_dir): parameter_list = [param1, param2] ref_model = Model([sigm], parameter_list, "test") - return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], + return pt_model, ref_model, {'input': [(shape, np.float32), (shape, np.float32)], + 'scale_values': [[1, 1, 1], [1, 1, 1]], 'compress_to_fp16': False, 'use_convert_model_from_mo': True} @@ -471,7 +432,8 @@ def create_pytorch_nn_module_scale_list_compression_default(tmp_dir): parameter_list = [param1, param2] ref_model = Model([sigm], parameter_list, "test") - return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], + return pt_model, ref_model, {'input': [(shape, np.float32), (shape, np.float32)], + 'scale_values': [[1, 1, 1], [1, 1, 1]], 'use_convert_model_from_mo': True} @@ -497,13 +459,14 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir): parameter_list = [param1, param2] ref_model = Model([sigm], parameter_list, "test") - return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], + return pt_model, ref_model, {'input': [(shape, np.float32), (shape, np.float32)], + 'scale_values': [[1, 1, 1], [1, 1, 1]], 'compress_to_fp16': True, 'use_convert_model_from_mo': True} def create_pytorch_nn_module_shapes_list_static(tmp_dir): pt_model = make_pt_model_two_inputs() - ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20]) + ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20], dtype=Type.dynamic) return pt_model, ref_model, {'input': [[1, 3, 20, 20], [1, 3, 20, 20]]} @@ -521,10 +484,11 @@ def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir): [-1, 3, 20, Dimension(-1, 20)]] param1 = ov.opset8.parameter(PartialShape( - inp_shapes[0]), name="x", dtype=np.float32) + inp_shapes[0]), name="x", dtype=Type.dynamic) param2 = ov.opset8.parameter(PartialShape( - inp_shapes[1]), name="y", dtype=np.float32) - mul = ov.opset8.multiply(param1, param2) + inp_shapes[1]), name="y", dtype=Type.dynamic) + cl = ov.opset8.convert_like(param2, param1) + mul = ov.opset8.multiply(param1, cl) relu = ov.opset8.relu(mul) sigm = ov.opset8.sigmoid(relu) @@ -548,13 +512,13 @@ def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir): parameter_list = [param1, param2] ref_model = Model([sigm], parameter_list, "test") - return pt_model, ref_model, {'input': [(inp_shapes[0],), (inp_shapes[1],)]} + return pt_model, ref_model, {'input': [(inp_shapes[0], Type.f32), (inp_shapes[1], Type.f32)]} def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir): pt_model = make_pt_model_one_input() inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)]] - ref_model = make_ref_pt_model_one_input(inp_shapes[0]) + ref_model = make_ref_pt_model_one_input(inp_shapes[0], dtype=Type.dynamic) return pt_model, ref_model, {'input': inp_shapes} @@ -568,7 +532,7 @@ def create_pytorch_nn_module_shapes_list_dynamic_single_input_via_input(tmp_dir) def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir): pt_model = make_pt_model_one_input() inp_shapes = [[1, 3, 20, 20]] - ref_model = make_ref_pt_model_one_input(inp_shapes[0]) + ref_model = make_ref_pt_model_one_input(inp_shapes[0], dtype=Type.dynamic) return pt_model, ref_model, {'input': inp_shapes} @@ -735,20 +699,6 @@ def create_pytorch_module_with_optional_inputs_case3(tmp_dir): return net, ref_model, {"example_input": example_input, "input": [[3, 3, 3, 3], [3, 3, 3, 3]]} -def create_pytorch_module_with_optional_inputs_case4(tmp_dir): - net = make_pt_model_with_optional_input() - ref_model = make_ref_pt_model_with_optional_inputs( - [3, 3, 3, 3], z_exist=True) - return net, ref_model, {"input": [("x", [3, 3, 3, 3]), ("z", [3, 3, 3, 3])]} - - -def create_pytorch_module_with_optional_inputs_case5(tmp_dir): - net = make_pt_model_with_optional_input() - ref_model = make_ref_pt_model_with_optional_inputs( - [1, 3, -1, -1], z_exist=True) - return net, ref_model, {"input": [("x",[1, 3, -1, -1]), ("z", [1, 3, -1, -1])]} - - def create_pytorch_module_with_compressed_int8_constant_compress_to_fp16_default(tmp_dir): import torch import torch.nn.functional as F @@ -1013,11 +963,9 @@ class TestMoConvertPyTorch(CommonMOConvertTest): create_pytorch_jit_script_function, create_pytorch_nn_module_layout_list, create_pytorch_nn_module_layout_list_case2, - create_pytorch_nn_module_mean_list, create_pytorch_nn_module_mean_list_compression_default, create_pytorch_nn_module_mean_list_compression_disabled, create_pytorch_nn_module_mean_list_compression_enabled, - create_pytorch_nn_module_scale_list, create_pytorch_nn_module_scale_list_compression_default, create_pytorch_nn_module_scale_list_compression_disabled, create_pytorch_nn_module_scale_list_compression_enabled, @@ -1039,8 +987,6 @@ class TestMoConvertPyTorch(CommonMOConvertTest): create_pytorch_module_with_optional_inputs_case1, create_pytorch_module_with_optional_inputs_case2, create_pytorch_module_with_optional_inputs_case3, - create_pytorch_module_with_optional_inputs_case4, - create_pytorch_module_with_optional_inputs_case5, create_pytorch_nn_module_with_scalar_input, create_pytorch_module_with_compressed_int8_constant, create_pytorch_module_with_compressed_int8_constant_compress_to_fp16_default, diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py index b42fa131225..7cb46d92300 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py @@ -30,7 +30,7 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args): "NNCF models produced by nncf<2.6 are not supported directly. Please export to ONNX first.") except: pass - inputs = prepare_torch_inputs(example_inputs, input_shape, args.get("input"), allow_none=True) + inputs = prepare_torch_inputs(example_inputs) decoder = TorchScriptPythonDecoder(model, example_input=inputs) args['input_model'] = decoder args["framework"] = "pytorch" @@ -151,36 +151,7 @@ def to_torch_tensor(tensor): "Got {}".format(type(tensor))) -def get_torch_dtype(dtype): - import torch - ov_str_to_torch = { - "boolean": torch.bool, - "f16": torch.float16, - "f32": torch.float32, - "f64": torch.float64, - "i8": torch.int8, - "i16": torch.int16, - "i32": torch.int32, - "i64": torch.int64, - "u8": torch.uint8, - } - if dtype is None: - return torch.float - if isinstance(dtype, torch.dtype): - return dtype - if isinstance(dtype, (type, np.dtype)): - dtype = get_element_type_str(dtype) - if isinstance(dtype, Type): - dtype = dtype.get_type_name() - if isinstance(dtype, str): - str_dtype = ov_str_to_torch.get(dtype) - if str_dtype is None: - raise Error(f"Unexpected data type '{dtype}' for input") - return str_dtype - raise Error(f"Unexpected data type for input. Supported torch.dtype, numpy.dtype, ov.Type and str. Got {type(dtype)}") - - -def prepare_torch_inputs(example_inputs, input_shape, input_info=None, allow_none=False): +def prepare_torch_inputs(example_inputs): import torch inputs = None if example_inputs is not None: @@ -201,29 +172,7 @@ def prepare_torch_inputs(example_inputs, input_shape, input_info=None, allow_non inputs[name] = to_torch_tensor(tensor) else: inputs = to_torch_tensor(inputs) - elif input_info is not None or input_shape is not None: - input_info = input_to_input_cut_info(input_info) or [] - input_shape_to_input_cut_info(input_shape, input_info) - inputs = [] - inputs_with_names = {} - for inp in input_info: - shape = inp.shape - if shape is None: - if not allow_none: - raise Error("Please provide input_shape or example_input for all inputs converting PyTorch model.") - inputs = None - break - dtype = get_torch_dtype(inp.type) - static_shape = get_static_shape(shape, dynamic_value=1) - input_tensor = torch.zeros(static_shape, dtype=dtype) # pylint: disable=no-member - if inp.name is not None: - inputs_with_names[inp.name] = input_tensor - inputs.append(input_tensor) - if isinstance(inputs, list): - inputs = tuple(inputs) - if inputs is not None and len(inputs) == len(inputs_with_names): - inputs = inputs_with_names else: - if not allow_none: - raise Error("Please provide input_shape or example_input for converting PyTorch model.") + # No example_input were provided, decoder will use scripting + return None return inputs diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index 3bb6c928f3a..89c5ce11ae5 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -30,7 +30,7 @@ def get_pytorch_decoder(model, example_inputs, args): "NNCF models produced by nncf<2.6 are not supported directly. Please export to ONNX first.") except: pass - inputs = prepare_torch_inputs(example_inputs, args.get("input"), allow_none=True) + inputs = prepare_torch_inputs(example_inputs) decoder = TorchScriptPythonDecoder(model, example_input=inputs) args['input_model'] = decoder args["example_input"] = inputs @@ -150,36 +150,7 @@ def to_torch_tensor(tensor): "Got {}".format(type(tensor))) -def get_torch_dtype(dtype): - import torch - ov_str_to_torch = { - "boolean": torch.bool, - "f16": torch.float16, - "f32": torch.float32, - "f64": torch.float64, - "i8": torch.int8, - "i16": torch.int16, - "i32": torch.int32, - "i64": torch.int64, - "u8": torch.uint8, - } - if dtype is None: - return torch.float - if isinstance(dtype, torch.dtype): - return dtype - if isinstance(dtype, (type, np.dtype)): - dtype = get_element_type_str(dtype) - if isinstance(dtype, Type): - dtype = dtype.get_type_name() - if isinstance(dtype, str): - str_dtype = ov_str_to_torch.get(dtype) - if str_dtype is None: - raise Error(f"Unexpected data type '{dtype}' for input") - return str_dtype - raise Error(f"Unexpected data type for input. Supported torch.dtype, numpy.dtype, ov.Type and str. Got {type(dtype)}") - - -def prepare_torch_inputs(example_inputs, input_info=None, allow_none=False): +def prepare_torch_inputs(example_inputs): import torch inputs = None if example_inputs is not None: @@ -200,28 +171,7 @@ def prepare_torch_inputs(example_inputs, input_info=None, allow_none=False): inputs[name] = to_torch_tensor(tensor) else: inputs = to_torch_tensor(inputs) - elif input_info is not None: - input_info = input_to_input_cut_info(input_info) or [] - inputs = [] - inputs_with_names = {} - for inp in input_info: - shape = inp.shape - if shape is None: - if not allow_none: - raise Error("Please provide shape in `input` or `example_input` for all inputs converting PyTorch model.") - inputs = None - break - dtype = get_torch_dtype(inp.type) - static_shape = get_static_shape(shape, dynamic_value=1) - input_tensor = torch.zeros(static_shape, dtype=dtype) # pylint: disable=no-member - if inp.name is not None: - inputs_with_names[inp.name] = input_tensor - inputs.append(input_tensor) - if isinstance(inputs, list): - inputs = tuple(inputs) - if inputs is not None and len(inputs) == len(inputs_with_names): - inputs = inputs_with_names else: - if not allow_none: - raise Error("Please provide shapes `input` or `example_input` for converting PyTorch model.") + # No example_input were provided, decoder will use scripting + return None return inputs