[PT FE]: support nested inputs in example_inputs and arg dicts with d… (#18492)
* [PT FE]: support nested inputs in example_inputs and arg dicts with different argtypes * accept hande lists as inputs * Update tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py * update tests and add comments in code * fix for custom types in annotations and duplicate in mo * Update tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py
This commit is contained in:
@@ -13,6 +13,19 @@ from packaging.version import parse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
wrapper_template="""
|
||||
import torch
|
||||
from typing import *
|
||||
|
||||
class ModelWrapper(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, {input_sign}):
|
||||
return self.model({example_input})
|
||||
"""
|
||||
|
||||
|
||||
def get_type_from_py_type(value):
|
||||
if isinstance(value, float):
|
||||
@@ -142,38 +155,76 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
def prepare_example_inputs(inputs, input_signature):
|
||||
is_torch_2 = parse(torch.__version__) >= parse("2.0.0")
|
||||
if isinstance(inputs, dict):
|
||||
ordered_inputs = []
|
||||
if input_signature is not None:
|
||||
used_sign = []
|
||||
for key in input_signature:
|
||||
if key not in inputs:
|
||||
def process_dict_inputs(inputs, input_params, model):
|
||||
ordered_inputs = []
|
||||
for input_name in input_params:
|
||||
if input_name in inputs:
|
||||
ordered_inputs.append(input_name)
|
||||
|
||||
input_signature = list(input_params)
|
||||
if ordered_inputs == input_signature[:len(ordered_inputs)]:
|
||||
example_inputs = [inputs[input_name] for input_name in ordered_inputs]
|
||||
if all([isinstance(inp, torch.Tensor) for inp in example_inputs]):
|
||||
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, model
|
||||
return {"example_inputs": example_inputs}, ordered_inputs, model
|
||||
|
||||
# PyTorch has some difficulties to trace models with named unordered parameters:
|
||||
# torch < 2.0.0 supports only positional arguments for tracing
|
||||
# pytorch == 2.0.0 supports input kwargs tracing,
|
||||
# but does not support complex nested objects (e. g. tuple of tuples of tensors)
|
||||
# We will use wrapper for making them positional as workaround.
|
||||
|
||||
input_sign_str = []
|
||||
input_params_str = []
|
||||
|
||||
for input_name in ordered_inputs:
|
||||
if str(input_params[input_name].annotation).startswith("typing.Union"):
|
||||
filter_custom_args = []
|
||||
for arg in input_params[input_name].annotation.__args__:
|
||||
str_arg = str(arg)
|
||||
is_typing = str_arg.startswith("typing.")
|
||||
is_torch = "torch." in str_arg
|
||||
is_builten = str_arg in (str(int), str(float), str(type(None)))
|
||||
if not (is_typing or is_torch or is_builten):
|
||||
continue
|
||||
ordered_inputs.append(inputs[key])
|
||||
used_sign.append(key)
|
||||
input_signature = used_sign
|
||||
else:
|
||||
ordered_inputs = list(inputs.values())
|
||||
if is_torch_2:
|
||||
return {"example_kwarg_inputs": inputs}, input_signature
|
||||
else:
|
||||
inputs = ordered_inputs
|
||||
filter_custom_args.append(arg)
|
||||
input_params[input_name].annotation.__args__ = tuple(filter_custom_args)
|
||||
input_sign_str.append(str(input_params[input_name]).replace("NoneType", "None"))
|
||||
input_params_str.append(f"{input_name}={input_name}")
|
||||
|
||||
wrapper_class = wrapper_template.format(input_sign=', '.join(input_sign_str), example_input=', '.join(input_params_str))
|
||||
result = {}
|
||||
try:
|
||||
exec(wrapper_class, result)
|
||||
|
||||
wrapped_model = result["ModelWrapper"](model)
|
||||
wrapped_model.eval()
|
||||
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
|
||||
except Exception:
|
||||
wrapped_model = model
|
||||
|
||||
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
|
||||
|
||||
def prepare_example_inputs_and_model(inputs, input_params, model):
|
||||
if isinstance(inputs, dict):
|
||||
return process_dict_inputs(inputs, input_params, model)
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
|
||||
return {"example_inputs": inputs}, input_signature
|
||||
input_signature = list(input_params)
|
||||
input_signature = input_signature[:len(inputs)]
|
||||
return {"example_inputs": inputs}, input_signature, model
|
||||
|
||||
if isinstance(pt_module, torch.nn.Module):
|
||||
pt_module.eval()
|
||||
input_signature = None
|
||||
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
||||
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
|
||||
# input params is dictionary contains input names and their signature values (type hints and default values if any)
|
||||
input_params = inspect.signature(pt_module.forward if hasattr(pt_module, "forward") else pt_module.__call__).parameters
|
||||
input_signature = list(input_params)
|
||||
if example_inputs is None:
|
||||
scripted = torch.jit.script(pt_module)
|
||||
else:
|
||||
input_parameters, input_signature = prepare_example_inputs(example_inputs, input_signature)
|
||||
input_parameters, input_signature, pt_module = prepare_example_inputs_and_model(example_inputs, input_params, pt_module)
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, **input_parameters)
|
||||
except Exception:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
|
||||
from typing import Tuple
|
||||
import numpy
|
||||
import numpy as np
|
||||
import openvino.runtime as ov
|
||||
@@ -725,6 +726,150 @@ def create_pytorch_module_with_compressed_int8_constant(tmp_dir):
|
||||
ref_model = Model([conv], [param1], "test")
|
||||
return traced_model, ref_model, {"example_input": example_input}
|
||||
|
||||
def create_pytorch_module_with_nested_inputs(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
|
||||
def forward(self, z:Tuple[torch.Tensor, torch.Tensor]):
|
||||
z1, z2 = z
|
||||
zeros1 = torch.zeros((1, 1))
|
||||
zeros2 = torch.zeros((1, 5, 1))
|
||||
return torch.cat([z1, zeros1], 1), torch.cat([z2, zeros2], 2)
|
||||
|
||||
net = PTModel()
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
|
||||
shape1 = PartialShape([1, -1])
|
||||
shape2 = PartialShape([1, 5, -1])
|
||||
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
|
||||
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
|
||||
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
|
||||
ref_model = Model([concat2, concat1], [param1, param2], "test")
|
||||
return net, ref_model, {"example_input": {"z": (torch.zeros((1, 10)), torch.ones((1, 5, 2)))}}
|
||||
|
||||
|
||||
def create_pytorch_module_with_nested_inputs2(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x:torch.Tensor, z:Tuple[torch.Tensor, torch.Tensor]):
|
||||
z1, z2 = z
|
||||
zeros1 = torch.zeros((1, 1))
|
||||
zeros2 = torch.zeros((1, 5, 1))
|
||||
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2)
|
||||
|
||||
net = PTModel()
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
|
||||
shape1 = PartialShape([1, -1])
|
||||
shape2 = PartialShape([1, 5, -1])
|
||||
param0 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
|
||||
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
|
||||
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
|
||||
add = ov.opset10.add(concat1, param0)
|
||||
ref_model = Model([concat2, add], [param0, param1, param2], "test")
|
||||
return net, ref_model, {"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 5)))}}
|
||||
|
||||
def create_pytorch_module_with_nested_inputs3(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
|
||||
def forward(self, z:Tuple[torch.Tensor, torch.Tensor], x:torch.Tensor):
|
||||
z1, z2 = z
|
||||
zeros1 = torch.zeros((1, 1))
|
||||
zeros2 = torch.zeros((1, 5, 1))
|
||||
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2)
|
||||
|
||||
net = PTModel()
|
||||
shape1 = PartialShape([1, -1])
|
||||
shape2 = PartialShape([1, 5, -1])
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
|
||||
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
|
||||
param3 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
|
||||
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
|
||||
add = ov.opset10.add(concat1, param3)
|
||||
ref_model = Model([concat2, add], [param1, param2, param3], "test")
|
||||
return net, ref_model, {"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 3)))}}
|
||||
|
||||
|
||||
def create_pytorch_module_with_nested_inputs4(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x:torch.Tensor, z:Tuple[torch.Tensor, torch.Tensor], y:torch.Tensor):
|
||||
z1, z2 = z
|
||||
zeros1 = torch.zeros((1, 1))
|
||||
zeros2 = torch.zeros((1, 5, 1))
|
||||
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2) * y
|
||||
|
||||
net = PTModel()
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
|
||||
shape1 = PartialShape([1, -1])
|
||||
shape2 = PartialShape([1, 5, -1])
|
||||
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
|
||||
param3 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||
param4 = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
|
||||
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
|
||||
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
|
||||
add = ov.opset10.add(concat1, param3)
|
||||
mul = ov.opset10.multiply(concat2, param4)
|
||||
ref_model = Model([mul, add], [param3, param1, param2, param4], "test")
|
||||
return net, ref_model, {"example_input": {"x": torch.ones((1, 10)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10))), "y": torch.ones((1,))}}
|
||||
|
||||
def create_pytorch_module_with_nested_inputs5(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x:torch.Tensor, z:Tuple[torch.Tensor, torch.Tensor], y:torch.Tensor):
|
||||
z1, z2 = z
|
||||
zeros1 = torch.zeros((1, 1))
|
||||
zeros2 = torch.zeros((1, 5, 1))
|
||||
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2) * y
|
||||
|
||||
net = PTModel()
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
|
||||
shape1 = PartialShape([1, -1])
|
||||
shape2 = PartialShape([1, 5, -1])
|
||||
param0 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
|
||||
param4 = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
|
||||
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
|
||||
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
|
||||
add = ov.opset10.add(concat1, param0)
|
||||
mul = ov.opset10.multiply(concat2, param4)
|
||||
ref_model = Model([mul, add], [param0, param1, param2, param4], "test")
|
||||
return net, ref_model, {"example_input": [torch.ones((1, 10)), (torch.zeros((1, 10)), torch.ones((1, 5, 10))), torch.ones((1,))]}
|
||||
|
||||
def create_pytorch_module_with_nested_inputs6(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x:torch.Tensor, y:torch.Tensor=None, z:Tuple[torch.Tensor, torch.Tensor]=None):
|
||||
z1, z2 = z
|
||||
zeros1 = torch.zeros((1, 1))
|
||||
zeros2 = torch.zeros((1, 5, 1))
|
||||
if y is not None:
|
||||
return torch.cat([z1, zeros1], 1) * y, torch.cat([z2, zeros2], 2) * y
|
||||
return torch.cat([z1, zeros1], 1) + x, torch.cat([z2, zeros2], 2)
|
||||
|
||||
net = PTModel()
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
constant_zeros2 = ov.opset10.constant(np.zeros((1, 5, 1), dtype=np.float32), dtype=np.float32)
|
||||
shape1 = PartialShape([1, -1])
|
||||
shape2 = PartialShape([1, 5, -1])
|
||||
param0 = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||
param1 = ov.opset10.parameter(shape1, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape2, dtype=np.float32)
|
||||
concat1 = ov.opset10.concat([param1, constant_zeros1], 1)
|
||||
concat2 = ov.opset10.concat([param2, constant_zeros2], 2)
|
||||
add1 = ov.opset10.add(concat1, param0)
|
||||
ref_model = Model([concat2, add1], [param0, param1, param2], "test")
|
||||
return net, ref_model, {"example_input": {"x": torch.ones((1, 11)), "z": (torch.zeros((1, 10)), torch.ones((1, 5, 10)))}}
|
||||
|
||||
|
||||
class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
test_data = [
|
||||
@@ -770,6 +915,12 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_module_with_optional_inputs_case5,
|
||||
create_pytorch_nn_module_with_scalar_input,
|
||||
create_pytorch_module_with_compressed_int8_constant,
|
||||
create_pytorch_module_with_nested_inputs,
|
||||
create_pytorch_module_with_nested_inputs2,
|
||||
create_pytorch_module_with_nested_inputs3,
|
||||
create_pytorch_module_with_nested_inputs4,
|
||||
create_pytorch_module_with_nested_inputs5,
|
||||
create_pytorch_module_with_nested_inputs6
|
||||
]
|
||||
|
||||
@ pytest.mark.parametrize("create_model", test_data)
|
||||
|
||||
@@ -66,9 +66,16 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_shapes = args.placeholder_shapes or {}
|
||||
is_dict_input = isinstance(example_inputs, dict)
|
||||
list_inputs = list(example_inputs.values()) if is_dict_input else example_inputs
|
||||
input_names = None if not is_dict_input else list(example_inputs)
|
||||
if not isinstance(list_inputs, (list, tuple)):
|
||||
input_names = None
|
||||
if not isinstance(example_inputs, (list, tuple, dict)):
|
||||
list_inputs = [list_inputs]
|
||||
if args.input_model._input_signature is not None and not is_dict_input:
|
||||
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
||||
if not is_dict_input:
|
||||
example_inputs = dict(zip(input_names, list_inputs))
|
||||
is_dict_input = True
|
||||
elif is_dict_input:
|
||||
input_names = list(example_inputs)
|
||||
if not data_types and input_names is None:
|
||||
data_types = []
|
||||
if not input_shapes and input_names is None:
|
||||
@@ -85,18 +92,18 @@ def extract_input_info_from_example(args, inputs):
|
||||
dtype = getattr(example_input, "dtype", type(example_input))
|
||||
example_dtype = pt_to_ov_type_map.get(str(dtype))
|
||||
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
|
||||
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
|
||||
if user_dtype is not None and example_dtype is not None and example_dtype.to_dtype() != user_dtype:
|
||||
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
|
||||
|
||||
data_rank = getattr(example_input, "ndim", 0)
|
||||
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
|
||||
if user_input_shape.rank.get_length() != data_rank:
|
||||
if user_input_shape.rank.is_static and user_input_shape.rank.get_length() != data_rank:
|
||||
raise Error(
|
||||
f"Requested input shape {user_input_shape.rank.get_length()} rank"
|
||||
f" is not equal to provided example_input rank {data_rank}")
|
||||
|
||||
input_shape = user_input_shape if user_input_shape is not None else PartialShape([-1] * data_rank)
|
||||
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype())
|
||||
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype() if example_dtype is not None else None)
|
||||
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
|
||||
else:
|
||||
for input_id, example_input in enumerate(list_inputs):
|
||||
@@ -106,7 +113,7 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_shape = PartialShape([-1] * data_rank)
|
||||
input_name = input_names[input_id] if input_names else None
|
||||
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
|
||||
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
|
||||
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype() if ov_dtype is not None else None)
|
||||
|
||||
args.placeholder_data_types = data_types
|
||||
args.placeholder_shapes = input_shapes
|
||||
@@ -125,7 +132,7 @@ def to_torch_tensor(tensor):
|
||||
return torch.tensor(tensor.data)
|
||||
if isinstance(tensor, (float, int, bool)):
|
||||
return tensor
|
||||
if isinstance(tensor, tuple):
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
|
||||
return tuple(to_torch_tensor(x) for x in tensor)
|
||||
else:
|
||||
|
||||
@@ -66,9 +66,16 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_shapes = args.placeholder_shapes or {}
|
||||
is_dict_input = isinstance(example_inputs, dict)
|
||||
list_inputs = list(example_inputs.values()) if is_dict_input else example_inputs
|
||||
input_names = None if not is_dict_input else list(example_inputs)
|
||||
if not isinstance(list_inputs, (list, tuple)):
|
||||
input_names = None
|
||||
if not isinstance(example_inputs, (list, tuple, dict)):
|
||||
list_inputs = [list_inputs]
|
||||
if args.input_model._input_signature is not None and not is_dict_input:
|
||||
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
||||
if not is_dict_input:
|
||||
example_inputs = dict(zip(input_names, list_inputs))
|
||||
is_dict_input = True
|
||||
elif is_dict_input:
|
||||
input_names = list(example_inputs)
|
||||
if not data_types and input_names is None:
|
||||
data_types = []
|
||||
if not input_shapes and input_names is None:
|
||||
@@ -85,18 +92,18 @@ def extract_input_info_from_example(args, inputs):
|
||||
dtype = getattr(example_input, "dtype", type(example_input))
|
||||
example_dtype = pt_to_ov_type_map.get(str(dtype))
|
||||
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
|
||||
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
|
||||
if user_dtype is not None and example_dtype is not None and example_dtype.to_dtype() != user_dtype:
|
||||
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
|
||||
|
||||
data_rank = getattr(example_input, "ndim", 0)
|
||||
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
|
||||
if user_input_shape.rank.get_length() != data_rank:
|
||||
if user_input_shape.rank.is_static and user_input_shape.rank.get_length() != data_rank:
|
||||
raise Error(
|
||||
f"Requested input shape {user_input_shape.rank.get_length()} rank"
|
||||
f" is not equal to provided example_input rank {data_rank}")
|
||||
|
||||
input_shape = user_input_shape if user_input_shape is not None else PartialShape([-1] * data_rank)
|
||||
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype())
|
||||
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype() if example_dtype is not None else None)
|
||||
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
|
||||
else:
|
||||
for input_id, example_input in enumerate(list_inputs):
|
||||
@@ -106,7 +113,7 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_shape = PartialShape([-1] * data_rank)
|
||||
input_name = input_names[input_id] if input_names else None
|
||||
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
|
||||
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
|
||||
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype() if ov_dtype is not None else None)
|
||||
|
||||
args.placeholder_data_types = data_types
|
||||
args.placeholder_shapes = input_shapes
|
||||
@@ -125,7 +132,7 @@ def to_torch_tensor(tensor):
|
||||
return torch.tensor(tensor.data)
|
||||
if isinstance(tensor, (float, int, bool)):
|
||||
return tensor
|
||||
if isinstance(tensor, tuple):
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
|
||||
return tuple(to_torch_tensor(x) for x in tensor)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user