[PT FE]: support nested inputs in example_inputs and arg dicts with d… (#18492)

* [PT FE]: support nested inputs in example_inputs and arg dicts with different argtypes

* accept hande lists as inputs

* Update tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py

* update tests and add comments in code

* fix for custom types in annotations and duplicate in mo

* Update tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py
This commit is contained in:
Ekaterina Aidova
2023-07-19 19:01:22 +04:00
committed by GitHub
parent 186b1b6bfc
commit 61504bbfc2
4 changed files with 251 additions and 35 deletions

View File

@@ -13,6 +13,19 @@ from packaging.version import parse
import torch
import numpy as np
wrapper_template="""
import torch
from typing import *
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, {input_sign}):
return self.model({example_input})
"""
def get_type_from_py_type(value):
if isinstance(value, float):
@@ -142,38 +155,76 @@ class TorchScriptPythonDecoder (Decoder):
import torch
import inspect
def prepare_example_inputs(inputs, input_signature):
is_torch_2 = parse(torch.__version__) >= parse("2.0.0")
if isinstance(inputs, dict):
ordered_inputs = []
if input_signature is not None:
used_sign = []
for key in input_signature:
if key not in inputs:
def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
if input_name in inputs:
ordered_inputs.append(input_name)
input_signature = list(input_params)
if ordered_inputs == input_signature[:len(ordered_inputs)]:
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
if all([isinstance(inp, torch.Tensor) for inp in example_inputs]):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
return {"example_inputs": example_inputs}, ordered_inputs, model
# PyTorch has some difficulties to trace models with named unordered parameters:
# torch < 2.0.0 supports only positional arguments for tracing
# pytorch == 2.0.0 supports input kwargs tracing,
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
# We will use wrapper for making them positional as workaround.
input_sign_str = []
input_params_str = []
for input_name in ordered_inputs:
if str(input_params[input_name].annotation).startswith("typing.Union"):
filter_custom_args = []
for arg in input_params[input_name].annotation.__args__:
str_arg = str(arg)
is_typing = str_arg.startswith("typing.")
is_torch = "torch." in str_arg
is_builten = str_arg in (str(int), str(float), str(type(None)))
if not (is_typing or is_torch or is_builten):
continue
ordered_inputs.append(inputs[key])
used_sign.append(key)
input_signature = used_sign
else:
ordered_inputs = list(inputs.values())
if is_torch_2:
return {"example_kwarg_inputs": inputs}, input_signature
else:
inputs = ordered_inputs
filter_custom_args.append(arg)
input_params[input_name].annotation.__args__ = tuple(filter_custom_args)
input_sign_str.append(str(input_params[input_name]).replace("NoneType", "None"))
input_params_str.append(f"{input_name}={input_name}")
wrapper_class = wrapper_template.format(input_sign=', '.join(input_sign_str), example_input=', '.join(input_params_str))
result = {}
try:
exec(wrapper_class, result)
wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
wrapped_model = model
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
if isinstance(inputs, dict):
return process_dict_inputs(inputs, input_params, model)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
return {"example_inputs": inputs}, input_signature
input_signature = list(input_params)
input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model
if isinstance(pt_module, torch.nn.Module):
pt_module.eval()
input_signature = None
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
# input params is dictionary contains input names and their signature values (type hints and default values if any)
input_params = inspect.signature(pt_module.forward if hasattr(pt_module, "forward") else pt_module.__call__).parameters
input_signature = list(input_params)
if example_inputs is None:
scripted = torch.jit.script(pt_module)
else:
input_parameters, input_signature = prepare_example_inputs(example_inputs, input_signature)
input_parameters, input_signature, pt_module = prepare_example_inputs_and_model(example_inputs, input_params, pt_module)
try:
scripted = torch.jit.trace(pt_module, **input_parameters)
except Exception:

View File

@@ -3,6 +3,7 @@
import os
from typing import Tuple
import numpy
import numpy as np
import openvino.runtime as ov
@@ -725,6 +726,150 @@ def create_pytorch_module_with_compressed_int8_constant(tmp_dir):
ref_model = Model([conv], [param1], "test")
return traced_model, ref_model, {"example_input": example_input}
def create_pytorch_module_with_nested_inputs(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, z:Tuple[torch.Tensor, torch.Tensor]):
z1, z2 = z
zeros1 = torch.zeros((1, 1))
zeros2 = torch.zeros((1, 5, 1))
return torch.cat([z1, zeros1], 1), torch.cat([z2, zeros2], 2)
net = PTModel()
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
shape1 = PartialShape([1, -1])
shape2 = PartialShape([1, 5, -1])
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
ref_model = Model([concat2, concat1], [param1, param2], "test")
return net, ref_model, {"example_input": {"z": (torch.zeros((1, 10)), torch.ones((1, 5, 2)))}}
def create_pytorch_module_with_nested_inputs2(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x:torch.Tensor, z:Tuple[torch.Tensor, torch.Tensor]):
z1, z2 = z
zeros1 = torch.zeros((1, 1))
zeros2 = torch.zeros((1, 5, 1))
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2)
net = PTModel()
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
shape1 = PartialShape([1, -1])
shape2 = PartialShape([1, 5, -1])
param0 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
add = ov.opset10.add(concat1, param0)
ref_model = Model([concat2, add], [param0, param1, param2], "test")
return net, ref_model, {"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))}}
def create_pytorch_module_with_nested_inputs3(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, z:Tuple[torch.Tensor, torch.Tensor], x:torch.Tensor):
z1, z2 = z
zeros1 = torch.zeros((1, 1))
zeros2 = torch.zeros((1, 5, 1))
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2)
net = PTModel()
shape1 = PartialShape([1, -1])
shape2 = PartialShape([1, 5, -1])
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
param3 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
add = ov.opset10.add(concat1, param3)
ref_model = Model([concat2, add], [param1, param2, param3], "test")
return net, ref_model, {"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))}}
def create_pytorch_module_with_nested_inputs4(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x:torch.Tensor, z:Tuple[torch.Tensor, torch.Tensor], y:torch.Tensor):
z1, z2 = z
zeros1 = torch.zeros((1, 1))
zeros2 = torch.zeros((1, 5, 1))
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2) * y
net = PTModel()
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
shape1 = PartialShape([1, -1])
shape2 = PartialShape([1, 5, -1])
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
param3 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
param4 = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
add = ov.opset10.add(concat1, param3)
mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param3, param1, param2, param4], "test")
return net, ref_model, {"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))), "y": torch.ones((1,))}}
def create_pytorch_module_with_nested_inputs5(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x:torch.Tensor, z:Tuple[torch.Tensor, torch.Tensor], y:torch.Tensor):
z1, z2 = z
zeros1 = torch.zeros((1, 1))
zeros2 = torch.zeros((1, 5, 1))
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2) * y
net = PTModel()
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
shape1 = PartialShape([1, -1])
shape2 = PartialShape([1, 5, -1])
param0 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
param4 = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
add = ov.opset10.add(concat1, param0)
mul = ov.opset10.multiply(concat2, param4)
ref_model = Model([mul, add], [param0, param1, param2, param4], "test")
return net, ref_model, {"example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))]}
def create_pytorch_module_with_nested_inputs6(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x:torch.Tensor, y:torch.Tensor=None, z:Tuple[torch.Tensor, torch.Tensor]=None):
z1, z2 = z
zeros1 = torch.zeros((1, 1))
zeros2 = torch.zeros((1, 5, 1))
if y is not None:
return torch.cat([z1, zeros1], 1) * y, torch.cat([z2, zeros2], 2) * y
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2)
net = PTModel()
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
shape1 = PartialShape([1, -1])
shape2 = PartialShape([1, 5, -1])
param0 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
add1 = ov.opset10.add(concat1, param0)
ref_model = Model([concat2, add1], [param0, param1, param2], "test")
return net, ref_model, {"example_input": {"x": torch.ones((1, 11)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10)))}}
class TestMoConvertPyTorch(CommonMOConvertTest):
test_data = [
@@ -770,6 +915,12 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_module_with_optional_inputs_case5,
create_pytorch_nn_module_with_scalar_input,
create_pytorch_module_with_compressed_int8_constant,
create_pytorch_module_with_nested_inputs,
create_pytorch_module_with_nested_inputs2,
create_pytorch_module_with_nested_inputs3,
create_pytorch_module_with_nested_inputs4,
create_pytorch_module_with_nested_inputs5,
create_pytorch_module_with_nested_inputs6
]
@ pytest.mark.parametrize("create_model", test_data)

View File

@@ -66,9 +66,16 @@ def extract_input_info_from_example(args, inputs):
input_shapes = args.placeholder_shapes or {}
is_dict_input = isinstance(example_inputs, dict)
list_inputs = list(example_inputs.values()) if is_dict_input else example_inputs
input_names = None if not is_dict_input else list(example_inputs)
if not isinstance(list_inputs, (list, tuple)):
input_names = None
if not isinstance(example_inputs, (list, tuple, dict)):
list_inputs = [list_inputs]
if args.input_model._input_signature is not None and not is_dict_input:
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
if not is_dict_input:
example_inputs = dict(zip(input_names, list_inputs))
is_dict_input = True
elif is_dict_input:
input_names = list(example_inputs)
if not data_types and input_names is None:
data_types = []
if not input_shapes and input_names is None:
@@ -85,18 +92,18 @@ def extract_input_info_from_example(args, inputs):
dtype = getattr(example_input, "dtype", type(example_input))
example_dtype = pt_to_ov_type_map.get(str(dtype))
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
if user_dtype is not None and example_dtype is not None and example_dtype.to_dtype() != user_dtype:
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
data_rank = getattr(example_input, "ndim", 0)
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
if user_input_shape.rank.get_length() != data_rank:
if user_input_shape.rank.is_static and user_input_shape.rank.get_length() != data_rank:
raise Error(
f"Requested input shape {user_input_shape.rank.get_length()} rank"
f" is not equal to provided example_input rank {data_rank}")
input_shape = user_input_shape if user_input_shape is not None else PartialShape([-1] * data_rank)
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype())
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype() if example_dtype is not None else None)
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
else:
for input_id, example_input in enumerate(list_inputs):
@@ -106,7 +113,7 @@ def extract_input_info_from_example(args, inputs):
input_shape = PartialShape([-1] * data_rank)
input_name = input_names[input_id] if input_names else None
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype() if ov_dtype is not None else None)
args.placeholder_data_types = data_types
args.placeholder_shapes = input_shapes
@@ -125,7 +132,7 @@ def to_torch_tensor(tensor):
return torch.tensor(tensor.data)
if isinstance(tensor, (float, int, bool)):
return tensor
if isinstance(tensor, tuple):
if isinstance(tensor, (tuple, list)):
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
return tuple(to_torch_tensor(x) for x in tensor)
else:

View File

@@ -66,9 +66,16 @@ def extract_input_info_from_example(args, inputs):
input_shapes = args.placeholder_shapes or {}
is_dict_input = isinstance(example_inputs, dict)
list_inputs = list(example_inputs.values()) if is_dict_input else example_inputs
input_names = None if not is_dict_input else list(example_inputs)
if not isinstance(list_inputs, (list, tuple)):
input_names = None
if not isinstance(example_inputs, (list, tuple, dict)):
list_inputs = [list_inputs]
if args.input_model._input_signature is not None and not is_dict_input:
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
if not is_dict_input:
example_inputs = dict(zip(input_names, list_inputs))
is_dict_input = True
elif is_dict_input:
input_names = list(example_inputs)
if not data_types and input_names is None:
data_types = []
if not input_shapes and input_names is None:
@@ -85,18 +92,18 @@ def extract_input_info_from_example(args, inputs):
dtype = getattr(example_input, "dtype", type(example_input))
example_dtype = pt_to_ov_type_map.get(str(dtype))
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
if user_dtype is not None and example_dtype is not None and example_dtype.to_dtype() != user_dtype:
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
data_rank = getattr(example_input, "ndim", 0)
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
if user_input_shape.rank.get_length() != data_rank:
if user_input_shape.rank.is_static and user_input_shape.rank.get_length() != data_rank:
raise Error(
f"Requested input shape {user_input_shape.rank.get_length()} rank"
f" is not equal to provided example_input rank {data_rank}")
input_shape = user_input_shape if user_input_shape is not None else PartialShape([-1] * data_rank)
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype())
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype() if example_dtype is not None else None)
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
else:
for input_id, example_input in enumerate(list_inputs):
@@ -106,7 +113,7 @@ def extract_input_info_from_example(args, inputs):
input_shape = PartialShape([-1] * data_rank)
input_name = input_names[input_id] if input_names else None
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype() if ov_dtype is not None else None)
args.placeholder_data_types = data_types
args.placeholder_shapes = input_shapes
@@ -125,7 +132,7 @@ def to_torch_tensor(tensor):
return torch.tensor(tensor.data)
if isinstance(tensor, (float, int, bool)):
return tensor
if isinstance(tensor, tuple):
if isinstance(tensor, (tuple, list)):
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
return tuple(to_torch_tensor(x) for x in tensor)
else: