[PT FE]: improve integration into mo.convert_model (#16243)

This commit is contained in:
Ekaterina Aidova
2023-03-24 19:55:07 +04:00
committed by GitHub
parent 953a166a62
commit 179403ddc9
22 changed files with 330 additions and 247 deletions

View File

@@ -9,7 +9,6 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
import typing
import warnings
import torch
import numpy as np
@@ -92,18 +91,75 @@ pt_to_ov_type_map = {
class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None):
def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True):
Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
self._input_signature = None
converted_model = False
if graph_element is None:
assert hasattr(pt_module, "inlined_graph"), "graph_element must have inlined_graph"
converted_model = True
pt_module = self._get_scripted_model(pt_module, example_input, freeze)
self.graph_element = pt_module.inlined_graph
else:
self.graph_element = graph_element
self.pt_module = pt_module
self.raw_inputs = list(self.graph_element.inputs())
self.raw_inputs = [inp for inp in self.graph_element.inputs()]
self.raw_outputs = list(self.graph_element.outputs())
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
self._input_signature.insert(0, "self")
def _get_scripted_model(self, pt_module, example_inputs=None, freeze=True):
import torch
import inspect
def prepare_example_inputs(inputs, input_signature):
if inputs is not None:
if isinstance(inputs, dict):
if input_signature is not None:
ordered_inputs = []
used_sign = []
for key in input_signature:
if key not in inputs:
continue
ordered_inputs.append(inputs[key])
used_sign.append(key)
inputs = ordered_inputs
input_signature = used_sign
else:
inputs = list(inputs.values())
input_signature = input_signature[:len(inputs)]
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
return inputs, input_signature
pt_module.eval()
input_signature = None
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
try:
scripted = torch.jit.script(pt_module)
except Exception as scripting_err:
if example_inputs is not None:
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
try:
scripted = torch.jit.trace(pt_module, inputs)
except Exception as tracing_e:
raise tracing_e
else:
raise scripting_err
else:
scripted = pt_module
if freeze:
try:
f_model = torch.jit.freeze(scripted)
except Exception:
# usually freezing failed when model already frozen for inference
f_model = scripted
else:
f_model = scripted
self._input_signature = input_signature
return f_model
def inputs(self) -> list:
return [x.unique() for x in self.raw_inputs]
@@ -114,6 +170,11 @@ class TorchScriptPythonDecoder (Decoder):
def get_input_debug_name(self, index: int) -> str:
return self._raw_input(index).debugName()
def get_input_signature_name(self, index: int) -> str:
if self._input_signature is not None:
return self._input_signature[index]
return self.get_input_debug_name(index)
def get_input_shape(self, index: int):
raw_input = self._raw_input(index)
return self.get_shape_for_value(raw_input)

View File

@@ -26,6 +26,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_input_debug_name, index);
}
const std::string& get_input_signature_name(size_t index) const override {
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_input_signature_name, index);
}
ov::PartialShape get_input_shape(size_t index) const override {
PYBIND11_OVERRIDE_PURE(ov::PartialShape, TorchDecoder, get_input_shape, index);
}