[PT FE]: improve integration into mo.convert_model (#16243)
This commit is contained in:
@@ -9,7 +9,6 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
|
||||
|
||||
import typing
|
||||
import warnings
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@@ -92,18 +91,75 @@ pt_to_ov_type_map = {
|
||||
|
||||
|
||||
class TorchScriptPythonDecoder (Decoder):
|
||||
def __init__(self, pt_module, graph_element=None):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True):
|
||||
Decoder.__init__(self)
|
||||
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
|
||||
self.m_decoders = []
|
||||
self._input_signature = None
|
||||
converted_model = False
|
||||
if graph_element is None:
|
||||
assert hasattr(pt_module, "inlined_graph"), "graph_element must have inlined_graph"
|
||||
converted_model = True
|
||||
pt_module = self._get_scripted_model(pt_module, example_input, freeze)
|
||||
self.graph_element = pt_module.inlined_graph
|
||||
else:
|
||||
self.graph_element = graph_element
|
||||
self.pt_module = pt_module
|
||||
self.raw_inputs = list(self.graph_element.inputs())
|
||||
self.raw_inputs = [inp for inp in self.graph_element.inputs()]
|
||||
self.raw_outputs = list(self.graph_element.outputs())
|
||||
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
|
||||
self._input_signature.insert(0, "self")
|
||||
|
||||
def _get_scripted_model(self, pt_module, example_inputs=None, freeze=True):
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
def prepare_example_inputs(inputs, input_signature):
|
||||
if inputs is not None:
|
||||
if isinstance(inputs, dict):
|
||||
if input_signature is not None:
|
||||
ordered_inputs = []
|
||||
used_sign = []
|
||||
for key in input_signature:
|
||||
if key not in inputs:
|
||||
continue
|
||||
ordered_inputs.append(inputs[key])
|
||||
used_sign.append(key)
|
||||
inputs = ordered_inputs
|
||||
input_signature = used_sign
|
||||
else:
|
||||
inputs = list(inputs.values())
|
||||
input_signature = input_signature[:len(inputs)]
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
return inputs, input_signature
|
||||
|
||||
pt_module.eval()
|
||||
input_signature = None
|
||||
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
||||
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
|
||||
try:
|
||||
scripted = torch.jit.script(pt_module)
|
||||
except Exception as scripting_err:
|
||||
if example_inputs is not None:
|
||||
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, inputs)
|
||||
except Exception as tracing_e:
|
||||
raise tracing_e
|
||||
else:
|
||||
raise scripting_err
|
||||
else:
|
||||
scripted = pt_module
|
||||
if freeze:
|
||||
try:
|
||||
f_model = torch.jit.freeze(scripted)
|
||||
except Exception:
|
||||
# usually freezing failed when model already frozen for inference
|
||||
f_model = scripted
|
||||
else:
|
||||
f_model = scripted
|
||||
self._input_signature = input_signature
|
||||
return f_model
|
||||
|
||||
def inputs(self) -> list:
|
||||
return [x.unique() for x in self.raw_inputs]
|
||||
@@ -114,6 +170,11 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def get_input_debug_name(self, index: int) -> str:
|
||||
return self._raw_input(index).debugName()
|
||||
|
||||
def get_input_signature_name(self, index: int) -> str:
|
||||
if self._input_signature is not None:
|
||||
return self._input_signature[index]
|
||||
return self.get_input_debug_name(index)
|
||||
|
||||
def get_input_shape(self, index: int):
|
||||
raw_input = self._raw_input(index)
|
||||
return self.get_shape_for_value(raw_input)
|
||||
|
||||
@@ -26,6 +26,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_input_debug_name, index);
|
||||
}
|
||||
|
||||
const std::string& get_input_signature_name(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_input_signature_name, index);
|
||||
}
|
||||
|
||||
ov::PartialShape get_input_shape(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(ov::PartialShape, TorchDecoder, get_input_shape, index);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user