[PT FE] Separate tracing and scripting modes (#19676)

* [PT FE] Separate scripting and tracing in decoder

* Fix convert_model to accept decoder

* Some fixes

* Fix code style

* Fix preprocessor tests

* Fix tests

* Fix tests

* Fix more tests

* Fix ovc tests
This commit is contained in:
Maxim Vafin 2023-09-12 10:40:20 +02:00 committed by GitHub
parent 514f9864af
commit a6bc78dd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 213 additions and 211 deletions

View File

@ -7,29 +7,15 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny from openvino.runtime import op, PartialShape, Type as OVType, OVAny
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, torch_tensor_to_ov_const from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor
from openvino.runtime import opset11 as ops from openvino.runtime import opset11 as ops
import typing import typing
import torch import torch
import numpy as np
wrapper_template = """
import torch
from typing import *
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, {input_sign}):
return self.model({example_input})
"""
class TorchScriptPythonDecoder (Decoder): class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True): def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False):
Decoder.__init__(self) Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = [] self.m_decoders = []
@ -38,16 +24,18 @@ class TorchScriptPythonDecoder (Decoder):
self._input_is_list = False self._input_is_list = False
if graph_element is None: if graph_element is None:
try: try:
pt_module = self._get_scripted_model(pt_module, example_input) pt_module = self._get_scripted_model(pt_module, example_input, skip_freeze)
except Exception as e: except Exception as e:
if example_input is not None: if example_input is not None:
msg = "tracing or scripting" msg = "tracing"
help_msg = "" help_msg = "Please check correctness of provided 'example_input'. "
"Sometimes models can be converted in scripted mode, please try running "
"conversion without 'example_input'."
else: else:
msg = "scripting" msg = "scripting"
help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument. " help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument."
raise RuntimeError( raise RuntimeError(
f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n {help_msg}" f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n{help_msg} "
"You can also provide TorchScript module that you obtained" "You can also provide TorchScript module that you obtained"
" yourself, please refer to PyTorch documentation: " " yourself, please refer to PyTorch documentation: "
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.") "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.")
@ -82,74 +70,10 @@ class TorchScriptPythonDecoder (Decoder):
preserved_attributes.append(name) preserved_attributes.append(name)
return preserved_attributes return preserved_attributes
def _get_scripted_model(self, pt_module, example_inputs=None): def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False):
import torch import torch
import inspect import inspect
def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
if input_name in inputs:
ordered_inputs.append(input_name)
input_signature = list(input_params)
if ordered_inputs == input_signature[:len(ordered_inputs)]:
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
if all([isinstance(inp, torch.Tensor) for inp in example_inputs]):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
return {"example_inputs": example_inputs}, ordered_inputs, model
# PyTorch has some difficulties to trace models with named unordered parameters:
# torch < 2.0.0 supports only positional arguments for tracing
# pytorch == 2.0.0 supports input kwargs tracing,
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
# We will use wrapper for making them positional as workaround.
input_sign_str = []
input_params_str = []
for input_name in ordered_inputs:
if str(input_params[input_name].annotation).startswith("typing.Union"):
filter_custom_args = []
for arg in input_params[input_name].annotation.__args__:
str_arg = str(arg)
is_typing = str_arg.startswith("typing.")
is_torch = "torch." in str_arg
is_builten = str_arg in (str(int), str(float), str(type(None)))
if not (is_typing or is_torch or is_builten):
continue
filter_custom_args.append(arg)
input_params[input_name].annotation.__args__ = tuple(filter_custom_args)
input_sign_str.append(str(input_params[input_name]).replace("NoneType", "None"))
input_params_str.append(f"{input_name}={input_name}")
wrapper_class = wrapper_template.format(input_sign=', '.join(input_sign_str), example_input=', '.join(input_params_str))
result = {}
try:
exec(wrapper_class, result)
wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
wrapped_model = model
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
input_signature = list(input_params)
if isinstance(inputs, dict):
return process_dict_inputs(inputs, input_params, model)
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "typing.List" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
self._input_is_list = True
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model
if isinstance(pt_module, torch.nn.Module): if isinstance(pt_module, torch.nn.Module):
pt_module.eval() pt_module.eval()
input_signature = None input_signature = None
@ -160,32 +84,23 @@ class TorchScriptPythonDecoder (Decoder):
if example_inputs is None: if example_inputs is None:
scripted = torch.jit.script(pt_module) scripted = torch.jit.script(pt_module)
else: else:
input_parameters, input_signature, pt_module = prepare_example_inputs_and_model(example_inputs, input_params, pt_module) input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(example_inputs, input_params, pt_module)
try: scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
scripted = torch.jit.trace(pt_module, **input_parameters) if not skip_freeze:
except Exception: for n in scripted.inlined_graph.nodes():
try: # TODO: switch off freezing for all traced models
scripted = torch.jit.script(pt_module) if "quantize" in n.kind():
except Exception as se: # do not freeze quantized models
try: skip_freeze = True
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) break
except Exception as te: elif "aten::to" in n.kind():
raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}") first_input = next(n.inputs())
skip_freeze = False if first_input.node().kind() == "prim::Constant":
for n in scripted.inlined_graph.nodes(): ivalue = first_input.toIValue()
# TODO: switch off freezing for all traced models if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]:
if "quantize" in n.kind(): # do not freeze models with compressed constants
# do not freeze quantized models skip_freeze = True
skip_freeze = True break
break
elif "aten::to" in n.kind():
first_input = next(n.inputs())
if first_input.node().kind() == "prim::Constant":
ivalue = first_input.toIValue()
if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]:
# do not freeze models with compressed constants
skip_freeze = True
break
if not skip_freeze: if not skip_freeze:
preserved_attrs = self._get_preserved_attributes(scripted) preserved_attrs = self._get_preserved_attributes(scripted)
f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs) f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs)
@ -331,36 +246,6 @@ class TorchScriptPythonDecoder (Decoder):
node.set_friendly_name(name) node.set_friendly_name(name)
return node return node
@staticmethod
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
# need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale)
qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor)
if qscheme == torch.per_channel_affine:
int8_tensor = qtensor.int_repr()
scale = qtensor.q_per_channel_scales().numpy().astype(np.float32) # (weight.q_scale() for per_tensor)
zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32) # (weight.q_zero_point() for per_tensor)
axis = np.int32(qtensor.q_per_channel_axis())
new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32)
new_shape[axis] = -1
zero_point_bc = np.reshape(zero_point, new_shape)
scale_bc = np.reshape(scale, new_shape)
int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point_bc)
return ops.multiply(sub, scale_bc).outputs()
elif qscheme == torch.per_tensor_affine:
int8_tensor = qtensor.int_repr()
scale = np.float32(qtensor.q_scale())
zero_point = np.float32(qtensor.q_zero_point())
int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs()
assert False, "Unsupported qscheme"
def try_decode_get_attr(self): def try_decode_get_attr(self):
pt_value = get_value_from_getattr(self.graph_element, self.pt_module) pt_value = get_value_from_getattr(self.graph_element, self.pt_module)
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr" assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
@ -368,7 +253,7 @@ class TorchScriptPythonDecoder (Decoder):
# We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase # We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase
# TODO: but can be anything. Figure a better way to distinguish # TODO: but can be anything. Figure a better way to distinguish
weight, bias = pt_value.unpack() weight, bias = pt_value.unpack()
res = self.convert_quantized_tensor(weight, self._shared_memory) res = convert_quantized_tensor(weight, self._shared_memory)
if isinstance(bias, torch.Tensor): if isinstance(bias, torch.Tensor):
res += ivalue_to_constant(bias) res += ivalue_to_constant(bias)
else: else:

View File

@ -10,9 +10,10 @@ import numpy as np
import ctypes import ctypes
from openvino.runtime import op, Type as OVType, Shape, Tensor from openvino.runtime import op, Type as OVType, Shape, Tensor
from openvino.runtime import opset11 as ops
def maybe_convert_max_int(value : int): def maybe_convert_max_int(value: int):
# FIXME: This is a convertion from 64-bit positive max integer value # FIXME: This is a convertion from 64-bit positive max integer value
# to 32-bit positive max integer value. Find a better way to handle this. # to 32-bit positive max integer value. Find a better way to handle this.
if value == torch.iinfo(torch.int64).max: if value == torch.iinfo(torch.int64).max:
@ -20,10 +21,12 @@ def maybe_convert_max_int(value : int):
else: else:
return value return value
def make_constant(*args, **kwargs): def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs) return op.Constant(*args, **kwargs)
def fetch_attr(self_module, target : str):
def fetch_attr(self_module, target: str):
""" """
Fetch an attribute from the ``Module`` hierarchy of ``self.module``. Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
@ -37,7 +40,8 @@ def fetch_attr(self_module, target : str):
attr_itr = self_module attr_itr = self_module
for i, atom in enumerate(target_atoms): for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom): if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom) attr_itr = getattr(attr_itr, atom)
return attr_itr return attr_itr
@ -84,6 +88,7 @@ def ivalue_to_constant(ivalue, shared_memory=True):
return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs() return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
return None return None
def get_value_from_getattr(getattr_node, self_module): def get_value_from_getattr(getattr_node, self_module):
assert getattr_node.kind() == "prim::GetAttr", "Got node of kind not equal to prim::GetAttr" assert getattr_node.kind() == "prim::GetAttr", "Got node of kind not equal to prim::GetAttr"
# GetAttr nodes can be nested # GetAttr nodes can be nested
@ -98,10 +103,12 @@ def get_value_from_getattr(getattr_node, self_module):
while len(stack) > 0: while len(stack) > 0:
node = stack.pop() node = stack.pop()
attr_name = node.s("name") attr_name = node.s("name")
assert hasattr(module, attr_name), f"No attribute with name \"{attr_name}\" found in module." assert hasattr(
module, attr_name), f"No attribute with name \"{attr_name}\" found in module."
module = getattr(module, attr_name) module = getattr(module, attr_name)
return module return module
pt_to_ov_type_map = { pt_to_ov_type_map = {
"float": OVType.f32, "float": OVType.f32,
"int": OVType.i32, "int": OVType.i32,
@ -131,3 +138,121 @@ ov_to_c_type_map = {
OVType.i32: ctypes.c_int, OVType.i32: ctypes.c_int,
OVType.i64: ctypes.c_int64, OVType.i64: ctypes.c_int64,
} }
wrapper_template = """
import torch
from typing import *
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, {input_sign}):
return self.model({example_input})
"""
def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
if input_name in inputs:
ordered_inputs.append(input_name)
input_signature = list(input_params)
if ordered_inputs == input_signature[:len(ordered_inputs)]:
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
if all([isinstance(inp, torch.Tensor) for inp in example_inputs]):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
return {"example_inputs": example_inputs}, ordered_inputs, model
# PyTorch has some difficulties to trace models with named unordered parameters:
# torch < 2.0.0 supports only positional arguments for tracing
# pytorch == 2.0.0 supports input kwargs tracing,
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
# We will use wrapper for making them positional as workaround.
input_sign_str = []
input_params_str = []
for input_name in ordered_inputs:
if str(input_params[input_name].annotation).startswith("typing.Union"):
filter_custom_args = []
for arg in input_params[input_name].annotation.__args__:
str_arg = str(arg)
is_typing = str_arg.startswith("typing.")
is_torch = "torch." in str_arg
is_builten = str_arg in (str(int), str(float), str(type(None)))
if not (is_typing or is_torch or is_builten):
continue
filter_custom_args.append(arg)
input_params[input_name].annotation.__args__ = tuple(
filter_custom_args)
input_sign_str.append(
str(input_params[input_name]).replace("NoneType", "None"))
input_params_str.append(f"{input_name}={input_name}")
wrapper_class = wrapper_template.format(input_sign=', '.join(
input_sign_str), example_input=', '.join(input_params_str))
result = {}
try:
exec(wrapper_class, result)
wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
wrapped_model = model
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
input_is_list = False
input_signature = list(input_params)
if isinstance(inputs, dict):
examples, ordered, wrapped = process_dict_inputs(inputs, input_params, model)
return examples, ordered, wrapped, input_is_list
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "typing.List" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
input_is_list = True
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model, input_is_list
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
# represents torch quantized tensor as
# Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply(scale)
qscheme = qtensor.qscheme()
if qscheme == torch.per_channel_affine:
int8_tensor = qtensor.int_repr()
scale = qtensor.q_per_channel_scales().numpy().astype(np.float32)
zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32)
axis = np.int32(qtensor.q_per_channel_axis())
new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32)
new_shape[axis] = -1
zero_point_bc = np.reshape(zero_point, new_shape)
scale_bc = np.reshape(scale, new_shape)
int8_const = torch_tensor_to_ov_const(
int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point_bc)
return ops.multiply(sub, scale_bc).outputs()
elif qscheme == torch.per_tensor_affine:
int8_tensor = qtensor.int_repr()
scale = np.float32(qtensor.q_scale())
zero_point = np.float32(qtensor.q_zero_point())
int8_const = torch_tensor_to_ov_const(
int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs()
assert False, "Unsupported qscheme"

View File

@ -32,8 +32,7 @@ class Convnet(torch.nn.Module):
def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3): def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3):
torch_model = Convnet(input_channels) torch_model = Convnet(input_channels)
example_input = Tensor(np.expand_dims(test_input, axis=0).astype(np.float32)) ov_model = convert_model(torch_model)
ov_model = convert_model(torch_model, example_input=example_input)
core = Core() core = Core()
ov_model = PreprocessConverter.from_torchvision( ov_model = PreprocessConverter.from_torchvision(

View File

@ -841,7 +841,7 @@ def create_pytorch_module_with_nested_inputs2(tmp_dir):
add = ov.opset10.add(concat1, param0) add = ov.opset10.add(concat1, param0)
ref_model = Model([concat2, add], [param0, param1, param2], "test") ref_model = Model([concat2, add], [param0, param1, param2], "test")
return net, ref_model, { return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))}, "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 5)))},
"compress_to_fp16": False} "compress_to_fp16": False}
@ -867,7 +867,7 @@ def create_pytorch_module_with_nested_inputs3(tmp_dir):
add = ov.opset10.add(concat1, param3) add = ov.opset10.add(concat1, param3)
ref_model = Model([concat2, add], [param1, param2, param3], "test") ref_model = Model([concat2, add], [param1, param2, param3], "test")
return net, ref_model, { return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))}, "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 3)))},
"compress_to_fp16": False} "compress_to_fp16": False}
@ -895,7 +895,7 @@ def create_pytorch_module_with_nested_inputs4(tmp_dir):
mul = ov.opset10.multiply(concat2, param4) mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param3, param1, param2, param4], "test") ref_model = Model([mul, add], [param3, param1, param2, param4], "test")
return net, ref_model, { return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))), "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 10))),
"y": torch.ones((1,))}, "y": torch.ones((1,))},
"compress_to_fp16": False} "compress_to_fp16": False}
@ -924,7 +924,7 @@ def create_pytorch_module_with_nested_inputs5(tmp_dir):
mul = ov.opset10.multiply(concat2, param4) mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param0, param1, param2, param4], "test") ref_model = Model([mul, add], [param0, param1, param2, param4], "test")
return net, ref_model, { return net, ref_model, {
"example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))], "example_input": [torch.ones((1, 10)), (torch.zeros((1, 9)), torch.ones((1, 5, 10))), torch.ones((1,))],
"compress_to_fp16": False} "compress_to_fp16": False}
@ -1268,9 +1268,4 @@ class TestPrecisionSensitive():
fw_res = fw_model(*torch_inp_tensors) fw_res = fw_model(*torch_inp_tensors)
ov_res = core.compile_model(ir_test)(example_inputs) ov_res = core.compile_model(ir_test)(example_inputs)
if precision == 'FP32': npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=1e-3, rtol=1e-3)
custom_eps = 1e-4
else:
custom_eps = 1e-3
npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=custom_eps)

View File

@ -843,7 +843,7 @@ def create_pytorch_module_with_nested_inputs2(tmp_dir):
add = ov.opset10.add(concat1, param0) add = ov.opset10.add(concat1, param0)
ref_model = Model([concat2, add], [param0, param1, param2], "test") ref_model = Model([concat2, add], [param0, param1, param2], "test")
return net, ref_model, { return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))}, "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 5)))},
"compress_to_fp16": False} "compress_to_fp16": False}
@ -869,7 +869,7 @@ def create_pytorch_module_with_nested_inputs3(tmp_dir):
add = ov.opset10.add(concat1, param3) add = ov.opset10.add(concat1, param3)
ref_model = Model([concat2, add], [param1, param2, param3], "test") ref_model = Model([concat2, add], [param1, param2, param3], "test")
return net, ref_model, { return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))}, "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 3)))},
"compress_to_fp16": False} "compress_to_fp16": False}
@ -897,7 +897,7 @@ def create_pytorch_module_with_nested_inputs4(tmp_dir):
mul = ov.opset10.multiply(concat2, param4) mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param3, param1, param2, param4], "test") ref_model = Model([mul, add], [param3, param1, param2, param4], "test")
return net, ref_model, { return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))), "example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 10))),
"y": torch.ones((1,))}, "y": torch.ones((1,))},
"compress_to_fp16": False} "compress_to_fp16": False}
@ -926,7 +926,7 @@ def create_pytorch_module_with_nested_inputs5(tmp_dir):
mul = ov.opset10.multiply(concat2, param4) mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param0, param1, param2, param4], "test") ref_model = Model([mul, add], [param0, param1, param2, param4], "test")
return net, ref_model, { return net, ref_model, {
"example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))], "example_input": [torch.ones((1, 10)), (torch.zeros((1, 9)), torch.ones((1, 5, 10))), torch.ones((1,))],
"compress_to_fp16": False} "compress_to_fp16": False}

View File

@ -77,18 +77,16 @@ class PytorchLayerTest:
self.torch_compile_backend_test(model, torch_inputs, custom_eps) self.torch_compile_backend_test(model, torch_inputs, custom_eps)
else: else:
with torch.no_grad(): with torch.no_grad():
model.eval()
trace_model = kwargs.get('trace_model', False) trace_model = kwargs.get('trace_model', False)
freeze_model = kwargs.get('freeze_model', True) freeze_model = kwargs.get('freeze_model', True)
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
graph = model.inlined_graph
if kind is not None and not isinstance(kind, (tuple, list)): if kind is not None and not isinstance(kind, (tuple, list)):
kind = [kind] kind = [kind]
if kind is not None: if kind is not None:
for op in kind: for op in kind:
assert self._check_kind_exist( assert self._check_kind_exist(
graph, op), f"Operation {op} type doesn't exist in provided graph" smodel.inlined_graph, op), f"Operation {op} type doesn't exist in provided graph"
# OV infer: # OV infer:
core = Core() core = Core()
compiled = core.compile_model(converted_model, ie_device) compiled = core.compile_model(converted_model, ie_device)
@ -99,7 +97,7 @@ class PytorchLayerTest:
return return
# Framework infer: # Framework infer:
fw_res = model(*deepcopy(torch_inputs)) fw_res = smodel(*deepcopy(torch_inputs))
if not isinstance(fw_res, (tuple)): if not isinstance(fw_res, (tuple)):
fw_res = (fw_res,) fw_res = (fw_res,)
@ -162,47 +160,36 @@ class PytorchLayerTest:
def _prepare_input(self): def _prepare_input(self):
raise RuntimeError("Please provide inputs generation function") raise RuntimeError("Please provide inputs generation function")
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs): def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
import torch
from openvino.tools.ovc import convert_model from openvino.tools.ovc import convert_model
kwargs = {"example_input": example_input if len( kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]}
example_input) > 1 else example_input[0], "compress_to_fp16": False} if trace_model:
with torch.no_grad(): decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
if trace_model: else:
model = torch.jit.trace(model, example_input) decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
else: smodel = decoder.pt_module
model = torch.jit.script(model) print(smodel.inlined_graph)
model = torch.jit.freeze(model) if not dynamic_shapes:
print(model) input_shapes = [inp.shape for inp in ov_inputs]
if not dynamic_shapes: kwargs["input"] = input_shapes
input_shapes = [inp.shape for inp in ov_inputs] om = convert_model(decoder, **kwargs)
kwargs["input_shape"] = input_shapes
om = convert_model(model, **kwargs)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes) self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return model, om return smodel, om
def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model): def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
import torch
fe_manager = FrontEndManager() fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework('pytorch') fe = fe_manager.load_by_framework('pytorch')
model.eval() if trace_model:
with torch.no_grad(): decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
if trace_model:
model = torch.jit.trace(model, example_input)
else:
model = torch.jit.script(model)
if freeze_model:
_model = torch.jit.freeze(model)
else: else:
_model = model decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
print(_model.inlined_graph) smodel = decoder.pt_module
decoder = TorchScriptPythonDecoder(_model) print(smodel.inlined_graph)
im = fe.load(decoder) im = fe.load(decoder)
om = fe.convert(im) om = fe.convert(im)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes) self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return model, om return smodel, om
def _resolve_input_shape_dtype(self, om, ov_inputs, dynamic_shapes): def _resolve_input_shape_dtype(self, om, ov_inputs, dynamic_shapes):
params = list(om.inputs) params = list(om.inputs)

View File

@ -111,7 +111,7 @@ class TestAddTypes(PytorchLayerTest):
self.rhs_type = rhs_type self.rhs_type = rhs_type
self.rhs_shape = rhs_shape self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
ie_device, precision, ir_version) ie_device, precision, ir_version, freeze_model=False, trace_model=True)
class TestAddLists(PytorchLayerTest): class TestAddLists(PytorchLayerTest):

View File

@ -28,11 +28,16 @@ class TestAliases(PytorchLayerTest):
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.precommit @pytest.mark.precommit
def test_alias(self, ie_device, precision, ir_version): def test_alias(self, ie_device, precision, ir_version):
self._test(aten_alias(), None, [ self._test(aten_alias(), None, ["aten::slice",
"aten::slice", "aten::select", "aten::copy_"], ie_device, precision, ir_version) "aten::select",
"aten::copy_"],
ie_device, precision, ir_version)
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.precommit @pytest.mark.precommit
def test_loop_alias(self, ie_device, precision, ir_version): def test_loop_alias(self, ie_device, precision, ir_version):
self._test(aten_loop_alias(), None, [ self._test(aten_loop_alias(), None, ["aten::slice",
"aten::slice", "aten::select", "aten::copy_", "prim::Loop"], ie_device, precision, ir_version) "aten::select",
"aten::copy_",
"prim::Loop"],
ie_device, precision, ir_version, freeze_model=False)

View File

@ -100,4 +100,4 @@ class TestMulTypes(PytorchLayerTest):
self.rhs_type = rhs_type self.rhs_type = rhs_type
self.rhs_shape = rhs_shape self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
ie_device, precision, ir_version) ie_device, precision, ir_version, freeze_model=False, trace_model=True)

View File

@ -32,7 +32,10 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args):
raise RuntimeError( raise RuntimeError(
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.") "NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
inputs = prepare_torch_inputs(example_inputs) inputs = prepare_torch_inputs(example_inputs)
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) if not isinstance(model, TorchScriptPythonDecoder):
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
else:
decoder = model
args['input_model'] = decoder args['input_model'] = decoder
args["framework"] = "pytorch" args["framework"] = "pytorch"
args["example_input"] = inputs args["example_input"] = inputs

View File

@ -32,7 +32,10 @@ def get_pytorch_decoder(model, example_inputs, args):
raise RuntimeError( raise RuntimeError(
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.") "NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
inputs = prepare_torch_inputs(example_inputs) inputs = prepare_torch_inputs(example_inputs)
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True)) if not isinstance(model, TorchScriptPythonDecoder):
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
else:
decoder = model
args['input_model'] = decoder args['input_model'] = decoder
args["example_input"] = inputs args["example_input"] = inputs