From a6bc78dd0f6f3d5580dcb5b64da9380213ff4528 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 12 Sep 2023 10:40:20 +0200 Subject: [PATCH] [PT FE] Separate tracing and scripting modes (#19676) * [PT FE] Separate scripting and tracing in decoder * Fix convert_model to accept decoder * Some fixes * Fix code style * Fix preprocessor tests * Fix tests * Fix tests * Fix more tests * Fix ovc tests --- .../openvino/frontend/pytorch/ts_decoder.py | 171 +++--------------- .../src/openvino/frontend/pytorch/utils.py | 133 +++++++++++++- .../test_preprocessor.py | 3 +- .../test_mo_convert_pytorch.py | 15 +- .../ovc_python_api_tests/test_pytorch.py | 8 +- .../pytorch_tests/pytorch_layer_test_class.py | 67 +++---- tests/layer_tests/pytorch_tests/test_add.py | 2 +- .../layer_tests/pytorch_tests/test_aliases.py | 13 +- tests/layer_tests/pytorch_tests/test_mul.py | 2 +- .../mo/moc_frontend/pytorch_frontend_utils.py | 5 +- .../moc_frontend/pytorch_frontend_utils.py | 5 +- 11 files changed, 213 insertions(+), 211 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 5ac0f797efe..32e62084e89 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -7,29 +7,15 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType from openvino.runtime import op, PartialShape, Type as OVType, OVAny -from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, torch_tensor_to_ov_const +from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor from openvino.runtime import opset11 as ops import typing import torch -import numpy as np - -wrapper_template = """ -import torch -from typing import * - -class ModelWrapper(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, {input_sign}): - return self.model({example_input}) -""" class TorchScriptPythonDecoder (Decoder): - def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True): + def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False): Decoder.__init__(self) # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted self.m_decoders = [] @@ -38,16 +24,18 @@ class TorchScriptPythonDecoder (Decoder): self._input_is_list = False if graph_element is None: try: - pt_module = self._get_scripted_model(pt_module, example_input) + pt_module = self._get_scripted_model(pt_module, example_input, skip_freeze) except Exception as e: if example_input is not None: - msg = "tracing or scripting" - help_msg = "" + msg = "tracing" + help_msg = "Please check correctness of provided 'example_input'. " + "Sometimes models can be converted in scripted mode, please try running " + "conversion without 'example_input'." else: msg = "scripting" - help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument. " + help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument." raise RuntimeError( - f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n {help_msg}" + f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n{help_msg} " "You can also provide TorchScript module that you obtained" " yourself, please refer to PyTorch documentation: " "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.") @@ -82,74 +70,10 @@ class TorchScriptPythonDecoder (Decoder): preserved_attributes.append(name) return preserved_attributes - def _get_scripted_model(self, pt_module, example_inputs=None): + def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False): import torch import inspect - def process_dict_inputs(inputs, input_params, model): - ordered_inputs = [] - for input_name in input_params: - if input_name in inputs: - ordered_inputs.append(input_name) - - input_signature = list(input_params) - if ordered_inputs == input_signature[:len(ordered_inputs)]: - example_inputs = [inputs[input_name] for input_name in ordered_inputs] - if all([isinstance(inp, torch.Tensor) for inp in example_inputs]): - return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model - return {"example_inputs": example_inputs}, ordered_inputs, model - - # PyTorch has some difficulties to trace models with named unordered parameters: - # torch < 2.0.0 supports only positional arguments for tracing - # pytorch == 2.0.0 supports input kwargs tracing, - # but does not support complex nested objects (e. g. tuple of tuples of tensors) - # We will use wrapper for making them positional as workaround. - - input_sign_str = [] - input_params_str = [] - - for input_name in ordered_inputs: - if str(input_params[input_name].annotation).startswith("typing.Union"): - filter_custom_args = [] - for arg in input_params[input_name].annotation.__args__: - str_arg = str(arg) - is_typing = str_arg.startswith("typing.") - is_torch = "torch." in str_arg - is_builten = str_arg in (str(int), str(float), str(type(None))) - if not (is_typing or is_torch or is_builten): - continue - filter_custom_args.append(arg) - input_params[input_name].annotation.__args__ = tuple(filter_custom_args) - input_sign_str.append(str(input_params[input_name]).replace("NoneType", "None")) - input_params_str.append(f"{input_name}={input_name}") - - wrapper_class = wrapper_template.format(input_sign=', '.join(input_sign_str), example_input=', '.join(input_params_str)) - result = {} - try: - exec(wrapper_class, result) - - wrapped_model = result["ModelWrapper"](model) - wrapped_model.eval() - # if wrapping failed, it is better to return original model for avoid user confusion regarding error message - except Exception: - wrapped_model = model - - return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model - - def prepare_example_inputs_and_model(inputs, input_params, model): - input_signature = list(input_params) - if isinstance(inputs, dict): - return process_dict_inputs(inputs, input_params, model) - if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor): - if "typing.List" in str(input_params[input_signature[0]].annotation): - inputs = inputs[0].unsqueeze(0) - self._input_is_list = True - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - input_signature = input_signature[:len(inputs)] - return {"example_inputs": inputs}, input_signature, model - if isinstance(pt_module, torch.nn.Module): pt_module.eval() input_signature = None @@ -160,32 +84,23 @@ class TorchScriptPythonDecoder (Decoder): if example_inputs is None: scripted = torch.jit.script(pt_module) else: - input_parameters, input_signature, pt_module = prepare_example_inputs_and_model(example_inputs, input_params, pt_module) - try: - scripted = torch.jit.trace(pt_module, **input_parameters) - except Exception: - try: - scripted = torch.jit.script(pt_module) - except Exception as se: - try: - scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) - except Exception as te: - raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}") - skip_freeze = False - for n in scripted.inlined_graph.nodes(): - # TODO: switch off freezing for all traced models - if "quantize" in n.kind(): - # do not freeze quantized models - skip_freeze = True - break - elif "aten::to" in n.kind(): - first_input = next(n.inputs()) - if first_input.node().kind() == "prim::Constant": - ivalue = first_input.toIValue() - if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]: - # do not freeze models with compressed constants - skip_freeze = True - break + input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(example_inputs, input_params, pt_module) + scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) + if not skip_freeze: + for n in scripted.inlined_graph.nodes(): + # TODO: switch off freezing for all traced models + if "quantize" in n.kind(): + # do not freeze quantized models + skip_freeze = True + break + elif "aten::to" in n.kind(): + first_input = next(n.inputs()) + if first_input.node().kind() == "prim::Constant": + ivalue = first_input.toIValue() + if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]: + # do not freeze models with compressed constants + skip_freeze = True + break if not skip_freeze: preserved_attrs = self._get_preserved_attributes(scripted) f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs) @@ -331,36 +246,6 @@ class TorchScriptPythonDecoder (Decoder): node.set_friendly_name(name) return node - @staticmethod - def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool): - # need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale) - qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor) - if qscheme == torch.per_channel_affine: - int8_tensor = qtensor.int_repr() - scale = qtensor.q_per_channel_scales().numpy().astype(np.float32) # (weight.q_scale() for per_tensor) - zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32) # (weight.q_zero_point() for per_tensor) - axis = np.int32(qtensor.q_per_channel_axis()) - - new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32) - new_shape[axis] = -1 - zero_point_bc = np.reshape(zero_point, new_shape) - scale_bc = np.reshape(scale, new_shape) - - int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory) - convert = ops.convert(int8_const, np.float32) - sub = ops.subtract(convert, zero_point_bc) - return ops.multiply(sub, scale_bc).outputs() - elif qscheme == torch.per_tensor_affine: - int8_tensor = qtensor.int_repr() - scale = np.float32(qtensor.q_scale()) - zero_point = np.float32(qtensor.q_zero_point()) - - int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory) - convert = ops.convert(int8_const, np.float32) - sub = ops.subtract(convert, zero_point) - return ops.multiply(sub, scale).outputs() - assert False, "Unsupported qscheme" - def try_decode_get_attr(self): pt_value = get_value_from_getattr(self.graph_element, self.pt_module) assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr" @@ -368,7 +253,7 @@ class TorchScriptPythonDecoder (Decoder): # We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase # TODO: but can be anything. Figure a better way to distinguish weight, bias = pt_value.unpack() - res = self.convert_quantized_tensor(weight, self._shared_memory) + res = convert_quantized_tensor(weight, self._shared_memory) if isinstance(bias, torch.Tensor): res += ivalue_to_constant(bias) else: diff --git a/src/bindings/python/src/openvino/frontend/pytorch/utils.py b/src/bindings/python/src/openvino/frontend/pytorch/utils.py index 3c658119bb1..97d237fb0ef 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/utils.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/utils.py @@ -10,9 +10,10 @@ import numpy as np import ctypes from openvino.runtime import op, Type as OVType, Shape, Tensor +from openvino.runtime import opset11 as ops -def maybe_convert_max_int(value : int): +def maybe_convert_max_int(value: int): # FIXME: This is a convertion from 64-bit positive max integer value # to 32-bit positive max integer value. Find a better way to handle this. if value == torch.iinfo(torch.int64).max: @@ -20,10 +21,12 @@ def maybe_convert_max_int(value : int): else: return value + def make_constant(*args, **kwargs): return op.Constant(*args, **kwargs) -def fetch_attr(self_module, target : str): + +def fetch_attr(self_module, target: str): """ Fetch an attribute from the ``Module`` hierarchy of ``self.module``. @@ -37,7 +40,8 @@ def fetch_attr(self_module, target : str): attr_itr = self_module for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") attr_itr = getattr(attr_itr, atom) return attr_itr @@ -84,6 +88,7 @@ def ivalue_to_constant(ivalue, shared_memory=True): return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs() return None + def get_value_from_getattr(getattr_node, self_module): assert getattr_node.kind() == "prim::GetAttr", "Got node of kind not equal to prim::GetAttr" # GetAttr nodes can be nested @@ -98,10 +103,12 @@ def get_value_from_getattr(getattr_node, self_module): while len(stack) > 0: node = stack.pop() attr_name = node.s("name") - assert hasattr(module, attr_name), f"No attribute with name \"{attr_name}\" found in module." + assert hasattr( + module, attr_name), f"No attribute with name \"{attr_name}\" found in module." module = getattr(module, attr_name) return module + pt_to_ov_type_map = { "float": OVType.f32, "int": OVType.i32, @@ -131,3 +138,121 @@ ov_to_c_type_map = { OVType.i32: ctypes.c_int, OVType.i64: ctypes.c_int64, } + + +wrapper_template = """ +import torch +from typing import * + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, {input_sign}): + return self.model({example_input}) +""" + + +def process_dict_inputs(inputs, input_params, model): + ordered_inputs = [] + for input_name in input_params: + if input_name in inputs: + ordered_inputs.append(input_name) + + input_signature = list(input_params) + if ordered_inputs == input_signature[:len(ordered_inputs)]: + example_inputs = [inputs[input_name] for input_name in ordered_inputs] + if all([isinstance(inp, torch.Tensor) for inp in example_inputs]): + return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model + return {"example_inputs": example_inputs}, ordered_inputs, model + + # PyTorch has some difficulties to trace models with named unordered parameters: + # torch < 2.0.0 supports only positional arguments for tracing + # pytorch == 2.0.0 supports input kwargs tracing, + # but does not support complex nested objects (e. g. tuple of tuples of tensors) + # We will use wrapper for making them positional as workaround. + + input_sign_str = [] + input_params_str = [] + + for input_name in ordered_inputs: + if str(input_params[input_name].annotation).startswith("typing.Union"): + filter_custom_args = [] + for arg in input_params[input_name].annotation.__args__: + str_arg = str(arg) + is_typing = str_arg.startswith("typing.") + is_torch = "torch." in str_arg + is_builten = str_arg in (str(int), str(float), str(type(None))) + if not (is_typing or is_torch or is_builten): + continue + filter_custom_args.append(arg) + input_params[input_name].annotation.__args__ = tuple( + filter_custom_args) + input_sign_str.append( + str(input_params[input_name]).replace("NoneType", "None")) + input_params_str.append(f"{input_name}={input_name}") + + wrapper_class = wrapper_template.format(input_sign=', '.join( + input_sign_str), example_input=', '.join(input_params_str)) + result = {} + try: + exec(wrapper_class, result) + + wrapped_model = result["ModelWrapper"](model) + wrapped_model.eval() + # if wrapping failed, it is better to return original model for avoid user confusion regarding error message + except Exception: + wrapped_model = model + + return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model + + +def prepare_example_inputs_and_model(inputs, input_params, model): + input_is_list = False + input_signature = list(input_params) + if isinstance(inputs, dict): + examples, ordered, wrapped = process_dict_inputs(inputs, input_params, model) + return examples, ordered, wrapped, input_is_list + if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor): + if "typing.List" in str(input_params[input_signature[0]].annotation): + inputs = inputs[0].unsqueeze(0) + input_is_list = True + + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + input_signature = input_signature[:len(inputs)] + return {"example_inputs": inputs}, input_signature, model, input_is_list + + +def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool): + # represents torch quantized tensor as + # Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply(scale) + qscheme = qtensor.qscheme() + if qscheme == torch.per_channel_affine: + int8_tensor = qtensor.int_repr() + scale = qtensor.q_per_channel_scales().numpy().astype(np.float32) + zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32) + axis = np.int32(qtensor.q_per_channel_axis()) + + new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32) + new_shape[axis] = -1 + zero_point_bc = np.reshape(zero_point, new_shape) + scale_bc = np.reshape(scale, new_shape) + + int8_const = torch_tensor_to_ov_const( + int8_tensor, shared_memory=shared_memory) + convert = ops.convert(int8_const, np.float32) + sub = ops.subtract(convert, zero_point_bc) + return ops.multiply(sub, scale_bc).outputs() + elif qscheme == torch.per_tensor_affine: + int8_tensor = qtensor.int_repr() + scale = np.float32(qtensor.q_scale()) + zero_point = np.float32(qtensor.q_zero_point()) + + int8_const = torch_tensor_to_ov_const( + int8_tensor, shared_memory=shared_memory) + convert = ops.convert(int8_const, np.float32) + sub = ops.subtract(convert, zero_point) + return ops.multiply(sub, scale).outputs() + assert False, "Unsupported qscheme" diff --git a/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py b/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py index 59f2b458ce7..a1cdc41f610 100644 --- a/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py +++ b/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py @@ -32,8 +32,7 @@ class Convnet(torch.nn.Module): def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3): torch_model = Convnet(input_channels) - example_input = Tensor(np.expand_dims(test_input, axis=0).astype(np.float32)) - ov_model = convert_model(torch_model, example_input=example_input) + ov_model = convert_model(torch_model) core = Core() ov_model = PreprocessConverter.from_torchvision( diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py index 9a863a12d70..6eab63bf682 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py @@ -841,7 +841,7 @@ def create_pytorch_module_with_nested_inputs2(tmp_dir): add = ov.opset10.add(concat1, param0) ref_model = Model([concat2, add], [param0, param1, param2], "test") return net, ref_model, { - "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))}, + "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 5)))}, "compress_to_fp16": False} @@ -867,7 +867,7 @@ def create_pytorch_module_with_nested_inputs3(tmp_dir): add = ov.opset10.add(concat1, param3) ref_model = Model([concat2, add], [param1, param2, param3], "test") return net, ref_model, { - "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))}, + "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 3)))}, "compress_to_fp16": False} @@ -895,7 +895,7 @@ def create_pytorch_module_with_nested_inputs4(tmp_dir): mul = ov.opset10.multiply(concat2, param4) ref_model = Model([mul, add], [param3, param1, param2, param4], "test") return net, ref_model, { - "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))), + "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 10))), "y": torch.ones((1,))}, "compress_to_fp16": False} @@ -924,7 +924,7 @@ def create_pytorch_module_with_nested_inputs5(tmp_dir): mul = ov.opset10.multiply(concat2, param4) ref_model = Model([mul, add], [param0, param1, param2, param4], "test") return net, ref_model, { - "example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))], + "example_input": [torch.ones((1, 10)), (torch.zeros((1, 9)), torch.ones((1, 5, 10))), torch.ones((1,))], "compress_to_fp16": False} @@ -1268,9 +1268,4 @@ class TestPrecisionSensitive(): fw_res = fw_model(*torch_inp_tensors) ov_res = core.compile_model(ir_test)(example_inputs) - if precision == 'FP32': - custom_eps = 1e-4 - else: - custom_eps = 1e-3 - - npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=custom_eps) + npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=1e-3, rtol=1e-3) diff --git a/tests/layer_tests/ovc_python_api_tests/test_pytorch.py b/tests/layer_tests/ovc_python_api_tests/test_pytorch.py index 302144d59c3..268f69d13f0 100644 --- a/tests/layer_tests/ovc_python_api_tests/test_pytorch.py +++ b/tests/layer_tests/ovc_python_api_tests/test_pytorch.py @@ -843,7 +843,7 @@ def create_pytorch_module_with_nested_inputs2(tmp_dir): add = ov.opset10.add(concat1, param0) ref_model = Model([concat2, add], [param0, param1, param2], "test") return net, ref_model, { - "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))}, + "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 5)))}, "compress_to_fp16": False} @@ -869,7 +869,7 @@ def create_pytorch_module_with_nested_inputs3(tmp_dir): add = ov.opset10.add(concat1, param3) ref_model = Model([concat2, add], [param1, param2, param3], "test") return net, ref_model, { - "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))}, + "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 3)))}, "compress_to_fp16": False} @@ -897,7 +897,7 @@ def create_pytorch_module_with_nested_inputs4(tmp_dir): mul = ov.opset10.multiply(concat2, param4) ref_model = Model([mul, add], [param3, param1, param2, param4], "test") return net, ref_model, { - "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))), + "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 10))), "y": torch.ones((1,))}, "compress_to_fp16": False} @@ -926,7 +926,7 @@ def create_pytorch_module_with_nested_inputs5(tmp_dir): mul = ov.opset10.multiply(concat2, param4) ref_model = Model([mul, add], [param0, param1, param2, param4], "test") return net, ref_model, { - "example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))], + "example_input": [torch.ones((1, 10)), (torch.zeros((1, 9)), torch.ones((1, 5, 10))), torch.ones((1,))], "compress_to_fp16": False} diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 283a90942b4..0f5638ea8c8 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -77,18 +77,16 @@ class PytorchLayerTest: self.torch_compile_backend_test(model, torch_inputs, custom_eps) else: with torch.no_grad(): - model.eval() trace_model = kwargs.get('trace_model', False) freeze_model = kwargs.get('freeze_model', True) - model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) - graph = model.inlined_graph + smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) - if kind is not None and not isinstance(kind, (tuple, list)): - kind = [kind] - if kind is not None: - for op in kind: - assert self._check_kind_exist( - graph, op), f"Operation {op} type doesn't exist in provided graph" + if kind is not None and not isinstance(kind, (tuple, list)): + kind = [kind] + if kind is not None: + for op in kind: + assert self._check_kind_exist( + smodel.inlined_graph, op), f"Operation {op} type doesn't exist in provided graph" # OV infer: core = Core() compiled = core.compile_model(converted_model, ie_device) @@ -99,7 +97,7 @@ class PytorchLayerTest: return # Framework infer: - fw_res = model(*deepcopy(torch_inputs)) + fw_res = smodel(*deepcopy(torch_inputs)) if not isinstance(fw_res, (tuple)): fw_res = (fw_res,) @@ -162,47 +160,36 @@ class PytorchLayerTest: def _prepare_input(self): raise RuntimeError("Please provide inputs generation function") - def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs): - import torch + def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model): from openvino.tools.ovc import convert_model - kwargs = {"example_input": example_input if len( - example_input) > 1 else example_input[0], "compress_to_fp16": False} - with torch.no_grad(): - if trace_model: - model = torch.jit.trace(model, example_input) - else: - model = torch.jit.script(model) - model = torch.jit.freeze(model) - print(model) - if not dynamic_shapes: - input_shapes = [inp.shape for inp in ov_inputs] - kwargs["input_shape"] = input_shapes - om = convert_model(model, **kwargs) + kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]} + if trace_model: + decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model) + else: + decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model) + smodel = decoder.pt_module + print(smodel.inlined_graph) + if not dynamic_shapes: + input_shapes = [inp.shape for inp in ov_inputs] + kwargs["input"] = input_shapes + om = convert_model(decoder, **kwargs) self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes) - return model, om + return smodel, om def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model): - import torch - fe_manager = FrontEndManager() fe = fe_manager.load_by_framework('pytorch') - model.eval() - with torch.no_grad(): - if trace_model: - model = torch.jit.trace(model, example_input) - else: - model = torch.jit.script(model) - if freeze_model: - _model = torch.jit.freeze(model) + if trace_model: + decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model) else: - _model = model - print(_model.inlined_graph) - decoder = TorchScriptPythonDecoder(_model) + decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model) + smodel = decoder.pt_module + print(smodel.inlined_graph) im = fe.load(decoder) om = fe.convert(im) self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes) - return model, om + return smodel, om def _resolve_input_shape_dtype(self, om, ov_inputs, dynamic_shapes): params = list(om.inputs) diff --git a/tests/layer_tests/pytorch_tests/test_add.py b/tests/layer_tests/pytorch_tests/test_add.py index c13cfbcd363..8c3026a9c2c 100644 --- a/tests/layer_tests/pytorch_tests/test_add.py +++ b/tests/layer_tests/pytorch_tests/test_add.py @@ -111,7 +111,7 @@ class TestAddTypes(PytorchLayerTest): self.rhs_type = rhs_type self.rhs_shape = rhs_shape self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), - ie_device, precision, ir_version) + ie_device, precision, ir_version, freeze_model=False, trace_model=True) class TestAddLists(PytorchLayerTest): diff --git a/tests/layer_tests/pytorch_tests/test_aliases.py b/tests/layer_tests/pytorch_tests/test_aliases.py index 1919aab5fbc..78f323b4a2d 100644 --- a/tests/layer_tests/pytorch_tests/test_aliases.py +++ b/tests/layer_tests/pytorch_tests/test_aliases.py @@ -28,11 +28,16 @@ class TestAliases(PytorchLayerTest): @pytest.mark.nightly @pytest.mark.precommit def test_alias(self, ie_device, precision, ir_version): - self._test(aten_alias(), None, [ - "aten::slice", "aten::select", "aten::copy_"], ie_device, precision, ir_version) + self._test(aten_alias(), None, ["aten::slice", + "aten::select", + "aten::copy_"], + ie_device, precision, ir_version) @pytest.mark.nightly @pytest.mark.precommit def test_loop_alias(self, ie_device, precision, ir_version): - self._test(aten_loop_alias(), None, [ - "aten::slice", "aten::select", "aten::copy_", "prim::Loop"], ie_device, precision, ir_version) + self._test(aten_loop_alias(), None, ["aten::slice", + "aten::select", + "aten::copy_", + "prim::Loop"], + ie_device, precision, ir_version, freeze_model=False) diff --git a/tests/layer_tests/pytorch_tests/test_mul.py b/tests/layer_tests/pytorch_tests/test_mul.py index 02a17e8c38d..8e958f09569 100644 --- a/tests/layer_tests/pytorch_tests/test_mul.py +++ b/tests/layer_tests/pytorch_tests/test_mul.py @@ -100,4 +100,4 @@ class TestMulTypes(PytorchLayerTest): self.rhs_type = rhs_type self.rhs_shape = rhs_shape self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), - ie_device, precision, ir_version) + ie_device, precision, ir_version, freeze_model=False, trace_model=True) diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py index 5b11f8c6998..214fbbc4ff7 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py @@ -32,7 +32,10 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args): raise RuntimeError( "NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.") inputs = prepare_torch_inputs(example_inputs) - decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) + if not isinstance(model, TorchScriptPythonDecoder): + decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) + else: + decoder = model args['input_model'] = decoder args["framework"] = "pytorch" args["example_input"] = inputs diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index 2703af43d8b..8baf75354f9 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -32,7 +32,10 @@ def get_pytorch_decoder(model, example_inputs, args): raise RuntimeError( "NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.") inputs = prepare_torch_inputs(example_inputs) - decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) + if not isinstance(model, TorchScriptPythonDecoder): + decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) + else: + decoder = model args['input_model'] = decoder args["example_input"] = inputs