[PT FE]: allow example input list with single tensor (#19308)

This commit is contained in:
Ekaterina Aidova 2023-08-23 11:08:39 +03:00 committed by GitHub
parent 128ec5452e
commit 80b8b6fff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 13 deletions

View File

@ -35,6 +35,7 @@ class TorchScriptPythonDecoder (Decoder):
self.m_decoders = [] self.m_decoders = []
self._input_signature = None self._input_signature = None
self._shared_memory = shared_memory self._shared_memory = shared_memory
self._input_is_list = False
if graph_element is None: if graph_element is None:
try: try:
pt_module = self._get_scripted_model(pt_module, example_input) pt_module = self._get_scripted_model(pt_module, example_input)
@ -136,11 +137,16 @@ class TorchScriptPythonDecoder (Decoder):
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
def prepare_example_inputs_and_model(inputs, input_params, model): def prepare_example_inputs_and_model(inputs, input_params, model):
input_signature = list(input_params)
if isinstance(inputs, dict): if isinstance(inputs, dict):
return process_dict_inputs(inputs, input_params, model) return process_dict_inputs(inputs, input_params, model)
if isinstance(inputs, list) and len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
if "typing.List" in str(input_params[input_signature[0]].annotation):
inputs = inputs[0].unsqueeze(0)
self._input_is_list = True
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
inputs = [inputs] inputs = [inputs]
input_signature = list(input_params)
input_signature = input_signature[:len(inputs)] input_signature = input_signature[:len(inputs)]
return {"example_inputs": inputs}, input_signature, model return {"example_inputs": inputs}, input_signature, model
@ -164,7 +170,7 @@ class TorchScriptPythonDecoder (Decoder):
try: try:
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False) scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
except Exception as te: except Exception as te:
raise f"Tracing failed with exception {te}\nScripting failed with exception: {se}" raise Exception(f"Tracing failed with exception {te}\nScripting failed with exception: {se}")
skip_freeze = False skip_freeze = False
for n in scripted.inlined_graph.nodes(): for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models # TODO: switch off freezing for all traced models

View File

@ -3,7 +3,7 @@
import os import os
from typing import Tuple from typing import Tuple, List
import numpy import numpy
import numpy as np import numpy as np
import openvino.runtime as ov import openvino.runtime as ov
@ -958,6 +958,48 @@ def create_pytorch_module_with_nested_inputs6(tmp_dir):
"compress_to_fp16": False} "compress_to_fp16": False}
def create_pytorch_module_with_nested_list_and_single_input(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
x0 = x[0]
x0 = torch.cat([x0, torch.zeros(1, 1)], 1)
return x0 + torch.ones((1, 1))
net = PTModel()
constant_one = ov.opset10.constant(np.ones((1, 1)), dtype=np.float32)
const_zero = ov.opset10.constant(0, dtype=np.int32)
constant_zeros1 = ov.opset10.constant(np.zeros((1, 1), dtype=np.float32), dtype=np.float32)
param = ov.opset10.parameter(PartialShape([-1, -1, -1]), dtype=np.float32)
gather = ov.opset10.gather(param, const_zero, const_zero)
concat1 = ov.opset10.concat([gather, constant_zeros1], 1)
add = ov.opset10.add(concat1, constant_one)
ref_model = Model([add], [param], "test")
return net, ref_model, {
"example_input": [torch.ones((1, 11))],
"compress_to_fp16": False}
def create_pytorch_module_with_single_input_as_list(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, x):
x0 = x[0]
x0 = torch.cat([x0, torch.zeros(1)], 0)
return x0 + torch.ones(1)
net = PTModel()
constant_one = ov.opset10.constant(np.ones((1,)), dtype=np.float32)
const_zero = ov.opset10.constant(0, dtype=np.int32)
constant_zeros1 = ov.opset10.constant(np.zeros((1, ), dtype=np.float32), dtype=np.float32)
param = ov.opset10.parameter(PartialShape([-1, -1]), dtype=np.float32)
gather = ov.opset10.gather(param, const_zero, const_zero)
concat1 = ov.opset10.concat([gather, constant_zeros1], 0)
add = ov.opset10.add(concat1, constant_one)
ref_model = Model([add], [param], "test")
return net, ref_model, {
"example_input": [torch.ones((1, 11))],
"compress_to_fp16": False}
class TestMoConvertPyTorch(CommonMOConvertTest): class TestMoConvertPyTorch(CommonMOConvertTest):
test_data = [ test_data = [
create_pytorch_nn_module_case1, create_pytorch_nn_module_case1,
@ -1006,7 +1048,9 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_module_with_nested_inputs3, create_pytorch_module_with_nested_inputs3,
create_pytorch_module_with_nested_inputs4, create_pytorch_module_with_nested_inputs4,
create_pytorch_module_with_nested_inputs5, create_pytorch_module_with_nested_inputs5,
create_pytorch_module_with_nested_inputs6 create_pytorch_module_with_nested_inputs6,
create_pytorch_module_with_nested_list_and_single_input,
create_pytorch_module_with_single_input_as_list
] ]
@pytest.mark.parametrize("create_model", test_data) @pytest.mark.parametrize("create_model", test_data)

View File

@ -78,6 +78,8 @@ def extract_input_info_from_example(args, inputs):
input_names = None input_names = None
if not isinstance(example_inputs, (list, tuple, dict)): if not isinstance(example_inputs, (list, tuple, dict)):
list_inputs = [list_inputs] list_inputs = [list_inputs]
if args.input_model._input_is_list:
list_inputs[0] = list_inputs[0].unsqueeze(0)
if args.input_model._input_signature is not None and not is_dict_input: if args.input_model._input_signature is not None and not is_dict_input:
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
if not is_dict_input: if not is_dict_input:
@ -156,10 +158,6 @@ def prepare_torch_inputs(example_inputs):
inputs = example_inputs inputs = example_inputs
if isinstance(inputs, list): if isinstance(inputs, list):
inputs = [to_torch_tensor(x) for x in inputs] inputs = [to_torch_tensor(x) for x in inputs]
if len(inputs) == 1:
inputs = torch.unsqueeze(inputs[0], 0)
else:
inputs = inputs
elif isinstance(inputs, tuple): elif isinstance(inputs, tuple):
inputs = [to_torch_tensor(x) for x in inputs] inputs = [to_torch_tensor(x) for x in inputs]
inputs = tuple(inputs) inputs = tuple(inputs)

View File

@ -77,6 +77,8 @@ def extract_input_info_from_example(args, inputs):
input_names = None input_names = None
if not isinstance(example_inputs, (list, tuple, dict)): if not isinstance(example_inputs, (list, tuple, dict)):
list_inputs = [list_inputs] list_inputs = [list_inputs]
if args.input_model._input_is_list:
list_inputs[0] = list_inputs[0].unsqueeze(0)
if args.input_model._input_signature is not None and not is_dict_input: if args.input_model._input_signature is not None and not is_dict_input:
input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature input_names = args.input_model._input_signature[1:] if args.input_model._input_signature[0] == "self" else args.input_model._input_signature
if not is_dict_input: if not is_dict_input:
@ -155,10 +157,6 @@ def prepare_torch_inputs(example_inputs):
inputs = example_inputs inputs = example_inputs
if isinstance(inputs, list): if isinstance(inputs, list):
inputs = [to_torch_tensor(x) for x in inputs] inputs = [to_torch_tensor(x) for x in inputs]
if len(inputs) == 1:
inputs = torch.unsqueeze(inputs[0], 0)
else:
inputs = inputs
elif isinstance(inputs, tuple): elif isinstance(inputs, tuple):
inputs = [to_torch_tensor(x) for x in inputs] inputs = [to_torch_tensor(x) for x in inputs]
inputs = tuple(inputs) inputs = tuple(inputs)