[PT FE]: allow example input list with single tensor (#19308)
This commit is contained in:
parent
128ec5452e
commit
80b8b6fff1
@ -35,6 +35,7 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
self.m_decoders = []
|
self.m_decoders = []
|
||||||
self._input_signature = None
|
self._input_signature = None
|
||||||
self._shared_memory = shared_memory
|
self._shared_memory = shared_memory
|
||||||
|
self._input_is_list = False
|
||||||
if graph_element is None:
|
if graph_element is None:
|
||||||
try:
|
try:
|
||||||
pt_module = self._get_scripted_model(pt_module, example_input)
|
pt_module = self._get_scripted_model(pt_module, example_input)
|
||||||
@ -136,11 +137,16 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
|
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
|
||||||
|
|
||||||
def prepare_example_inputs_and_model(inputs, input_params, model):
|
def prepare_example_inputs_and_model(inputs, input_params, model):
|
||||||
|
input_signature = list(input_params)
|
||||||
if isinstance(inputs, dict):
|
if isinstance(inputs, dict):
|
||||||
return process_dict_inputs(inputs, input_params, model)
|
return process_dict_inputs(inputs, input_params, model)
|
||||||
|
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
|
||||||
|
if "typing.List" in str(input_params[input_signature[0]].annotation):
|
||||||
|
inputs = inputs[0].unsqueeze(0)
|
||||||
|
self._input_is_list = True
|
||||||
|
|
||||||
if isinstance(inputs, torch.Tensor):
|
if isinstance(inputs, torch.Tensor):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
input_signature = list(input_params)
|
|
||||||
input_signature = input_signature[:len(inputs)]
|
input_signature = input_signature[:len(inputs)]
|
||||||
return {"example_inputs": inputs}, input_signature, model
|
return {"example_inputs": inputs}, input_signature, model
|
||||||
|
|
||||||
@ -164,7 +170,7 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
try:
|
try:
|
||||||
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
|
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
|
||||||
except Exception as te:
|
except Exception as te:
|
||||||
raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}"
|
raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}")
|
||||||
skip_freeze = False
|
skip_freeze = False
|
||||||
for n in scripted.inlined_graph.nodes():
|
for n in scripted.inlined_graph.nodes():
|
||||||
# TODO: switch off freezing for all traced models
|
# TODO: switch off freezing for all traced models
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple, List
|
||||||
import numpy
|
import numpy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import openvino.runtime as ov
|
import openvino.runtime as ov
|
||||||
@ -958,6 +958,48 @@ def create_pytorch_module_with_nested_inputs6(tmp_dir):
|
|||||||
"compress_to_fp16": False}
|
"compress_to_fp16": False}
|
||||||
|
|
||||||
|
|
||||||
|
def create_pytorch_module_with_nested_list_and_single_input(tmp_dir):
|
||||||
|
class PTModel(torch.nn.Module):
|
||||||
|
def forward(self, x: List[torch.Tensor]):
|
||||||
|
x0 = x[0]
|
||||||
|
x0 = torch.cat([x0, torch.zeros(1, 1)], 1)
|
||||||
|
return x0 + torch.ones((1, 1))
|
||||||
|
|
||||||
|
net = PTModel()
|
||||||
|
constant_one = ov.opset10.constant(np.ones((1, 1)), dtype=np.float32)
|
||||||
|
const_zero = ov.opset10.constant(0, dtype=np.int32)
|
||||||
|
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
|
||||||
|
|
||||||
|
param = ov.opset10.parameter(PartialShape([-1, -1, -1]), dtype=np.float32)
|
||||||
|
gather = ov.opset10.gather(param, const_zero, const_zero)
|
||||||
|
concat1 = ov.opset10.concat([gather, constant_zeros1], 1)
|
||||||
|
add = ov.opset10.add(concat1, constant_one)
|
||||||
|
ref_model = Model([add], [param], "test")
|
||||||
|
return net, ref_model, {
|
||||||
|
"example_input": [torch.ones((1, 11))],
|
||||||
|
"compress_to_fp16": False}
|
||||||
|
|
||||||
|
def create_pytorch_module_with_single_input_as_list(tmp_dir):
|
||||||
|
class PTModel(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x0 = x[0]
|
||||||
|
x0 = torch.cat([x0, torch.zeros(1)], 0)
|
||||||
|
return x0 + torch.ones(1)
|
||||||
|
|
||||||
|
net = PTModel()
|
||||||
|
constant_one = ov.opset10.constant(np.ones((1,)), dtype=np.float32)
|
||||||
|
const_zero = ov.opset10.constant(0, dtype=np.int32)
|
||||||
|
constant_zeros1 = ov.opset10.constant(np.zeros((1, ), dtype=np.float32), dtype=np.float32)
|
||||||
|
|
||||||
|
param = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
|
||||||
|
gather = ov.opset10.gather(param, const_zero, const_zero)
|
||||||
|
concat1 = ov.opset10.concat([gather, constant_zeros1], 0)
|
||||||
|
add = ov.opset10.add(concat1, constant_one)
|
||||||
|
ref_model = Model([add], [param], "test")
|
||||||
|
return net, ref_model, {
|
||||||
|
"example_input": [torch.ones((1, 11))],
|
||||||
|
"compress_to_fp16": False}
|
||||||
|
|
||||||
class TestMoConvertPyTorch(CommonMOConvertTest):
|
class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||||
test_data = [
|
test_data = [
|
||||||
create_pytorch_nn_module_case1,
|
create_pytorch_nn_module_case1,
|
||||||
@ -1006,7 +1048,9 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
|||||||
create_pytorch_module_with_nested_inputs3,
|
create_pytorch_module_with_nested_inputs3,
|
||||||
create_pytorch_module_with_nested_inputs4,
|
create_pytorch_module_with_nested_inputs4,
|
||||||
create_pytorch_module_with_nested_inputs5,
|
create_pytorch_module_with_nested_inputs5,
|
||||||
create_pytorch_module_with_nested_inputs6
|
create_pytorch_module_with_nested_inputs6,
|
||||||
|
create_pytorch_module_with_nested_list_and_single_input,
|
||||||
|
create_pytorch_module_with_single_input_as_list
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.parametrize("create_model", test_data)
|
@pytest.mark.parametrize("create_model", test_data)
|
||||||
|
@ -78,6 +78,8 @@ def extract_input_info_from_example(args, inputs):
|
|||||||
input_names = None
|
input_names = None
|
||||||
if not isinstance(example_inputs, (list, tuple, dict)):
|
if not isinstance(example_inputs, (list, tuple, dict)):
|
||||||
list_inputs = [list_inputs]
|
list_inputs = [list_inputs]
|
||||||
|
if args.input_model._input_is_list:
|
||||||
|
list_inputs[0] = list_inputs[0].unsqueeze(0)
|
||||||
if args.input_model._input_signature is not None and not is_dict_input:
|
if args.input_model._input_signature is not None and not is_dict_input:
|
||||||
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
||||||
if not is_dict_input:
|
if not is_dict_input:
|
||||||
@ -156,10 +158,6 @@ def prepare_torch_inputs(example_inputs):
|
|||||||
inputs = example_inputs
|
inputs = example_inputs
|
||||||
if isinstance(inputs, list):
|
if isinstance(inputs, list):
|
||||||
inputs = [to_torch_tensor(x) for x in inputs]
|
inputs = [to_torch_tensor(x) for x in inputs]
|
||||||
if len(inputs) == 1:
|
|
||||||
inputs = torch.unsqueeze(inputs[0], 0)
|
|
||||||
else:
|
|
||||||
inputs = inputs
|
|
||||||
elif isinstance(inputs, tuple):
|
elif isinstance(inputs, tuple):
|
||||||
inputs = [to_torch_tensor(x) for x in inputs]
|
inputs = [to_torch_tensor(x) for x in inputs]
|
||||||
inputs = tuple(inputs)
|
inputs = tuple(inputs)
|
||||||
|
@ -77,6 +77,8 @@ def extract_input_info_from_example(args, inputs):
|
|||||||
input_names = None
|
input_names = None
|
||||||
if not isinstance(example_inputs, (list, tuple, dict)):
|
if not isinstance(example_inputs, (list, tuple, dict)):
|
||||||
list_inputs = [list_inputs]
|
list_inputs = [list_inputs]
|
||||||
|
if args.input_model._input_is_list:
|
||||||
|
list_inputs[0] = list_inputs[0].unsqueeze(0)
|
||||||
if args.input_model._input_signature is not None and not is_dict_input:
|
if args.input_model._input_signature is not None and not is_dict_input:
|
||||||
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
|
||||||
if not is_dict_input:
|
if not is_dict_input:
|
||||||
@ -155,10 +157,6 @@ def prepare_torch_inputs(example_inputs):
|
|||||||
inputs = example_inputs
|
inputs = example_inputs
|
||||||
if isinstance(inputs, list):
|
if isinstance(inputs, list):
|
||||||
inputs = [to_torch_tensor(x) for x in inputs]
|
inputs = [to_torch_tensor(x) for x in inputs]
|
||||||
if len(inputs) == 1:
|
|
||||||
inputs = torch.unsqueeze(inputs[0], 0)
|
|
||||||
else:
|
|
||||||
inputs = inputs
|
|
||||||
elif isinstance(inputs, tuple):
|
elif isinstance(inputs, tuple):
|
||||||
inputs = [to_torch_tensor(x) for x in inputs]
|
inputs = [to_torch_tensor(x) for x in inputs]
|
||||||
inputs = tuple(inputs)
|
inputs = tuple(inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user