Fix problems with pytorch models passed to convert_model (#17255)

* Do eval() only for torch Module

* Add test

* Support decoder in convert_model

* Enable tests
This commit is contained in:
Maxim Vafin 2023-04-27 16:33:46 +02:00 committed by GitHub
parent 2b8a6ba99a
commit 1d443c6da6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 8 deletions

View File

@ -96,9 +96,7 @@ class TorchScriptPythonDecoder (Decoder):
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
self._input_signature = None
converted_model = False
if graph_element is None:
converted_model = True
pt_module = self._get_scripted_model(pt_module, example_input, freeze)
self.graph_element = pt_module.inlined_graph
else:
@ -106,7 +104,7 @@ class TorchScriptPythonDecoder (Decoder):
self.pt_module = pt_module
self.raw_inputs = list(self.graph_element.inputs())
self.raw_outputs = list(self.graph_element.outputs())
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
if self._input_signature is not None and "self" in self.raw_inputs[0].debugName():
self._input_signature.insert(0, "self")
if isinstance(self.graph_element, torch.Graph):
@ -137,6 +135,7 @@ class TorchScriptPythonDecoder (Decoder):
inputs = [inputs]
return inputs, input_signature
if isinstance(pt_module, torch.nn.Module):
pt_module.eval()
input_signature = None
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):

View File

@ -780,9 +780,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
create_pytorch_nn_module_sample_list_of_tensors,
create_pytorch_jit_script_module,
create_pytorch_jit_script_function,
# Disabled due to Ticket-109430
#create_pytorch_nn_module_layout_list,
#create_pytorch_nn_module_layout_list_case2,
create_pytorch_nn_module_layout_list,
create_pytorch_nn_module_layout_list_case2,
create_pytorch_nn_module_mean_list,
create_pytorch_nn_module_mean_list_default_no_compression,
create_pytorch_nn_module_mean_list_compressin_enabled,

View File

@ -598,3 +598,15 @@ def test_pytorch_decoder_can_convert_optional_tensor_none():
outputs = list(nc_decoder.graph_element.outputs())
assert len(outputs) == 1
assert isinstance(outputs[0].type(), torch.NoneType)
def f(x, y):
return x + y
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_scripted_function():
from openvino.tools.mo import convert_model
scripted = torch.jit.script(f)
model = convert_model(scripted)
assert model is not None

View File

@ -571,8 +571,15 @@ def check_model_object(argv):
return "tf"
if 'torch' in sys.modules:
import torch
if isinstance(model, torch.nn.Module) or isinstance(model, torch.jit.ScriptFunction):
if isinstance(model, (torch.nn.Module, torch.jit.ScriptFunction)):
return "pytorch"
try:
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
if isinstance(model, TorchScriptPythonDecoder):
return "pytorch"
except Exception as e:
pass
import io
if isinstance(model, io.BytesIO):