[PT FE] Separate tracing and scripting modes (#19676)

* [PT FE] Separate scripting and tracing in decoder

* Fix convert_model to accept decoder

* Some fixes

* Fix code style

* Fix preprocessor tests

* Fix tests

* Fix tests

* Fix more tests

* Fix ovc tests
This commit is contained in:
Maxim Vafin 2023-09-12 10:40:20 +02:00 committed by GitHub
parent 514f9864af
commit a6bc78dd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 213 additions and 211 deletions

View File

@ -7,29 +7,15 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, torch_tensor_to_ov_const
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor
from openvino.runtime import opset11 as ops
import typing
import torch
import numpy as np
wrapper_template = """
import torch
from typing import *
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, {input_sign}):
return self.model({example_input})
"""
class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False):
Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
@ -38,16 +24,18 @@ class TorchScriptPythonDecoder (Decoder):
self._input_is_list = False
if graph_element is None:
try:
pt_module = self._get_scripted_model(pt_module, example_input)
pt_module = self._get_scripted_model(pt_module, example_input, skip_freeze)
except Exception as e:
if example_input is not None:
msg = "tracing or scripting"
help_msg = ""
msg = "tracing"
help_msg = "Please check correctness of provided 'example_input'. "
"Sometimes models can be converted in scripted mode, please try running "
"conversion without 'example_input'."
else:
msg = "scripting"
help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument. "
help_msg = "\nTracing sometimes provide better results, please provide valid 'example_input' argument."
raise RuntimeError(
f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n {help_msg}"
f"Couldn't get TorchScript module by {msg}. With exception:\n{e}\n{help_msg} "
"You can also provide TorchScript module that you obtained"
" yourself, please refer to PyTorch documentation: "
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.")
@ -82,74 +70,10 @@ class TorchScriptPythonDecoder (Decoder):
preserved_attributes.append(name)
return preserved_attributes
def _get_scripted_model(self, pt_module, example_inputs=None):
def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False):
import torch
import inspect
def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
if input_name in inputs:
ordered_inputs.append(input_name)
input_signature = list(input_params)
if ordered_inputs == input_signature[:len(ordered_inputs)]:
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
if all([isinstance(inp, torch.Tensor) for inp in example_inputs]):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
return {"example_inputs": example_inputs}, ordered_inputs, model
# PyTorch has some difficulties to trace models with named unordered parameters:
# torch < 2.0.0 supports only positional arguments for tracing
# pytorch == 2.0.0 supports input kwargs tracing,
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
# We will use wrapper for making them positional as workaround.
input_sign_str = []
input_params_str = []
for input_name in ordered_inputs:
if str(input_params[input_name].annotation).startswith("typing.Union"):
filter_custom_args = []
for arg in input_params[input_name].annotation.__args__:
str_arg = str(arg)
is_typing = str_arg.startswith("typing.")
is_torch = "torch." in str_arg
is_builten = str_arg in (str(int), str(float), str(type(None)))
if not (is_typing or is_torch or is_builten):
continue
filter_custom_args.append(arg)
input_params[input_name].annotation.__args__ = tuple(filter_custom_args)
input_sign_str.append(str(input_params[input_name]).replace("NoneType", "None"))
input_params_str.append(f"{input_name}={input_name}")
wrapper_class = wrapper_template.format(input_sign=', '.join(input_sign_str), example_input=', '.join(input_params_str))
result = {}
try:
exec(wrapper_class, result)
wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
wrapped_model = model
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
input_signature = list(input_params)
if isinstance(inputs, dict):
return process_dict_inputs(inputs, input_params, model)
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "typing.List" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
self._input_is_list = True
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model
if isinstance(pt_module, torch.nn.Module):
pt_module.eval()
input_signature = None
@ -160,32 +84,23 @@ class TorchScriptPythonDecoder (Decoder):
if example_inputs is None:
scripted = torch.jit.script(pt_module)
else:
input_parameters, input_signature, pt_module = prepare_example_inputs_and_model(example_inputs, input_params, pt_module)
try:
scripted = torch.jit.trace(pt_module, **input_parameters)
except Exception:
try:
scripted = torch.jit.script(pt_module)
except Exception as se:
try:
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
except Exception as te:
raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}")
skip_freeze = False
for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models
if "quantize" in n.kind():
# do not freeze quantized models
skip_freeze = True
break
elif "aten::to" in n.kind():
first_input = next(n.inputs())
if first_input.node().kind() == "prim::Constant":
ivalue = first_input.toIValue()
if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]:
# do not freeze models with compressed constants
skip_freeze = True
break
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(example_inputs, input_params, pt_module)
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
if not skip_freeze:
for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models
if "quantize" in n.kind():
# do not freeze quantized models
skip_freeze = True
break
elif "aten::to" in n.kind():
first_input = next(n.inputs())
if first_input.node().kind() == "prim::Constant":
ivalue = first_input.toIValue()
if isinstance(ivalue, torch.Tensor) and ivalue.dtype in [torch.bfloat16, torch.float16]:
# do not freeze models with compressed constants
skip_freeze = True
break
if not skip_freeze:
preserved_attrs = self._get_preserved_attributes(scripted)
f_model = torch.jit.freeze(scripted, preserved_attrs=preserved_attrs)
@ -331,36 +246,6 @@ class TorchScriptPythonDecoder (Decoder):
node.set_friendly_name(name)
return node
@staticmethod
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
# need to represent as Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply (scale)
qscheme = qtensor.qscheme() # torch.per_channel_affine (per_tensor)
if qscheme == torch.per_channel_affine:
int8_tensor = qtensor.int_repr()
scale = qtensor.q_per_channel_scales().numpy().astype(np.float32) # (weight.q_scale() for per_tensor)
zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32) # (weight.q_zero_point() for per_tensor)
axis = np.int32(qtensor.q_per_channel_axis())
new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32)
new_shape[axis] = -1
zero_point_bc = np.reshape(zero_point, new_shape)
scale_bc = np.reshape(scale, new_shape)
int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point_bc)
return ops.multiply(sub, scale_bc).outputs()
elif qscheme == torch.per_tensor_affine:
int8_tensor = qtensor.int_repr()
scale = np.float32(qtensor.q_scale())
zero_point = np.float32(qtensor.q_zero_point())
int8_const = torch_tensor_to_ov_const(int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs()
assert False, "Unsupported qscheme"
def try_decode_get_attr(self):
pt_value = get_value_from_getattr(self.graph_element, self.pt_module)
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
@ -368,7 +253,7 @@ class TorchScriptPythonDecoder (Decoder):
# We assume this is __torch__.torch.classes.quantized.Conv2dPackedParamsBase or __torch__.torch.classes.quantized.LinearPackedParamsBase
# TODO: but can be anything. Figure a better way to distinguish
weight, bias = pt_value.unpack()
res = self.convert_quantized_tensor(weight, self._shared_memory)
res = convert_quantized_tensor(weight, self._shared_memory)
if isinstance(bias, torch.Tensor):
res += ivalue_to_constant(bias)
else:

View File

@ -10,9 +10,10 @@ import numpy as np
import ctypes
from openvino.runtime import op, Type as OVType, Shape, Tensor
from openvino.runtime import opset11 as ops
def maybe_convert_max_int(value : int):
def maybe_convert_max_int(value: int):
# FIXME: This is a convertion from 64-bit positive max integer value
# to 32-bit positive max integer value. Find a better way to handle this.
if value == torch.iinfo(torch.int64).max:
@ -20,10 +21,12 @@ def maybe_convert_max_int(value : int):
else:
return value
def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs)
def fetch_attr(self_module, target : str):
def fetch_attr(self_module, target: str):
"""
Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
@ -37,7 +40,8 @@ def fetch_attr(self_module, target : str):
attr_itr = self_module
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
@ -84,6 +88,7 @@ def ivalue_to_constant(ivalue, shared_memory=True):
return torch_tensor_to_ov_const(ivalue, shared_memory=shared_memory).outputs()
return None
def get_value_from_getattr(getattr_node, self_module):
assert getattr_node.kind() == "prim::GetAttr", "Got node of kind not equal to prim::GetAttr"
# GetAttr nodes can be nested
@ -98,10 +103,12 @@ def get_value_from_getattr(getattr_node, self_module):
while len(stack) > 0:
node = stack.pop()
attr_name = node.s("name")
assert hasattr(module, attr_name), f"No attribute with name \"{attr_name}\" found in module."
assert hasattr(
module, attr_name), f"No attribute with name \"{attr_name}\" found in module."
module = getattr(module, attr_name)
return module
pt_to_ov_type_map = {
"float": OVType.f32,
"int": OVType.i32,
@ -131,3 +138,121 @@ ov_to_c_type_map = {
OVType.i32: ctypes.c_int,
OVType.i64: ctypes.c_int64,
}
wrapper_template = """
import torch
from typing import *
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, {input_sign}):
return self.model({example_input})
"""
def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
if input_name in inputs:
ordered_inputs.append(input_name)
input_signature = list(input_params)
if ordered_inputs == input_signature[:len(ordered_inputs)]:
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
if all([isinstance(inp, torch.Tensor) for inp in example_inputs]):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
return {"example_inputs": example_inputs}, ordered_inputs, model
# PyTorch has some difficulties to trace models with named unordered parameters:
# torch < 2.0.0 supports only positional arguments for tracing
# pytorch == 2.0.0 supports input kwargs tracing,
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
# We will use wrapper for making them positional as workaround.
input_sign_str = []
input_params_str = []
for input_name in ordered_inputs:
if str(input_params[input_name].annotation).startswith("typing.Union"):
filter_custom_args = []
for arg in input_params[input_name].annotation.__args__:
str_arg = str(arg)
is_typing = str_arg.startswith("typing.")
is_torch = "torch." in str_arg
is_builten = str_arg in (str(int), str(float), str(type(None)))
if not (is_typing or is_torch or is_builten):
continue
filter_custom_args.append(arg)
input_params[input_name].annotation.__args__ = tuple(
filter_custom_args)
input_sign_str.append(
str(input_params[input_name]).replace("NoneType", "None"))
input_params_str.append(f"{input_name}={input_name}")
wrapper_class = wrapper_template.format(input_sign=', '.join(
input_sign_str), example_input=', '.join(input_params_str))
result = {}
try:
exec(wrapper_class, result)
wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
wrapped_model = model
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
input_is_list = False
input_signature = list(input_params)
if isinstance(inputs, dict):
examples, ordered, wrapped = process_dict_inputs(inputs, input_params, model)
return examples, ordered, wrapped, input_is_list
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "typing.List" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
input_is_list = True
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model, input_is_list
def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
# represents torch quantized tensor as
# Constant(u8) -> Convert(f32) -> Subtract(zero_point) -> Multiply(scale)
qscheme = qtensor.qscheme()
if qscheme == torch.per_channel_affine:
int8_tensor = qtensor.int_repr()
scale = qtensor.q_per_channel_scales().numpy().astype(np.float32)
zero_point = qtensor.q_per_channel_zero_points().numpy().astype(np.float32)
axis = np.int32(qtensor.q_per_channel_axis())
new_shape = np.ones(len(int8_tensor.shape), dtype=np.int32)
new_shape[axis] = -1
zero_point_bc = np.reshape(zero_point, new_shape)
scale_bc = np.reshape(scale, new_shape)
int8_const = torch_tensor_to_ov_const(
int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point_bc)
return ops.multiply(sub, scale_bc).outputs()
elif qscheme == torch.per_tensor_affine:
int8_tensor = qtensor.int_repr()
scale = np.float32(qtensor.q_scale())
zero_point = np.float32(qtensor.q_zero_point())
int8_const = torch_tensor_to_ov_const(
int8_tensor, shared_memory=shared_memory)
convert = ops.convert(int8_const, np.float32)
sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs()
assert False, "Unsupported qscheme"

View File

@ -32,8 +32,7 @@ class Convnet(torch.nn.Module):
def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3):
torch_model = Convnet(input_channels)
example_input = Tensor(np.expand_dims(test_input, axis=0).astype(np.float32))
ov_model = convert_model(torch_model, example_input=example_input)
ov_model = convert_model(torch_model)
core = Core()
ov_model = PreprocessConverter.from_torchvision(

View File

@ -841,7 +841,7 @@ def create_pytorch_module_with_nested_inputs2(tmp_dir):
add = ov.opset10.add(concat1, param0)
ref_model = Model([concat2, add], [param0, param1, param2], "test")
return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))},
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 5)))},
"compress_to_fp16": False}
@ -867,7 +867,7 @@ def create_pytorch_module_with_nested_inputs3(tmp_dir):
add = ov.opset10.add(concat1, param3)
ref_model = Model([concat2, add], [param1, param2, param3], "test")
return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))},
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 3)))},
"compress_to_fp16": False}
@ -895,7 +895,7 @@ def create_pytorch_module_with_nested_inputs4(tmp_dir):
mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param3, param1, param2, param4], "test")
return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))),
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 10))),
"y": torch.ones((1,))},
"compress_to_fp16": False}
@ -924,7 +924,7 @@ def create_pytorch_module_with_nested_inputs5(tmp_dir):
mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param0, param1, param2, param4], "test")
return net, ref_model, {
"example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))],
"example_input": [torch.ones((1, 10)), (torch.zeros((1, 9)), torch.ones((1, 5, 10))), torch.ones((1,))],
"compress_to_fp16": False}
@ -1268,9 +1268,4 @@ class TestPrecisionSensitive():
fw_res = fw_model(*torch_inp_tensors)
ov_res = core.compile_model(ir_test)(example_inputs)
if precision == 'FP32':
custom_eps = 1e-4
else:
custom_eps = 1e-3
npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=custom_eps)
npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=1e-3, rtol=1e-3)

View File

@ -843,7 +843,7 @@ def create_pytorch_module_with_nested_inputs2(tmp_dir):
add = ov.opset10.add(concat1, param0)
ref_model = Model([concat2, add], [param0, param1, param2], "test")
return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))},
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 5)))},
"compress_to_fp16": False}
@ -869,7 +869,7 @@ def create_pytorch_module_with_nested_inputs3(tmp_dir):
add = ov.opset10.add(concat1, param3)
ref_model = Model([concat2, add], [param1, param2, param3], "test")
return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))},
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 3)))},
"compress_to_fp16": False}
@ -897,7 +897,7 @@ def create_pytorch_module_with_nested_inputs4(tmp_dir):
mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param3, param1, param2, param4], "test")
return net, ref_model, {
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))),
"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 9)), torch.ones((1, 5, 10))),
"y": torch.ones((1,))},
"compress_to_fp16": False}
@ -926,7 +926,7 @@ def create_pytorch_module_with_nested_inputs5(tmp_dir):
mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param0, param1, param2, param4], "test")
return net, ref_model, {
"example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))],
"example_input": [torch.ones((1, 10)), (torch.zeros((1, 9)), torch.ones((1, 5, 10))), torch.ones((1,))],
"compress_to_fp16": False}

View File

@ -77,18 +77,16 @@ class PytorchLayerTest:
self.torch_compile_backend_test(model, torch_inputs, custom_eps)
else:
with torch.no_grad():
model.eval()
trace_model = kwargs.get('trace_model', False)
freeze_model = kwargs.get('freeze_model', True)
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
graph = model.inlined_graph
smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
if kind is not None and not isinstance(kind, (tuple, list)):
kind = [kind]
if kind is not None:
for op in kind:
assert self._check_kind_exist(
graph, op), f"Operation {op} type doesn't exist in provided graph"
if kind is not None and not isinstance(kind, (tuple, list)):
kind = [kind]
if kind is not None:
for op in kind:
assert self._check_kind_exist(
smodel.inlined_graph, op), f"Operation {op} type doesn't exist in provided graph"
# OV infer:
core = Core()
compiled = core.compile_model(converted_model, ie_device)
@ -99,7 +97,7 @@ class PytorchLayerTest:
return
# Framework infer:
fw_res = model(*deepcopy(torch_inputs))
fw_res = smodel(*deepcopy(torch_inputs))
if not isinstance(fw_res, (tuple)):
fw_res = (fw_res,)
@ -162,47 +160,36 @@ class PytorchLayerTest:
def _prepare_input(self):
raise RuntimeError("Please provide inputs generation function")
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs):
import torch
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
from openvino.tools.ovc import convert_model
kwargs = {"example_input": example_input if len(
example_input) > 1 else example_input[0], "compress_to_fp16": False}
with torch.no_grad():
if trace_model:
model = torch.jit.trace(model, example_input)
else:
model = torch.jit.script(model)
model = torch.jit.freeze(model)
print(model)
if not dynamic_shapes:
input_shapes = [inp.shape for inp in ov_inputs]
kwargs["input_shape"] = input_shapes
om = convert_model(model, **kwargs)
kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]}
if trace_model:
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
else:
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
smodel = decoder.pt_module
print(smodel.inlined_graph)
if not dynamic_shapes:
input_shapes = [inp.shape for inp in ov_inputs]
kwargs["input"] = input_shapes
om = convert_model(decoder, **kwargs)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return model, om
return smodel, om
def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
import torch
fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework('pytorch')
model.eval()
with torch.no_grad():
if trace_model:
model = torch.jit.trace(model, example_input)
else:
model = torch.jit.script(model)
if freeze_model:
_model = torch.jit.freeze(model)
if trace_model:
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
else:
_model = model
print(_model.inlined_graph)
decoder = TorchScriptPythonDecoder(_model)
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
smodel = decoder.pt_module
print(smodel.inlined_graph)
im = fe.load(decoder)
om = fe.convert(im)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return model, om
return smodel, om
def _resolve_input_shape_dtype(self, om, ov_inputs, dynamic_shapes):
params = list(om.inputs)

View File

@ -111,7 +111,7 @@ class TestAddTypes(PytorchLayerTest):
self.rhs_type = rhs_type
self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
ie_device, precision, ir_version)
ie_device, precision, ir_version, freeze_model=False, trace_model=True)
class TestAddLists(PytorchLayerTest):

View File

@ -28,11 +28,16 @@ class TestAliases(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_alias(self, ie_device, precision, ir_version):
self._test(aten_alias(), None, [
"aten::slice", "aten::select", "aten::copy_"], ie_device, precision, ir_version)
self._test(aten_alias(), None, ["aten::slice",
"aten::select",
"aten::copy_"],
ie_device, precision, ir_version)
@pytest.mark.nightly
@pytest.mark.precommit
def test_loop_alias(self, ie_device, precision, ir_version):
self._test(aten_loop_alias(), None, [
"aten::slice", "aten::select", "aten::copy_", "prim::Loop"], ie_device, precision, ir_version)
self._test(aten_loop_alias(), None, ["aten::slice",
"aten::select",
"aten::copy_",
"prim::Loop"],
ie_device, precision, ir_version, freeze_model=False)

View File

@ -100,4 +100,4 @@ class TestMulTypes(PytorchLayerTest):
self.rhs_type = rhs_type
self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
ie_device, precision, ir_version)
ie_device, precision, ir_version, freeze_model=False, trace_model=True)

View File

@ -32,7 +32,10 @@ def get_pytorch_decoder(model, input_shape, example_inputs, args):
raise RuntimeError(
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
inputs = prepare_torch_inputs(example_inputs)
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
if not isinstance(model, TorchScriptPythonDecoder):
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
else:
decoder = model
args['input_model'] = decoder
args["framework"] = "pytorch"
args["example_input"] = inputs

View File

@ -32,7 +32,10 @@ def get_pytorch_decoder(model, example_inputs, args):
raise RuntimeError(
"NNCF models produced by nncf<2.6 are not supported directly. Please upgrade nncf or export to ONNX first.")
inputs = prepare_torch_inputs(example_inputs)
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
if not isinstance(model, TorchScriptPythonDecoder):
decoder = TorchScriptPythonDecoder(model, example_input=inputs, shared_memory=args.get("share_weights", True))
else:
decoder = model
args['input_model'] = decoder
args["example_input"] = inputs