[PT FE]: allow example input list with single tensor (#19308)

This commit is contained in:
Ekaterina Aidova 2023-08-23 11:08:39 +03:00 committed by GitHub
parent 128ec5452e
commit 80b8b6fff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 13 deletions

View File

@ -35,6 +35,7 @@ class TorchScriptPythonDecoder (Decoder):
self.m_decoders = []
self._input_signature = None
self._shared_memory = shared_memory
self._input_is_list = False
if graph_element is None:
try:
pt_module = self._get_scripted_model(pt_module, example_input)
@ -136,11 +137,16 @@ class TorchScriptPythonDecoder (Decoder):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model):
input_signature = list(input_params)
if isinstance(inputs, dict):
return process_dict_inputs(inputs, input_params, model)
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "typing.List" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
self._input_is_list = True
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
input_signature = list(input_params)
input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model
@ -164,7 +170,7 @@ class TorchScriptPythonDecoder (Decoder):
try:
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
except Exception as te:
raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}"
raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}")
skip_freeze = False
for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models

View File

@ -3,7 +3,7 @@
import os
from typing import Tuple
from typing import Tuple, List
import numpy
import numpy as np
import openvino.runtime as ov
@ -958,6 +958,48 @@ def create_pytorch_module_with_nested_inputs6(tmp_dir):
"compress_to_fp16": False}
def create_pytorch_module_with_nested_list_and_single_input(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
x0 = x[0]
x0 = torch.cat([x0, torch.zeros(1, 1)], 1)
return x0 + torch.ones((1, 1))
net = PTModel()
constant_one = ov.opset10.constant(np.ones((1, 1)), dtype=np.float32)
const_zero = ov.opset10.constant(0, dtype=np.int32)
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
param = ov.opset10.parameter(PartialShape([-1, -1, -1]), dtype=np.float32)
gather = ov.opset10.gather(param, const_zero, const_zero)
concat1 = ov.opset10.concat([gather, constant_zeros1], 1)
add = ov.opset10.add(concat1, constant_one)
ref_model = Model([add], [param], "test")
return net, ref_model, {
"example_input": [torch.ones((1, 11))],
"compress_to_fp16": False}
def create_pytorch_module_with_single_input_as_list(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x):
x0 = x[0]
x0 = torch.cat([x0, torch.zeros(1)], 0)
return x0 + torch.ones(1)
net = PTModel()
constant_one = ov.opset10.constant(np.ones((1,)), dtype=np.float32)
const_zero = ov.opset10.constant(0, dtype=np.int32)
constant_zeros1 = ov.opset10.constant(np.zeros((1, ), dtype=np.float32), dtype=np.float32)
param = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
gather = ov.opset10.gather(param, const_zero, const_zero)
concat1 = ov.opset10.concat([gather, constant_zeros1], 0)
add = ov.opset10.add(concat1, constant_one)
ref_model = Model([add], [param], "test")
return net, ref_model, {
"example_input": [torch.ones((1, 11))],
"compress_to_fp16": False}
class TestMoConvertPyTorch(CommonMOConvertTest):
test_data = [
create_pytorch_nn_module_case1,
@ -1006,7 +1048,9 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_module_with_nested_inputs3,
create_pytorch_module_with_nested_inputs4,
create_pytorch_module_with_nested_inputs5,
create_pytorch_module_with_nested_inputs6
create_pytorch_module_with_nested_inputs6,
create_pytorch_module_with_nested_list_and_single_input,
create_pytorch_module_with_single_input_as_list
]
@pytest.mark.parametrize("create_model", test_data)

View File

@ -78,6 +78,8 @@ def extract_input_info_from_example(args, inputs):
input_names = None
if not isinstance(example_inputs, (list, tuple, dict)):
list_inputs = [list_inputs]
if args.input_model._input_is_list:
list_inputs[0] = list_inputs[0].unsqueeze(0)
if args.input_model._input_signature is not None and not is_dict_input:
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
if not is_dict_input:
@ -156,10 +158,6 @@ def prepare_torch_inputs(example_inputs):
inputs = example_inputs
if isinstance(inputs, list):
inputs = [to_torch_tensor(x) for x in inputs]
if len(inputs) == 1:
inputs = torch.unsqueeze(inputs[0], 0)
else:
inputs = inputs
elif isinstance(inputs, tuple):
inputs = [to_torch_tensor(x) for x in inputs]
inputs = tuple(inputs)

View File

@ -77,6 +77,8 @@ def extract_input_info_from_example(args, inputs):
input_names = None
if not isinstance(example_inputs, (list, tuple, dict)):
list_inputs = [list_inputs]
if args.input_model._input_is_list:
list_inputs[0] = list_inputs[0].unsqueeze(0)
if args.input_model._input_signature is not None and not is_dict_input:
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
if not is_dict_input:
@ -155,10 +157,6 @@ def prepare_torch_inputs(example_inputs):
inputs = example_inputs
if isinstance(inputs, list):
inputs = [to_torch_tensor(x) for x in inputs]
if len(inputs) == 1:
inputs = torch.unsqueeze(inputs[0], 0)
else:
inputs = inputs
elif isinstance(inputs, tuple):
inputs = [to_torch_tensor(x) for x in inputs]
inputs = tuple(inputs)