[PT FE]: use example input for dtype and rank detection (#18083)

* [PT FE]: use example input for dtype and rank detection, support unordered input dict

* Apply suggestions from code review

* restore old behaviour for old torch versions

* move info addition after parsing
This commit is contained in:
Ekaterina Aidova
2023-06-20 17:31:30 +04:00
committed by GitHub
parent 37c538d6bd
commit a8a366de08
5 changed files with 244 additions and 62 deletions

View File

@@ -9,6 +9,7 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
import typing
from packaging.version import parse
import torch
import numpy as np
@@ -133,24 +134,27 @@ class TorchScriptPythonDecoder (Decoder):
import inspect
def prepare_example_inputs(inputs, input_signature):
if inputs is not None:
if isinstance(inputs, dict):
if input_signature is not None:
ordered_inputs = []
used_sign = []
for key in input_signature:
if key not in inputs:
continue
ordered_inputs.append(inputs[key])
used_sign.append(key)
inputs = ordered_inputs
input_signature = used_sign
else:
inputs = list(inputs.values())
input_signature = input_signature[:len(inputs)]
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
return inputs, input_signature
is_torch_2 = parse(torch.__version__) >= parse("2.0.0")
if isinstance(inputs, dict):
ordered_inputs = []
if input_signature is not None:
used_sign = []
for key in input_signature:
if key not in inputs:
continue
ordered_inputs.append(inputs[key])
used_sign.append(key)
input_signature = used_sign
else:
ordered_inputs = list(inputs.values())
if is_torch_2:
return {"example_kwarg_inputs": inputs}, input_signature
else:
inputs = ordered_inputs
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
return {"example_inputs": inputs}, input_signature
if isinstance(pt_module, torch.nn.Module):
pt_module.eval()
@@ -160,14 +164,14 @@ class TorchScriptPythonDecoder (Decoder):
if example_inputs is None:
scripted = torch.jit.script(pt_module)
else:
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
input_parameters, input_signature = prepare_example_inputs(example_inputs, input_signature)
try:
scripted = torch.jit.trace(pt_module, inputs)
scripted = torch.jit.trace(pt_module, **input_parameters)
except Exception:
try:
scripted = torch.jit.script(pt_module)
except Exception:
scripted = torch.jit.trace(pt_module, inputs, strict=False)
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
skip_freeze = False
for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models

View File

@@ -62,6 +62,26 @@ def make_pt_model_two_inputs():
return NeuralNetwork()
def make_pt_model_with_optional_input():
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.linear_relu_stack = nn.Sequential(
nn.ReLU(),
nn.Sigmoid(),
)
def forward(self, x, y=None, z=None):
if y is None:
logits = self.linear_relu_stack(x + z)
if z is None:
logits = self.linear_relu_stack(x * y)
return logits
return NeuralNetwork()
def make_ref_pt_model_one_input(shape, dtype=np.float32):
shape = PartialShape(shape)
param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype)
@@ -96,16 +116,37 @@ def make_ref_pt_model_two_inputs(shape, dtype=np.float32):
return model
def make_ref_pt_model_with_optional_inputs(shape, dtype=np.float32, z_exist=False):
if len(shape) == 2:
param1 = ov.opset8.parameter(PartialShape(
shape[0]), name="input_0", dtype=dtype)
param2 = ov.opset8.parameter(PartialShape(
shape[1]), name="input_1", dtype=dtype)
else:
shape = PartialShape(shape)
param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype)
param2 = ov.opset8.parameter(shape, name="input_1", dtype=dtype)
op = ov.opset8.multiply(param1, param2) if not z_exist else ov.opset8.add(param1, param2)
relu = ov.opset8.relu(op)
if dtype != np.float32:
relu = ov.opset8.convert(relu, np.float32)
sigm = ov.opset8.sigmoid(relu)
parameter_list = [param1, param2]
model = Model([sigm], parameter_list, "test")
return model
def create_pytorch_nn_module_case1(tmp_dir):
pt_model = make_pt_model_two_inputs()
ref_model = make_ref_pt_model_two_inputs([-1, 3, -1, -1])
ref_model = make_ref_pt_model_two_inputs([-1, -1, -1, -1])
sample_input1 = torch.zeros(1, 3, 10, 10)
sample_input2 = torch.zeros(1, 3, 10, 10)
sample_input = sample_input1, sample_input2
return pt_model, ref_model, {'input': [([-1, 3, -1, -1], np.float32), ([-1, 3, -1, -1], np.float32)],
'example_input': sample_input}
return pt_model, ref_model, {'example_input': sample_input}
def create_pytorch_nn_module_case2(tmp_dir):
@@ -117,7 +158,6 @@ def create_pytorch_nn_module_case2(tmp_dir):
sample_input = sample_input1, sample_input2
return pt_model, ref_model, {'input_shape': ["[?,3,?,?]", PartialShape([-1, 3, -1, -1])],
'input': [np.float32, np.float32],
'example_input': sample_input}
@@ -130,7 +170,6 @@ def create_pytorch_nn_module_case3(tmp_dir):
sample_input = tuple([sample_input1, sample_input2])
return pt_model, ref_model, {'input_shape': "[?,3,?,?],[?,3,?,?]",
'input': [np.float32, np.float32],
'example_input': sample_input}
@@ -139,10 +178,9 @@ def create_pytorch_nn_module_case4(tmp_dir):
sample_input = torch.zeros(1, 3, 10, 10)
ref_model = make_ref_pt_model_one_input(PartialShape.dynamic())
ref_model = make_ref_pt_model_one_input(PartialShape([1, 3, 20, 20]))
return pt_model, ref_model, {'input': [np.float32],
'example_input': sample_input}
return pt_model, ref_model, {'example_input': sample_input, "input_shape": [1, 3, 20, 20]}
def create_pytorch_nn_module_case5(tmp_dir):
@@ -163,6 +201,15 @@ def create_pytorch_nn_module_case6(tmp_dir):
return pt_model, ref_model, {'input': (shape, np.float32)}
def create_pytorch_nn_module_case7(tmp_dir):
pt_model = make_pt_model_one_input()
sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
ref_model = make_ref_pt_model_one_input(PartialShape([1, 3, 20, 20]), dtype=np.int32)
return pt_model, ref_model, {'example_input': sample_input, "input": ([1, 3, 20, 20], np.int32)}
def create_pytorch_nn_module_torch_size(tmp_dir):
pt_model = make_pt_model_one_input()
ref_model = make_ref_pt_model_one_input([1, 3, 2, 10])
@@ -176,7 +223,7 @@ def create_pytorch_nn_module_sample_input_int32(tmp_dir):
sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
ref_model = make_ref_pt_model_one_input(shape, dtype=numpy.int32)
ref_model = make_ref_pt_model_one_input(shape, dtype=np.int32)
return pt_model, ref_model, {'example_input': sample_input,
'input': (shape, np.int32)}
@@ -216,7 +263,7 @@ def create_pytorch_jit_script_function(tmp_dir):
inp_shape = PartialShape([Dimension(1, -1), Dimension(-1, 5), 10])
ref_model = make_ref_pt_model_two_inputs(inp_shape)
return scripted_fn, ref_model, {'input': [(inp_shape, np.float32), (inp_shape, np.float32)]}
return scripted_fn, ref_model, {'input': [(inp_shape), (inp_shape)]}
@@ -232,7 +279,6 @@ def create_pytorch_nn_module_layout_list(tmp_dir):
return pt_model, ref_model, {
'input_shape': [shape, shape], 'layout': ['nchw', Layout('nhwc')],
'input': [np.float32, np.float32]
}
@@ -247,8 +293,7 @@ def create_pytorch_nn_module_layout_list_case2(tmp_dir):
ref_model.inputs[1].node.layout = Layout('nhwc')
return pt_model, ref_model, {
'input_shape': [shape, shape], 'layout': ('nchw', Layout('nhwc')),
'input': [np.float32, np.float32]}
'input_shape': [shape, shape], 'layout': ('nchw', Layout('nhwc'))}
def create_pytorch_nn_module_mean_list(tmp_dir):
@@ -270,8 +315,7 @@ def create_pytorch_nn_module_mean_list(tmp_dir):
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {
'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': False,
'input': [np.float32, np.float32]}
'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': False}
def create_pytorch_nn_module_mean_list_default_no_compression(tmp_dir):
@@ -293,7 +337,7 @@ def create_pytorch_nn_module_mean_list_default_no_compression(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'input': [np.float32, np.float32]}
return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]]}
def create_pytorch_nn_module_mean_list_compression_enabled(tmp_dir):
@@ -316,7 +360,7 @@ def create_pytorch_nn_module_mean_list_compression_enabled(tmp_dir):
return pt_model, ref_model, {
'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]],
'compress_to_fp16': False, 'input': [np.float32, np.float32]}
'compress_to_fp16': False}
def create_pytorch_nn_module_scale_list(tmp_dir):
@@ -337,7 +381,7 @@ def create_pytorch_nn_module_scale_list(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], 'compress_to_fp16': False, 'input': [np.float32, np.float32]}
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], 'compress_to_fp16': False}
def create_pytorch_nn_module_scale_list_default_no_compression(tmp_dir):
@@ -359,7 +403,7 @@ def create_pytorch_nn_module_scale_list_default_no_compression(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], 'input': [np.float32, np.float32]}
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]]}
def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
@@ -382,7 +426,7 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], 'input': [np.float32, np.float32],
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]],
'compress_to_fp16': True}
@@ -390,7 +434,7 @@ def create_pytorch_nn_module_shapes_list_static(tmp_dir):
pt_model = make_pt_model_two_inputs()
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20])
return pt_model, ref_model, {'input_shape': [[1, 3, 20, 20], [1, 3, 20, 20]], 'input': [np.float32, np.float32]}
return pt_model, ref_model, {'input_shape': [[1, 3, 20, 20], [1, 3, 20, 20]]}
def create_pytorch_nn_module_shapes_list_static_via_input(tmp_dir):
@@ -415,7 +459,7 @@ def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input_shape': inp_shapes, 'input': [np.float32, np.float32]}
return pt_model, ref_model, {'input_shape': inp_shapes}
def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
@@ -433,14 +477,14 @@ def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input': [(inp_shapes[0], np.float32), (inp_shapes[1], np.float32)]}
return pt_model, ref_model, {'input': [(inp_shapes[0],), (inp_shapes[1],)]}
def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir):
pt_model = make_pt_model_one_input()
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)]]
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
return pt_model, ref_model, {'input_shape': inp_shapes, 'input': np.float32}
return pt_model, ref_model, {'input_shape': inp_shapes}
def create_pytorch_nn_module_shapes_list_dynamic_single_input_via_input(tmp_dir):
@@ -454,7 +498,7 @@ def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir):
pt_model = make_pt_model_one_input()
inp_shapes = [[1, 3, 20, 20]]
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
return pt_model, ref_model, {'input_shape': inp_shapes, 'input': np.float32}
return pt_model, ref_model, {'input_shape': inp_shapes}
def create_pytorch_nn_module_shapes_list_static_single_input_via_input(tmp_dir):
@@ -548,8 +592,7 @@ def create_pytorch_jit_script_module_convert_pytorch_frontend(tmp_dir):
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return scripted_model, ref_model, {
"example_input": {"x": torch.zeros((1, 3, 10, 10)), "y": torch.ones((1, 3, 10, 10))},
'input': [InputCutInfo(shape=[-1, -1, -1, -1], type="f32"), InputCutInfo(shape=[-1, -1, -1, -1], type="f32")]}
"example_input": [torch.zeros((1, 3, 10, 10)), torch.ones((1, 3, 10, 10))]}
def create_pytorch_jit_trace_module_convert_pytorch_frontend(tmp_dir):
@@ -567,8 +610,7 @@ def create_pytorch_jit_trace_module_convert_pytorch_frontend(tmp_dir):
sigm = ov.opset10.sigmoid(relu)
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return scripted_model, ref_model, {"example_input": example_input, 'input': [
InputCutInfo(shape=[-1, -1, -1, -1], type="f32"), InputCutInfo(shape=[-1, -1, -1, -1], type="f32")]}
return scripted_model, ref_model, {"example_input": example_input}
def create_pytorch_module_convert_pytorch_frontend_oob(tmp_dir):
@@ -595,6 +637,39 @@ def create_pytorch_module_convert_pytorch_frontend_oob(tmp_dir):
return net, ref_model, {}
def create_pytorch_module_with_optional_inputs_case1(tmp_dir):
net = make_pt_model_with_optional_input()
example_input = {"x": torch.zeros((1,3,10,10)), "y": torch.ones((1,3,10,10))}
ref_model = make_ref_pt_model_with_optional_inputs([-1, -1, -1, -1])
return net, ref_model, {"example_input": example_input}
def create_pytorch_module_with_optional_inputs_case2(tmp_dir):
net = make_pt_model_with_optional_input()
example_input = {"x": torch.zeros((1,3,10,10)), "z": torch.ones((1,3,10,10))}
ref_model = make_ref_pt_model_with_optional_inputs([-1, -1, -1, -1], z_exist=True)
return net, ref_model, {"example_input": example_input}
def create_pytorch_module_with_optional_inputs_case3(tmp_dir):
net = make_pt_model_with_optional_input()
example_input = {"x": torch.zeros((1,3,10,10)), "z": torch.ones((1,3,10,10))}
ref_model = make_ref_pt_model_with_optional_inputs([3, 3, 3, 3], z_exist=True)
return net, ref_model, {"example_input": example_input, "input_shape": [[3, 3, 3, 3], [3, 3, 3, 3]]}
def create_pytorch_module_with_optional_inputs_case4(tmp_dir):
net = make_pt_model_with_optional_input()
ref_model = make_ref_pt_model_with_optional_inputs([3, 3, 3, 3], z_exist=True)
return net, ref_model, {"input": [("x", [3, 3, 3, 3]), ("z", [3, 3, 3, 3])]}
def create_pytorch_module_with_optional_inputs_case5(tmp_dir):
net = make_pt_model_with_optional_input()
ref_model = make_ref_pt_model_with_optional_inputs([1, 3, -1, -1], z_exist=True)
return net, ref_model, {"input": ["x", "z"], "input_shape": [[1, 3, -1, -1], [1, 3, -1, -1]]}
class TestMoConvertPyTorch(CommonMOConvertTest):
test_data = [
create_pytorch_nn_module_case1,
@@ -603,6 +678,7 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_nn_module_case4,
create_pytorch_nn_module_case5,
create_pytorch_nn_module_case6,
create_pytorch_nn_module_case7,
create_pytorch_nn_module_torch_size,
create_pytorch_nn_module_sample_input_int32,
create_pytorch_nn_module_sample_input_int32_two_inputs,
@@ -630,7 +706,12 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_nn_module_convert_pytorch_frontend4,
create_pytorch_jit_script_module_convert_pytorch_frontend,
create_pytorch_jit_trace_module_convert_pytorch_frontend,
create_pytorch_module_convert_pytorch_frontend_oob
create_pytorch_module_convert_pytorch_frontend_oob,
create_pytorch_module_with_optional_inputs_case1,
create_pytorch_module_with_optional_inputs_case2,
create_pytorch_module_with_optional_inputs_case3,
create_pytorch_module_with_optional_inputs_case4,
create_pytorch_module_with_optional_inputs_case5
]
@ pytest.mark.parametrize("create_model", test_data)

View File

@@ -48,7 +48,7 @@ from openvino.tools.mo.utils.utils import refer_to_faq_msg, check_values_equal
from openvino.tools.mo.utils.telemetry_utils import send_params_info, send_framework_info, send_conversion_result, \
get_tid
from openvino.tools.mo.moc_frontend.check_config import legacy_extensions_used
from openvino.tools.mo.moc_frontend.pytorch_frontend_utils import get_pytorch_decoder
from openvino.tools.mo.moc_frontend.pytorch_frontend_utils import get_pytorch_decoder, extract_input_info_from_example
from openvino.tools.mo.moc_frontend.paddle_frontend_utils import paddle_frontend_converter
from openvino.tools.mo.moc_frontend.shape_utils import parse_input_shapes
@@ -760,6 +760,9 @@ def python_api_params_parsing(argv: argparse.Namespace):
argv.placeholder_shapes = shape_list if shape_list else None
argv.placeholder_data_types = data_type_list if data_type_list else {}
if argv.framework == "pytorch" and getattr(argv, "example_input", None) is not None:
extract_input_info_from_example(argv, inputs)
def pack_params_to_args_namespace(args: dict, cli_parser: argparse.ArgumentParser):
if len(args) > 0:
@@ -838,9 +841,7 @@ def _convert(cli_parser: argparse.ArgumentParser, framework, args, python_api_us
elif 'example_inputs' in args:
raise AssertionError("'example_inputs' argument is not recognized, maybe you meant to provide 'example_input'?")
decoder = get_pytorch_decoder(args['input_model'], parse_input_shapes(args), example_inputs, args.get("input"))
args['input_model'] = decoder
args['framework'] = model_framework
decoder = get_pytorch_decoder(args['input_model'], parse_input_shapes(args), example_inputs, args)
if model_framework == "paddle":
example_inputs = None
if 'example_input' in args and args['example_input'] is not None:
@@ -950,6 +951,6 @@ def _convert(cli_parser: argparse.ArgumentParser, framework, args, python_api_us
send_conversion_result('fail')
if python_api_used:
raise e.with_traceback(None)
raise e#.with_traceback(None)
else:
return None, argv

View File

@@ -399,7 +399,7 @@ def convert_params_lists_to_dicts(input_model,
# this cycle adds each unnamed type to dictionary using name from model_inputs
for idx, node_type in enumerate(input_user_data_types):
assert isinstance(node_type, (type, Type)), "Got incorrect format of input types. " \
assert isinstance(node_type, (type, np.dtype, Type)), "Got incorrect format of input types. " \
"Expected numpy type or openvino.runtime.Type, " \
"got {}.".format(type(node_type))

View File

@@ -5,21 +5,111 @@ import logging as log
import numpy as np
from openvino.tools.mo.moc_frontend.shape_utils import get_static_shape
from openvino.tools.mo.utils.error import Error
from openvino.runtime import Tensor, Type
from openvino.runtime import Tensor, Type, PartialShape
from openvino.runtime.utils.types import get_element_type_str
from openvino.tools.mo.utils.cli_parser import input_to_input_cut_info, input_shape_to_input_cut_info
def get_pytorch_decoder(model, input_shape, example_inputs, input_info):
def get_pytorch_decoder(model, input_shape, example_inputs, args):
try:
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
except Exception as e:
log.error("PyTorch frontend loading failed")
raise e
inputs = prepare_torch_inputs(example_inputs, input_shape, input_info, allow_none=True)
inputs = prepare_torch_inputs(example_inputs, input_shape, args.get("input"), allow_none=True)
decoder = TorchScriptPythonDecoder(model, example_input=inputs)
args['input_model'] = decoder
args["framework"] = "pytorch"
args["example_input"] = inputs
return decoder
return args
def update_list_or_dict(container, name, idx, value):
if isinstance(container, dict):
if name is None:
name = list(container)[idx]
container[name] = value
return
if idx == len(container):
container.append(value)
elif idx > len(container):
raise Error(f"Wrong {idx}")
else:
container[idx] = value
return
def get_value_from_list_or_dict(container, name, idx):
if isinstance(container, dict):
if name is None:
if idx < len(container):
name = list(container)[idx]
return None
return container.get(name)
if idx < len(container):
return container[idx]
return None
def extract_input_info_from_example(args, inputs):
try:
from openvino.frontend.pytorch.decoder import pt_to_ov_type_map
except Exception as e:
log.error("PyTorch frontend loading failed")
raise e
example_inputs = args.example_input
data_types = args.placeholder_data_types or {}
input_shapes = args.placeholder_shapes or {}
is_dict_input = isinstance(example_inputs, dict)
list_inputs = list(example_inputs.values()) if is_dict_input else example_inputs
input_names = None if not is_dict_input else list(example_inputs)
if not isinstance(list_inputs, (list, tuple)):
list_inputs = [list_inputs]
if not data_types and input_names is None:
data_types = []
if not input_shapes and input_names is None:
input_shapes = []
if inputs:
for input_id, input_info in enumerate(inputs):
input_name = input_info.name
if is_dict_input and input_name in example_inputs:
example_input = example_inputs[input_name]
else:
example_input = list_inputs[input_id]
if is_dict_input and input_name is None:
input_name = input_names[input_id]
dtype = getattr(example_input, "dtype", type(example_input))
example_dtype = pt_to_ov_type_map.get(str(dtype))
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
data_rank = getattr(example_input, "ndim", 0)
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
if user_input_shape.rank.get_length() != data_rank:
raise Error(
f"Requested input shape {user_input_shape.rank.get_length()} rank"
f" is not equal to provided example_input rank {data_rank}")
input_shape = user_input_shape if user_input_shape is not None else PartialShape([-1] * data_rank)
update_list_or_dict(data_types, input_name, input_id, example_dtype.to_dtype())
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
else:
for input_id, example_input in enumerate(list_inputs):
dtype = getattr(example_input, "dtype", type(example_input))
ov_dtype = pt_to_ov_type_map.get(str(dtype))
data_rank = getattr(example_input, "ndim", 0)
input_shape = PartialShape([-1] * data_rank)
input_name = input_names[input_id] if input_names else None
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
args.placeholder_data_types = data_types
args.placeholder_shapes = input_shapes
if not args.input and input_names:
args.input_list = input_names
args.input = ",".join(input_names)
def to_torch_tensor(tensor):
@@ -91,6 +181,7 @@ def prepare_torch_inputs(example_inputs, input_shape, input_info=None, allow_non
input_info = input_to_input_cut_info(input_info) or []
input_shape_to_input_cut_info(input_shape, input_info)
inputs = []
inputs_with_names = {}
for inp in input_info:
shape = inp.shape
if shape is None:
@@ -100,9 +191,14 @@ def prepare_torch_inputs(example_inputs, input_shape, input_info=None, allow_non
break
dtype = get_torch_dtype(inp.type)
static_shape = get_static_shape(shape, dynamic_value=1)
inputs.append(torch.zeros(static_shape, dtype=dtype))
input_tensor = torch.zeros(static_shape, dtype=dtype)
if inp.name is not None:
inputs_with_names[inp.name] = input_tensor
inputs.append(input_tensor)
if isinstance(inputs, list):
inputs = tuple(inputs)
if inputs is not None and len(inputs) == len(inputs_with_names):
inputs = inputs_with_names
else:
if not allow_none:
raise Error("Please provide input_shape or example_input for converting PyTorch model.")