Fix problems with pytorch models passed to convert_model (#17255)
* Do eval() only for torch Module * Add test * Support decoder in convert_model * Enable tests
This commit is contained in:
parent
2b8a6ba99a
commit
1d443c6da6
@ -96,9 +96,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
|
||||
self.m_decoders = []
|
||||
self._input_signature = None
|
||||
converted_model = False
|
||||
if graph_element is None:
|
||||
converted_model = True
|
||||
pt_module = self._get_scripted_model(pt_module, example_input, freeze)
|
||||
self.graph_element = pt_module.inlined_graph
|
||||
else:
|
||||
@ -106,7 +104,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
self.pt_module = pt_module
|
||||
self.raw_inputs = list(self.graph_element.inputs())
|
||||
self.raw_outputs = list(self.graph_element.outputs())
|
||||
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
|
||||
if self._input_signature is not None and "self" in self.raw_inputs[0].debugName():
|
||||
self._input_signature.insert(0, "self")
|
||||
|
||||
if isinstance(self.graph_element, torch.Graph):
|
||||
@ -137,7 +135,8 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
inputs = [inputs]
|
||||
return inputs, input_signature
|
||||
|
||||
pt_module.eval()
|
||||
if isinstance(pt_module, torch.nn.Module):
|
||||
pt_module.eval()
|
||||
input_signature = None
|
||||
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
||||
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
|
||||
|
@ -780,9 +780,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_nn_module_sample_list_of_tensors,
|
||||
create_pytorch_jit_script_module,
|
||||
create_pytorch_jit_script_function,
|
||||
# Disabled due to Ticket-109430
|
||||
#create_pytorch_nn_module_layout_list,
|
||||
#create_pytorch_nn_module_layout_list_case2,
|
||||
create_pytorch_nn_module_layout_list,
|
||||
create_pytorch_nn_module_layout_list_case2,
|
||||
create_pytorch_nn_module_mean_list,
|
||||
create_pytorch_nn_module_mean_list_default_no_compression,
|
||||
create_pytorch_nn_module_mean_list_compressin_enabled,
|
||||
|
@ -598,3 +598,15 @@ def test_pytorch_decoder_can_convert_optional_tensor_none():
|
||||
outputs = list(nc_decoder.graph_element.outputs())
|
||||
assert len(outputs) == 1
|
||||
assert isinstance(outputs[0].type(), torch.NoneType)
|
||||
|
||||
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_scripted_function():
|
||||
from openvino.tools.mo import convert_model
|
||||
scripted = torch.jit.script(f)
|
||||
model = convert_model(scripted)
|
||||
assert model is not None
|
||||
|
@ -571,8 +571,15 @@ def check_model_object(argv):
|
||||
return "tf"
|
||||
if 'torch' in sys.modules:
|
||||
import torch
|
||||
if isinstance(model, torch.nn.Module) or isinstance(model, torch.jit.ScriptFunction):
|
||||
if isinstance(model, (torch.nn.Module, torch.jit.ScriptFunction)):
|
||||
return "pytorch"
|
||||
try:
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
|
||||
if isinstance(model, TorchScriptPythonDecoder):
|
||||
return "pytorch"
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
import io
|
||||
if isinstance(model, io.BytesIO):
|
||||
|
Loading…
Reference in New Issue
Block a user