[PT FE]: allow example input list with single tensor (#19308)
This commit is contained in:
parent
128ec5452e
commit
80b8b6fff1
@ -35,6 +35,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
self.m_decoders = []
|
||||
self._input_signature = None
|
||||
self._shared_memory = shared_memory
|
||||
self._input_is_list = False
|
||||
if graph_element is None:
|
||||
try:
|
||||
pt_module = self._get_scripted_model(pt_module, example_input)
|
||||
@ -136,11 +137,16 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
|
||||
|
||||
def prepare_example_inputs_and_model(inputs, input_params, model):
|
||||
input_signature = list(input_params)
|
||||
if isinstance(inputs, dict):
|
||||
return process_dict_inputs(inputs, input_params, model)
|
||||
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
|
||||
if "typing.List" in str(input_params[input_signature[0]].annotation):
|
||||
inputs = inputs[0].unsqueeze(0)
|
||||
self._input_is_list = True
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
input_signature = list(input_params)
|
||||
input_signature = input_signature[:len(inputs)]
|
||||
return {"example_inputs": inputs}, input_signature, model
|
||||
|
||||
@ -164,7 +170,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
|
||||
except Exception as te:
|
||||
raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}"
|
||||
raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}")
|
||||
skip_freeze = False
|
||||
for n in scripted.inlined_graph.nodes():
|
||||
# TODO: switch off freezing for all traced models
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import os
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List
|
||||
import numpy
|
||||
import numpy as np
|
||||
import openvino.runtime as ov
|
||||
@ -958,6 +958,48 @@ def create_pytorch_module_with_nested_inputs6(tmp_dir):
|
||||
"compress_to_fp16": False}
|
||||
|
||||
|
||||
def create_pytorch_module_with_nested_list_and_single_input(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
x0 = x[0]
|
||||
x0 = torch.cat([x0, torch.zeros(1, 1)], 1)
|
||||
return x0 + torch.ones((1, 1))
|
||||
|
||||
net = PTModel()
|
||||
constant_one = ov.opset10.constant(np.ones((1, 1)), dtype=np.float32)
|
||||
const_zero = ov.opset10.constant(0, dtype=np.int32)
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||
|
||||
param = ov.opset10.parameter(PartialShape([-1, -1, -1]), dtype=np.float32)
|
||||
gather = ov.opset10.gather(param, const_zero, const_zero)
|
||||
concat1 = ov.opset10.concat([gather, constant_zeros1], 1)
|
||||
add = ov.opset10.add(concat1, constant_one)
|
||||
ref_model = Model([add], [param], "test")
|
||||
return net, ref_model, {
|
||||
"example_input": [torch.ones((1, 11))],
|
||||
"compress_to_fp16": False}
|
||||
|
||||
def create_pytorch_module_with_single_input_as_list(tmp_dir):
|
||||
class PTModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x0 = x[0]
|
||||
x0 = torch.cat([x0, torch.zeros(1)], 0)
|
||||
return x0 + torch.ones(1)
|
||||
|
||||
net = PTModel()
|
||||
constant_one = ov.opset10.constant(np.ones((1,)), dtype=np.float32)
|
||||
const_zero = ov.opset10.constant(0, dtype=np.int32)
|
||||
constant_zeros1 = ov.opset10.constant(np.zeros((1, ), dtype=np.float32), dtype=np.float32)
|
||||
|
||||
param = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||
gather = ov.opset10.gather(param, const_zero, const_zero)
|
||||
concat1 = ov.opset10.concat([gather, constant_zeros1], 0)
|
||||
add = ov.opset10.add(concat1, constant_one)
|
||||
ref_model = Model([add], [param], "test")
|
||||
return net, ref_model, {
|
||||
"example_input": [torch.ones((1, 11))],
|
||||
"compress_to_fp16": False}
|
||||
|
||||
class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
test_data = [
|
||||
create_pytorch_nn_module_case1,
|
||||
@ -1006,7 +1048,9 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_module_with_nested_inputs3,
|
||||
create_pytorch_module_with_nested_inputs4,
|
||||
create_pytorch_module_with_nested_inputs5,
|
||||
create_pytorch_module_with_nested_inputs6
|
||||
create_pytorch_module_with_nested_inputs6,
|
||||
create_pytorch_module_with_nested_list_and_single_input,
|
||||
create_pytorch_module_with_single_input_as_list
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("create_model", test_data)
|
||||
|
@ -78,6 +78,8 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_names = None
|
||||
if not isinstance(example_inputs, (list, tuple, dict)):
|
||||
list_inputs = [list_inputs]
|
||||
if args.input_model._input_is_list:
|
||||
list_inputs[0] = list_inputs[0].unsqueeze(0)
|
||||
if args.input_model._input_signature is not None and not is_dict_input:
|
||||
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
||||
if not is_dict_input:
|
||||
@ -156,10 +158,6 @@ def prepare_torch_inputs(example_inputs):
|
||||
inputs = example_inputs
|
||||
if isinstance(inputs, list):
|
||||
inputs = [to_torch_tensor(x) for x in inputs]
|
||||
if len(inputs) == 1:
|
||||
inputs = torch.unsqueeze(inputs[0], 0)
|
||||
else:
|
||||
inputs = inputs
|
||||
elif isinstance(inputs, tuple):
|
||||
inputs = [to_torch_tensor(x) for x in inputs]
|
||||
inputs = tuple(inputs)
|
||||
|
@ -65,7 +65,7 @@ def get_value_from_list_or_dict(container, name, idx):
|
||||
|
||||
def extract_input_info_from_example(args, inputs):
|
||||
try:
|
||||
from openvino.frontend.pytorch.utils import pt_to_ov_type_map # pylint: disable=no-name-in-module,import-error
|
||||
from openvino.frontend.pytorch.utils import pt_to_ov_type_map # pylint: disable=no-name-in-module,import-error
|
||||
except Exception as e:
|
||||
log.error("PyTorch frontend loading failed")
|
||||
raise e
|
||||
@ -77,6 +77,8 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_names = None
|
||||
if not isinstance(example_inputs, (list, tuple, dict)):
|
||||
list_inputs = [list_inputs]
|
||||
if args.input_model._input_is_list:
|
||||
list_inputs[0] = list_inputs[0].unsqueeze(0)
|
||||
if args.input_model._input_signature is not None and not is_dict_input:
|
||||
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
||||
if not is_dict_input:
|
||||
@ -155,10 +157,6 @@ def prepare_torch_inputs(example_inputs):
|
||||
inputs = example_inputs
|
||||
if isinstance(inputs, list):
|
||||
inputs = [to_torch_tensor(x) for x in inputs]
|
||||
if len(inputs) == 1:
|
||||
inputs = torch.unsqueeze(inputs[0], 0)
|
||||
else:
|
||||
inputs = inputs
|
||||
elif isinstance(inputs, tuple):
|
||||
inputs = [to_torch_tensor(x) for x in inputs]
|
||||
inputs = tuple(inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user