[PT FE]: improve integration into mo.convert_model (#16243)
This commit is contained in:
parent
953a166a62
commit
179403ddc9
@ -9,7 +9,6 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
|
||||
|
||||
import typing
|
||||
import warnings
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@ -92,18 +91,75 @@ pt_to_ov_type_map = {
|
||||
|
||||
|
||||
class TorchScriptPythonDecoder (Decoder):
|
||||
def __init__(self, pt_module, graph_element=None):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True):
|
||||
Decoder.__init__(self)
|
||||
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
|
||||
self.m_decoders = []
|
||||
self._input_signature = None
|
||||
converted_model = False
|
||||
if graph_element is None:
|
||||
assert hasattr(pt_module, "inlined_graph"), "graph_element must have inlined_graph"
|
||||
converted_model = True
|
||||
pt_module = self._get_scripted_model(pt_module, example_input, freeze)
|
||||
self.graph_element = pt_module.inlined_graph
|
||||
else:
|
||||
self.graph_element = graph_element
|
||||
self.pt_module = pt_module
|
||||
self.raw_inputs = list(self.graph_element.inputs())
|
||||
self.raw_inputs = [inp for inp in self.graph_element.inputs()]
|
||||
self.raw_outputs = list(self.graph_element.outputs())
|
||||
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
|
||||
self._input_signature.insert(0, "self")
|
||||
|
||||
def _get_scripted_model(self, pt_module, example_inputs=None, freeze=True):
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
def prepare_example_inputs(inputs, input_signature):
|
||||
if inputs is not None:
|
||||
if isinstance(inputs, dict):
|
||||
if input_signature is not None:
|
||||
ordered_inputs = []
|
||||
used_sign = []
|
||||
for key in input_signature:
|
||||
if key not in inputs:
|
||||
continue
|
||||
ordered_inputs.append(inputs[key])
|
||||
used_sign.append(key)
|
||||
inputs = ordered_inputs
|
||||
input_signature = used_sign
|
||||
else:
|
||||
inputs = list(inputs.values())
|
||||
input_signature = input_signature[:len(inputs)]
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
return inputs, input_signature
|
||||
|
||||
pt_module.eval()
|
||||
input_signature = None
|
||||
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
||||
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
|
||||
try:
|
||||
scripted = torch.jit.script(pt_module)
|
||||
except Exception as scripting_err:
|
||||
if example_inputs is not None:
|
||||
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, inputs)
|
||||
except Exception as tracing_e:
|
||||
raise tracing_e
|
||||
else:
|
||||
raise scripting_err
|
||||
else:
|
||||
scripted = pt_module
|
||||
if freeze:
|
||||
try:
|
||||
f_model = torch.jit.freeze(scripted)
|
||||
except Exception:
|
||||
# usually freezing failed when model already frozen for inference
|
||||
f_model = scripted
|
||||
else:
|
||||
f_model = scripted
|
||||
self._input_signature = input_signature
|
||||
return f_model
|
||||
|
||||
def inputs(self) -> list:
|
||||
return [x.unique() for x in self.raw_inputs]
|
||||
@ -114,6 +170,11 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def get_input_debug_name(self, index: int) -> str:
|
||||
return self._raw_input(index).debugName()
|
||||
|
||||
def get_input_signature_name(self, index: int) -> str:
|
||||
if self._input_signature is not None:
|
||||
return self._input_signature[index]
|
||||
return self.get_input_debug_name(index)
|
||||
|
||||
def get_input_shape(self, index: int):
|
||||
raw_input = self._raw_input(index)
|
||||
return self.get_shape_for_value(raw_input)
|
||||
|
@ -26,6 +26,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_input_debug_name, index);
|
||||
}
|
||||
|
||||
const std::string& get_input_signature_name(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_input_signature_name, index);
|
||||
}
|
||||
|
||||
ov::PartialShape get_input_shape(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(ov::PartialShape, TorchDecoder, get_input_shape, index);
|
||||
}
|
||||
|
@ -34,6 +34,9 @@ public:
|
||||
// Return debug name of the input tensor
|
||||
virtual const std::string& get_input_debug_name(size_t index) const = 0;
|
||||
|
||||
// Return signature name of the input tensor
|
||||
virtual const std::string& get_input_signature_name(size_t index) const = 0;
|
||||
|
||||
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
|
||||
virtual PartialShape get_input_shape(size_t index) const = 0;
|
||||
|
||||
|
@ -45,7 +45,9 @@ std::vector<ov::frontend::Place::Ptr> InputModel::get_inputs() const {
|
||||
for (const auto& input_idx : m_model_decoder->inputs()) {
|
||||
auto place_it = m_name_to_place.find(std::to_string(input_idx));
|
||||
FRONT_END_GENERAL_CHECK(place_it != m_name_to_place.end(), "Couldn't find Place for input.");
|
||||
res.push_back(place_it->second);
|
||||
if (input_idx != 0) {
|
||||
res.push_back(place_it->second);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -46,7 +46,10 @@ void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
|
||||
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
|
||||
auto input_id = m_decoder_inputs.at(index);
|
||||
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist.");
|
||||
m_translate_session->encode_tensor_name(ov_output, input_id, m_decoder->get_input_debug_name(index));
|
||||
m_translate_session->encode_tensor_name(
|
||||
ov_output,
|
||||
input_id,
|
||||
{m_decoder->get_input_debug_name(index), m_decoder->get_input_signature_name(index)});
|
||||
(*m_tensor_map)[input_id] = ov_output;
|
||||
m_mutated_tensors->insert(input_id);
|
||||
}
|
||||
|
@ -30,6 +30,11 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
|
||||
if (debug_name != m_names.at(0)) {
|
||||
m_names.push_back(debug_name);
|
||||
}
|
||||
const auto& signature_name =
|
||||
im->m_model_decoder->get_input_signature_name(std::distance(inputs.begin(), in_it));
|
||||
if (signature_name != m_names.at(0) && signature_name != debug_name) {
|
||||
m_names.push_back(signature_name);
|
||||
}
|
||||
}
|
||||
auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index);
|
||||
if (out_it != outputs.end()) {
|
||||
|
@ -33,6 +33,14 @@ public:
|
||||
return m_tensor_index;
|
||||
}
|
||||
|
||||
bool is_equal_data(const Ptr& another) const override {
|
||||
const auto another_pt = dynamic_cast<ov::frontend::pytorch::Place*>(another.get());
|
||||
if (!another_pt) {
|
||||
return false;
|
||||
}
|
||||
return m_tensor_index == another_pt->get_tensor_index();
|
||||
}
|
||||
|
||||
private:
|
||||
const ov::frontend::InputModel& m_input_model;
|
||||
const size_t m_tensor_index;
|
||||
|
@ -84,7 +84,10 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
}
|
||||
if (!input_node) {
|
||||
auto parameter = std::make_shared<v0::Parameter>(type, pshape);
|
||||
encode_tensor_name(parameter->output(0), inputs.at(i), pytorch_model->get_input_debug_name(i));
|
||||
encode_tensor_name(
|
||||
parameter->output(0),
|
||||
inputs.at(i),
|
||||
{pytorch_model->get_input_debug_name(i), pytorch_model->get_input_signature_name(i)});
|
||||
parameters->push_back(parameter);
|
||||
input_node = parameter;
|
||||
auto order = pytorch_model->get_input_transpose_order(i);
|
||||
@ -148,7 +151,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
"Duplicated producer for PT value with unique ID: ",
|
||||
fw_tensor_id);
|
||||
(*tensor_map)[fw_tensor_id] = converted_outputs[i];
|
||||
encode_tensor_name(converted_outputs[i], fw_tensor_id, node->get_output_debug_name(i));
|
||||
encode_tensor_name(converted_outputs[i], fw_tensor_id, {node->get_output_debug_name(i)});
|
||||
}
|
||||
};
|
||||
|
||||
@ -221,32 +224,29 @@ OutputVector TranslateSession::convert_node(const NodeContext& context) {
|
||||
return make_framework_node(context);
|
||||
}
|
||||
|
||||
void TranslateSession::encode_tensor_name(Output<Node> output, size_t tensor_idx, std::string debug_name) {
|
||||
void TranslateSession::encode_tensor_name(Output<Node> output,
|
||||
size_t tensor_idx,
|
||||
std::vector<std::string> additional_names) {
|
||||
if (!output.get_names().empty()) {
|
||||
OPENVINO_DEBUG << "Tensor names already exist: " << output.get_any_name() << ". Rewriting with " << tensor_idx;
|
||||
}
|
||||
auto has_dname = !debug_name.empty();
|
||||
auto name = std::to_string(tensor_idx);
|
||||
if (has_dname && name == debug_name)
|
||||
has_dname = false;
|
||||
std::unordered_set<std::string> names;
|
||||
names.insert(name);
|
||||
if (additional_names.size() > 0) {
|
||||
names.insert(additional_names.begin(), additional_names.end());
|
||||
}
|
||||
|
||||
if (m_counter_map.count(tensor_idx)) {
|
||||
auto&& pair = m_counter_map[tensor_idx];
|
||||
auto new_name = name + '_' + std::to_string(++pair.first);
|
||||
pair.second.set_names({new_name});
|
||||
pair.second = output;
|
||||
if (has_dname) {
|
||||
output.set_names({name, debug_name});
|
||||
} else {
|
||||
output.set_names({name});
|
||||
}
|
||||
output.set_names(names);
|
||||
|
||||
} else {
|
||||
m_counter_map[tensor_idx] = {0, output};
|
||||
if (has_dname) {
|
||||
output.set_names({name, debug_name});
|
||||
} else {
|
||||
output.set_names({name});
|
||||
}
|
||||
output.set_names(names);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -35,7 +35,9 @@ public:
|
||||
const TensorMap& external_tensor_map = {},
|
||||
const std::unordered_map<size_t, PlaceDesc>& external_descriptors = {});
|
||||
|
||||
void encode_tensor_name(Output<Node> tensor_desc, size_t tensor_idx, std::string debug_name = "");
|
||||
void encode_tensor_name(Output<Node> tensor_desc,
|
||||
size_t tensor_idx,
|
||||
std::vector<std::string> additional_names = {});
|
||||
size_t decode_tensor_name(const Output<Node>& tensor_desc);
|
||||
|
||||
size_t m_friendly_name_counter = 0;
|
||||
|
@ -10,6 +10,7 @@ import pytest
|
||||
import torch
|
||||
import unittest
|
||||
from openvino.runtime import PartialShape, Dimension, Model, Type
|
||||
from openvino.tools.mo import InputCutInfo
|
||||
|
||||
from common.mo_convert_test_class import CommonMOConvertTest
|
||||
|
||||
@ -27,6 +28,7 @@ class MyTorchOp(torch.autograd.Function):
|
||||
|
||||
def make_pt_model_one_input():
|
||||
from torch import nn
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
@ -44,6 +46,7 @@ def make_pt_model_one_input():
|
||||
|
||||
def make_pt_model_two_inputs():
|
||||
from torch import nn
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
@ -72,8 +75,10 @@ def make_ref_pt_model_one_input(shape, dtype=np.float32):
|
||||
|
||||
def make_ref_pt_model_two_inputs(shape, dtype=np.float32):
|
||||
if len(shape) == 2:
|
||||
param1 = ov.opset8.parameter(PartialShape(shape[0]), name="input_0", dtype=dtype)
|
||||
param2 = ov.opset8.parameter(PartialShape(shape[1]), name="input_1", dtype=dtype)
|
||||
param1 = ov.opset8.parameter(PartialShape(
|
||||
shape[0]), name="input_0", dtype=dtype)
|
||||
param2 = ov.opset8.parameter(PartialShape(
|
||||
shape[1]), name="input_1", dtype=dtype)
|
||||
else:
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset8.parameter(shape, name="input_0", dtype=dtype)
|
||||
@ -176,7 +181,8 @@ def create_pytorch_nn_module_sample_input_int32_two_inputs(tmp_dir):
|
||||
sample_input1 = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
|
||||
sample_input2 = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
|
||||
sample_input = sample_input1, sample_input2
|
||||
ref_model = make_ref_pt_model_two_inputs([PartialShape([-1, 3, -1, -1]), inp_shapes[1]], dtype=np.int32)
|
||||
ref_model = make_ref_pt_model_two_inputs(
|
||||
[PartialShape([-1, 3, -1, -1]), inp_shapes[1]], dtype=np.int32)
|
||||
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes,
|
||||
'example_input': sample_input, 'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
@ -188,7 +194,8 @@ def create_pytorch_nn_module_compare_convert_paths_case1(tmp_dir):
|
||||
|
||||
sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, sample_input, onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, sample_input,
|
||||
onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'example_input': sample_input, 'onnx_opset_version': 16, "use_legacy_frontend": True}
|
||||
@ -200,7 +207,8 @@ def create_pytorch_nn_module_compare_convert_paths_case2(tmp_dir):
|
||||
|
||||
sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.int32)
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, sample_input, onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, sample_input,
|
||||
onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'example_input': sample_input,
|
||||
@ -216,7 +224,8 @@ def create_pytorch_nn_module_compare_convert_paths_case3(tmp_dir):
|
||||
|
||||
sample_input = torch.zeros(1, 3, 10, 10, dtype=torch.float32)
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, sample_input, onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, sample_input,
|
||||
onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'input_shape': [1, 3, 10, 10],
|
||||
@ -232,7 +241,8 @@ def create_pytorch_nn_module_compare_convert_paths_case4(tmp_dir):
|
||||
sample_input = (sample_input1, sample_input2)
|
||||
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, sample_input, onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, sample_input,
|
||||
onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
|
||||
@ -248,7 +258,8 @@ def create_pytorch_nn_module_compare_convert_paths_case5(tmp_dir):
|
||||
sample_input = tuple([sample_input1, sample_input2])
|
||||
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, sample_input, onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, sample_input,
|
||||
onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
|
||||
@ -266,7 +277,8 @@ def create_pytorch_nn_module_compare_convert_paths_case6(tmp_dir):
|
||||
sample_input = tuple([sample_input1, sample_input2])
|
||||
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, sample_input, onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, sample_input,
|
||||
onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
|
||||
@ -302,7 +314,8 @@ def create_pytorch_nn_module_sample_input_numpy(tmp_dir):
|
||||
|
||||
example_inputs = np.array(torch.zeros(1, 3, 10, 10, dtype=torch.int32))
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, torch.zeros(1, 3, 10, 10, dtype=torch.int32), onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, torch.zeros(
|
||||
1, 3, 10, 10, dtype=torch.int32), onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'example_input': example_inputs,
|
||||
@ -314,9 +327,11 @@ def create_pytorch_nn_module_sample_input_dict(tmp_dir):
|
||||
from openvino.tools.mo import convert_model
|
||||
pt_model = make_pt_model_one_input()
|
||||
|
||||
example_inputs = {"x": np.array(torch.zeros(1, 3, 10, 10, dtype=torch.int32))}
|
||||
example_inputs = {"x": np.array(
|
||||
torch.zeros(1, 3, 10, 10, dtype=torch.int32))}
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, torch.zeros(1, 3, 10, 10, dtype=torch.int32), onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, torch.zeros(
|
||||
1, 3, 10, 10, dtype=torch.int32), onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'example_input': example_inputs,
|
||||
@ -345,7 +360,8 @@ def create_pytorch_nn_module_sample_list_of_tensors(tmp_dir):
|
||||
example_inputs = [torch.zeros(3, 10, 10, dtype=torch.float32)]
|
||||
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, torch.unsqueeze(example_inputs[0], 0), onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, torch.unsqueeze(
|
||||
example_inputs[0], 0), onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'example_input': example_inputs,
|
||||
@ -359,7 +375,8 @@ def create_pytorch_nn_module_sample_input_ov_host_tensor(tmp_dir):
|
||||
|
||||
sample_input = Tensor(np.zeros([1, 3, 10, 10], dtype=np.int32))
|
||||
onnx_model_path = os.path.join(tmp_dir, 'export.onnx')
|
||||
torch.onnx.export(pt_model, torch.zeros(1, 3, 10, 10, dtype=torch.int32), onnx_model_path, opset_version=16)
|
||||
torch.onnx.export(pt_model, torch.zeros(
|
||||
1, 3, 10, 10, dtype=torch.int32), onnx_model_path, opset_version=16)
|
||||
|
||||
ref_model = convert_model(onnx_model_path)
|
||||
return pt_model, ref_model, {'example_input': sample_input,
|
||||
@ -397,8 +414,10 @@ def create_pytorch_nn_module_layout_list(tmp_dir):
|
||||
ref_model.inputs[0].node.layout = Layout('nchw')
|
||||
ref_model.inputs[1].node.layout = Layout('nhwc')
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'layout': ['nchw', Layout('nhwc')],
|
||||
'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {
|
||||
'input_shape': [shape, shape], 'layout': ['nchw', Layout('nhwc')],
|
||||
"input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]
|
||||
}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_layout_list_case2(tmp_dir):
|
||||
@ -411,8 +430,9 @@ def create_pytorch_nn_module_layout_list_case2(tmp_dir):
|
||||
ref_model.inputs[0].node.layout = Layout('nchw')
|
||||
ref_model.inputs[1].node.layout = Layout('nhwc')
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'layout': ('nchw', Layout('nhwc')),
|
||||
'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {
|
||||
'input_shape': [shape, shape], 'layout': ('nchw', Layout('nhwc')),
|
||||
"input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_mean_list(tmp_dir):
|
||||
@ -433,8 +453,9 @@ def create_pytorch_nn_module_mean_list(tmp_dir):
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]],
|
||||
'onnx_opset_version': 11, 'compress_to_fp16': False, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {
|
||||
'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], 'compress_to_fp16': False,
|
||||
"input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_mean_list_default_compression(tmp_dir):
|
||||
@ -447,9 +468,11 @@ def create_pytorch_nn_module_mean_list_default_compression(tmp_dir):
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float16)
|
||||
const1_decompressed = ov.opset8.convert(const1, destination_type=np.float32)
|
||||
const1_decompressed = ov.opset8.convert(
|
||||
const1, destination_type=np.float32)
|
||||
const2 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float16)
|
||||
const2_decompressed = ov.opset8.convert(const2, destination_type=np.float32)
|
||||
const2_decompressed = ov.opset8.convert(
|
||||
const2, destination_type=np.float32)
|
||||
sub1 = ov.opset8.subtract(param1, const1_decompressed)
|
||||
sub2 = ov.opset8.subtract(param2, const2_decompressed)
|
||||
add = ov.opset8.add(sub1, sub2)
|
||||
@ -459,8 +482,7 @@ def create_pytorch_nn_module_mean_list_default_compression(tmp_dir):
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]],
|
||||
'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]], "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_mean_list_compressin_enabled(tmp_dir):
|
||||
@ -470,12 +492,10 @@ def create_pytorch_nn_module_mean_list_compressin_enabled(tmp_dir):
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float16)
|
||||
const1_decompressed = ov.opset8.convert(const1, destination_type=np.float32)
|
||||
const2 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float16)
|
||||
const2_decompressed = ov.opset8.convert(const2, destination_type=np.float32)
|
||||
sub1 = ov.opset8.subtract(param1, const1_decompressed)
|
||||
sub2 = ov.opset8.subtract(param2, const2_decompressed)
|
||||
const1 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
const2 = ov.opset8.constant([[[[0, 0, 0]]]], dtype=np.float32)
|
||||
sub1 = ov.opset8.subtract(param1, const1)
|
||||
sub2 = ov.opset8.subtract(param2, const2)
|
||||
add = ov.opset8.add(sub1, sub2)
|
||||
relu = ov.opset8.relu(add)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
@ -483,8 +503,9 @@ def create_pytorch_nn_module_mean_list_compressin_enabled(tmp_dir):
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]],
|
||||
'onnx_opset_version': 11, 'compress_to_fp16': True, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {
|
||||
'input_shape': [shape, shape], 'mean_values': [[0, 0, 0], [0, 0, 0]],
|
||||
'compress_to_fp16': False, "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_scale_list(tmp_dir):
|
||||
@ -505,8 +526,7 @@ def create_pytorch_nn_module_scale_list(tmp_dir):
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]],
|
||||
'onnx_opset_version': 11, 'compress_to_fp16': False, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], 'compress_to_fp16': False, "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_scale_list_default_compression(tmp_dir):
|
||||
@ -519,9 +539,11 @@ def create_pytorch_nn_module_scale_list_default_compression(tmp_dir):
|
||||
param1 = ov.opset8.parameter(shape)
|
||||
param2 = ov.opset8.parameter(shape)
|
||||
const1 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float16)
|
||||
const1_decompressed = ov.opset8.convert(const1, destination_type=np.float32)
|
||||
const1_decompressed = ov.opset8.convert(
|
||||
const1, destination_type=np.float32)
|
||||
const2 = ov.opset8.constant([[[[1, 1, 1]]]], dtype=np.float16)
|
||||
const2_decompressed = ov.opset8.convert(const2, destination_type=np.float32)
|
||||
const2_decompressed = ov.opset8.convert(
|
||||
const2, destination_type=np.float32)
|
||||
sub1 = ov.opset8.multiply(param1, const1_decompressed)
|
||||
sub2 = ov.opset8.multiply(param2, const2_decompressed)
|
||||
add = ov.opset8.add(sub1, sub2)
|
||||
@ -531,8 +553,7 @@ def create_pytorch_nn_module_scale_list_default_compression(tmp_dir):
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]],
|
||||
'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
|
||||
@ -555,44 +576,47 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]],
|
||||
'onnx_opset_version': 11, 'compress_to_fp16': True, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': [shape, shape], 'scale_values': [[1, 1, 1], [1, 1, 1]], "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)],
|
||||
'compress_to_fp16': True}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_static(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20])
|
||||
|
||||
return pt_model, ref_model, {'input_shape': [[1, 3, 20, 20], [1, 3, 20, 20]], 'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': [[1, 3, 20, 20], [1, 3, 20, 20]], "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)], [-1, 3, 20, Dimension(-1, 20)]]
|
||||
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)],
|
||||
[-1, 3, 20, Dimension(-1, 20)]]
|
||||
|
||||
param1 = ov.opset8.parameter(PartialShape(inp_shapes[0]), name="input_0", dtype=np.float32)
|
||||
param2 = ov.opset8.parameter(PartialShape(inp_shapes[1]), name="input_1", dtype=np.float32)
|
||||
param1 = ov.opset8.parameter(PartialShape(
|
||||
inp_shapes[0]), name="x", dtype=np.float32)
|
||||
param2 = ov.opset8.parameter(PartialShape(
|
||||
inp_shapes[1]), name="y", dtype=np.float32)
|
||||
add = ov.opset8.add(param1, param2)
|
||||
relu = ov.opset8.relu(add)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, 'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, "input": [InputCutInfo("x", None, "f32", None), InputCutInfo("y", None, "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)]]
|
||||
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, 'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, "input": InputCutInfo("x", None, "f32", None)}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
inp_shapes = [[1, 3, 20, 20]]
|
||||
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, 'onnx_opset_version': 11, "use_legacy_frontend": True}
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, "input": InputCutInfo("x", None, "f32", None)}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_convert_pytorch_frontend1(tmp_dir):
|
||||
@ -605,7 +629,10 @@ def create_pytorch_nn_module_convert_pytorch_frontend1(tmp_dir):
|
||||
|
||||
parameter_list = [param]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return pt_model, ref_model, {"example_input": torch.zeros((1, 3, 10, 10))}
|
||||
return pt_model, ref_model, {
|
||||
"example_input": torch.zeros((1, 3, 10, 10)),
|
||||
"input": [InputCutInfo("x", [-1, -1, -1, -1], "f32", None)]
|
||||
}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_convert_pytorch_frontend2(tmp_dir):
|
||||
@ -620,39 +647,46 @@ def create_pytorch_nn_module_convert_pytorch_frontend2(tmp_dir):
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
ref_model.input(0).get_node().set_element_type(Type.i32)
|
||||
ref_model.validate_nodes_and_infer_types()
|
||||
return pt_model, ref_model, {"example_input": torch.zeros((1, 3, 10, 10), dtype=torch.int32)}
|
||||
return pt_model, ref_model, {
|
||||
"example_input": torch.zeros((1, 3, 10, 10), dtype=torch.int32),
|
||||
"input": [InputCutInfo("x", [-1, -1, -1, -1], "i32", None)]
|
||||
}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_convert_pytorch_frontend3(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
shape = [-1, -1, -1, -1]
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset10.parameter(shape)
|
||||
param2 = ov.opset10.parameter(shape)
|
||||
param2_convert = ov.opset10.convert_like(param2, param1)
|
||||
add = ov.opset10.add(param1, param2_convert)
|
||||
param1 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
add = ov.opset10.add(param1, param2)
|
||||
relu = ov.opset10.relu(add)
|
||||
sigm = ov.opset10.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return pt_model, ref_model, {"example_input": [torch.zeros((1, 3, 10, 10)), torch.ones((1, 3, 10, 10))]}
|
||||
return pt_model, ref_model, {
|
||||
"example_input": [torch.zeros((1, 3, 10, 10)), torch.ones((1, 3, 10, 10))],
|
||||
"input": [InputCutInfo("x", [-1, -1, -1, -1], "f32", None), InputCutInfo("y", [-1, -1, -1, -1], "f32", None)]
|
||||
}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_convert_pytorch_frontend4(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
shape = [-1, -1, -1, -1]
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset10.parameter(shape)
|
||||
param2 = ov.opset10.parameter(shape)
|
||||
param2_convert = ov.opset10.convert_like(param2, param1)
|
||||
add = ov.opset10.add(param1, param2_convert)
|
||||
param1 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
add = ov.opset10.add(param1, param2)
|
||||
relu = ov.opset10.relu(add)
|
||||
sigm = ov.opset10.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return pt_model, ref_model, {"example_input": {"x": torch.zeros((1, 3, 10, 10)), "y": torch.ones((1, 3, 10, 10))}}
|
||||
return pt_model, ref_model, {
|
||||
"example_input": {"x": torch.zeros((1, 3, 10, 10), dtype=torch.float32), "y": torch.ones((1, 3, 10, 10), dtype=torch.float32)},
|
||||
"input": [InputCutInfo("x", [-1, -1, -1, -1], "f32", None), InputCutInfo("y", [-1, -1, -1, -1], "f32", None)]
|
||||
}
|
||||
|
||||
|
||||
def create_pytorch_jit_script_module_convert_pytorch_frontend(tmp_dir):
|
||||
@ -662,15 +696,16 @@ def create_pytorch_jit_script_module_convert_pytorch_frontend(tmp_dir):
|
||||
scripted_model = torch.jit.script(net)
|
||||
shape = [-1, -1, -1, -1]
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset10.parameter(shape)
|
||||
param2 = ov.opset10.parameter(shape)
|
||||
param2_convert = ov.opset10.convert_like(param2, param1)
|
||||
add = ov.opset10.add(param1, param2_convert)
|
||||
param1 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
add = ov.opset10.add(param1, param2)
|
||||
relu = ov.opset10.relu(add)
|
||||
sigm = ov.opset10.sigmoid(relu)
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return scripted_model, ref_model, {"example_input": {"x": torch.zeros((1, 3, 10, 10)), "y": torch.ones((1, 3, 10, 10))}}
|
||||
return scripted_model, ref_model, {
|
||||
"example_input": {"x": torch.zeros((1, 3, 10, 10)), "y": torch.ones((1, 3, 10, 10))},
|
||||
"input": [InputCutInfo("x.1", [-1, -1, -1, -1], "f32", None), InputCutInfo("y.1", [-1, -1, -1, -1], "f32", None)]}
|
||||
|
||||
|
||||
def create_pytorch_jit_trace_module_convert_pytorch_frontend(tmp_dir):
|
||||
@ -681,15 +716,15 @@ def create_pytorch_jit_trace_module_convert_pytorch_frontend(tmp_dir):
|
||||
scripted_model = torch.jit.trace(net, example_input)
|
||||
shape = [-1, -1, -1, -1]
|
||||
shape = PartialShape(shape)
|
||||
param1 = ov.opset10.parameter(shape)
|
||||
param2 = ov.opset10.parameter(shape)
|
||||
param2_convert = ov.opset10.convert_like(param2, param1)
|
||||
add = ov.opset10.add(param1, param2_convert)
|
||||
param1 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
param2 = ov.opset10.parameter(shape, dtype=np.float32)
|
||||
add = ov.opset10.add(param1, param2)
|
||||
relu = ov.opset10.relu(add)
|
||||
sigm = ov.opset10.sigmoid(relu)
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return scripted_model, ref_model, {"example_input": example_input}
|
||||
return scripted_model, ref_model, {"example_input": example_input, "input": [
|
||||
InputCutInfo("x", [-1, -1, -1, -1], "f32", None), InputCutInfo("y", [-1, -1, -1, -1], "f32", None)]}
|
||||
|
||||
|
||||
class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
@ -736,9 +771,9 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_jit_trace_module_convert_pytorch_frontend
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("create_model", test_data)
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@ pytest.mark.parametrize("create_model", test_data)
|
||||
@ pytest.mark.nightly
|
||||
@ pytest.mark.precommit
|
||||
def test_mo_import_from_memory(self, create_model, ie_device, precision, ir_version,
|
||||
temp_dir, use_new_frontend, use_old_api):
|
||||
fw_model, graph_ref, mo_params = create_model(temp_dir)
|
||||
@ -746,7 +781,8 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
test_params = {'input_model': fw_model}
|
||||
if mo_params is not None:
|
||||
test_params.update(mo_params)
|
||||
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
|
||||
self._test_by_ref_graph(temp_dir, test_params,
|
||||
graph_ref, compare_tensor_names=False)
|
||||
|
||||
|
||||
def create_pt_model_with_custom_op():
|
||||
@ -773,4 +809,5 @@ class ConvertONNXFallthroughTest(unittest.TestCase):
|
||||
|
||||
# Check that ONNX conversion passed, so ONNX frontend raises error message of unsupported op.
|
||||
with self.assertRaisesRegex(RuntimeError, ".*OpenVINO does not support the following ONNX operations: MyTorchOp.*"):
|
||||
convert_model(pytorch_model, input_shape=[1, 2, 3], use_legacy_frontend=True)
|
||||
convert_model(pytorch_model, input_shape=[
|
||||
1, 2, 3], use_legacy_frontend=True)
|
||||
|
@ -48,13 +48,17 @@ class PytorchLayerTest:
|
||||
inputs = self._prepare_input()
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
if not kwargs.get('trace_model', False):
|
||||
model = torch.jit.script(model)
|
||||
torch_inputs = [torch.from_numpy(inp) if isinstance(
|
||||
inp, np.ndarray) else inp for inp in inputs]
|
||||
trace_model = kwargs.get('trace_model', False)
|
||||
freeze_model = kwargs.get('freeze_model', True)
|
||||
use_mo_convert = kwargs.get("use_mo_convert", True)
|
||||
if not freeze_model or not use_mo_convert:
|
||||
model, converted_model = self.convert_directly_via_frontend(
|
||||
model, torch_inputs, trace_model, dynamic_shapes, inputs, freeze_model)
|
||||
else:
|
||||
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
|
||||
model = torch.jit.trace(model, deepcopy(torch_inputs))
|
||||
if kwargs.get('freeze_model', True):
|
||||
model = torch.jit.freeze(model)
|
||||
model, converted_model = self.convert_via_mo(
|
||||
model, torch_inputs, trace_model, dynamic_shapes, inputs)
|
||||
graph = model.inlined_graph
|
||||
print(graph)
|
||||
|
||||
@ -62,36 +66,12 @@ class PytorchLayerTest:
|
||||
kind = [kind]
|
||||
if kind is not None:
|
||||
for op in kind:
|
||||
assert self._check_kind_exist(graph, op), f"Operation {op} type doesn't exist in provided graph"
|
||||
|
||||
fe_manager = FrontEndManager()
|
||||
fe = fe_manager.load_by_framework('pytorch')
|
||||
|
||||
decoder = TorchScriptPythonDecoder(model)
|
||||
|
||||
im = fe.load(decoder)
|
||||
om = fe.convert(im)
|
||||
|
||||
torch_inps = [torch.from_numpy(inp) if isinstance(inp, np.ndarray) else inp for inp in inputs]
|
||||
|
||||
params = om.get_parameters()
|
||||
# todo: support lists and dicts
|
||||
for i in range(len(inputs)):
|
||||
inp = inputs[i]
|
||||
if isinstance(inp, list):
|
||||
inputs[i] = np.array(inp)
|
||||
if inputs[i].dtype == np.int64:
|
||||
inputs[i] = inputs[i].astype(np.int32)
|
||||
inp = inputs[i]
|
||||
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
|
||||
params[i].set_element_type(self._type_map[inp.dtype.name])
|
||||
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
|
||||
params[i].set_partial_shape(PartialShape(shape))
|
||||
om.validate_nodes_and_infer_types()
|
||||
assert self._check_kind_exist(
|
||||
graph, op), f"Operation {op} type doesn't exist in provided graph"
|
||||
|
||||
# OV infer:
|
||||
core = Core()
|
||||
compiled = core.compile_model(om, ie_device)
|
||||
compiled = core.compile_model(converted_model, ie_device)
|
||||
infer_res = compiled(deepcopy(inputs))
|
||||
|
||||
if hasattr(self, 'skip_framework') and self.skip_framework:
|
||||
@ -99,7 +79,7 @@ class PytorchLayerTest:
|
||||
return
|
||||
|
||||
# Framework infer:
|
||||
fw_res = model(*deepcopy(torch_inps))
|
||||
fw_res = model(*deepcopy(torch_inputs))
|
||||
|
||||
if not isinstance(fw_res, (tuple)):
|
||||
fw_res = (fw_res,)
|
||||
@ -120,8 +100,8 @@ class PytorchLayerTest:
|
||||
results.extend(decomposed_res)
|
||||
continue
|
||||
results.append(res_item)
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
flatten_fw_res = flattenize_list_outputs(fw_res)
|
||||
|
||||
assert len(flatten_fw_res) == len(
|
||||
@ -130,7 +110,8 @@ class PytorchLayerTest:
|
||||
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
|
||||
if not isinstance(fw_tensor, torch.Tensor):
|
||||
if np.isscalar(fw_tensor):
|
||||
assert fw_tensor == np.array(ov_tensor).item(), f"{fw_tensor} != {np.array(ov_tensor).item()}"
|
||||
assert fw_tensor == np.array(ov_tensor).item(
|
||||
), f"{fw_tensor} != {np.array(ov_tensor).item()}"
|
||||
else:
|
||||
if isinstance(fw_tensor, list):
|
||||
ov_tensor = ov_tensor.tolist()
|
||||
@ -169,6 +150,59 @@ class PytorchLayerTest:
|
||||
def _prepare_input(self):
|
||||
raise RuntimeError("Please provide inputs generation function")
|
||||
|
||||
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs):
|
||||
import torch
|
||||
from openvino.tools.mo import convert_model
|
||||
kwargs = {"example_input": example_input if len(
|
||||
example_input) > 1 else example_input[0], "compress_to_fp16": False}
|
||||
with torch.no_grad():
|
||||
if trace_model:
|
||||
model = torch.jit.trace(model, example_input)
|
||||
else:
|
||||
model = torch.jit.script(model)
|
||||
model = torch.jit.freeze(model)
|
||||
print(model)
|
||||
if not dynamic_shapes:
|
||||
input_shapes = [inp.shape for inp in ov_inputs]
|
||||
kwargs["input_shape"] = input_shapes
|
||||
om = convert_model(model, **kwargs)
|
||||
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
|
||||
return model, om
|
||||
|
||||
def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
|
||||
import torch
|
||||
|
||||
fe_manager = FrontEndManager()
|
||||
fe = fe_manager.load_by_framework('pytorch')
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if trace_model:
|
||||
model = torch.jit.trace(model, example_input)
|
||||
else:
|
||||
model = torch.jit.script(model)
|
||||
decoder = TorchScriptPythonDecoder(model, freeze=freeze_model)
|
||||
im = fe.load(decoder)
|
||||
om = fe.convert(im)
|
||||
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
|
||||
return model, om
|
||||
|
||||
def _resolve_input_shape_dtype(self, om, ov_inputs, dynamic_shapes):
|
||||
params = list(om.inputs)
|
||||
for i in range(len(ov_inputs)):
|
||||
inp = ov_inputs[i]
|
||||
if isinstance(inp, list):
|
||||
ov_inputs[i] = np.array(inp)
|
||||
if ov_inputs[i].dtype == np.int64:
|
||||
ov_inputs[i] = ov_inputs[i].astype(np.int32)
|
||||
inp = ov_inputs[i]
|
||||
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
|
||||
params[i].get_node().set_element_type(self._type_map[inp.dtype.name])
|
||||
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
|
||||
params[i].get_node().set_partial_shape(PartialShape(shape))
|
||||
om.validate_nodes_and_infer_types()
|
||||
return om
|
||||
|
||||
|
||||
def get_params(ie_device=None, precision=None):
|
||||
"""
|
||||
|
@ -39,7 +39,7 @@ class TestAddCMul(PytorchLayerTest):
|
||||
[np.int32, 10],
|
||||
[np.int32, 110],
|
||||
[np.float32, 2.0],
|
||||
[np.float32, 3.1],
|
||||
[np.float32, 3.123],
|
||||
[np.float32, 4.5],
|
||||
[np.float64, 41.5],
|
||||
[np.float64, 24.5],
|
||||
|
@ -45,4 +45,4 @@ class TestBatchNorm(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
def test_batch_norm(self, weights, bias, eps, ie_device, precision, ir_version, kwargs_to_prepare_input):
|
||||
self._test(*self.create_model(weights, bias, eps),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False)
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=False, use_mo_convert=False)
|
@ -30,7 +30,7 @@ class TestInstanceNorm(PytorchLayerTest):
|
||||
if mean_var:
|
||||
self.mean = torch.randn(weights_shape)
|
||||
self.var = torch.randn(weights_shape)
|
||||
|
||||
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
@ -42,15 +42,16 @@ class TestInstanceNorm(PytorchLayerTest):
|
||||
|
||||
@pytest.mark.parametrize("params",
|
||||
[
|
||||
{"eps": 0.0001},
|
||||
{'weights': True, 'eps': -0.05},
|
||||
{'weights': True},
|
||||
{'weights': True, 'bias': True},
|
||||
{"weights": True, 'bias': False, "mean_var": True},
|
||||
{"weights": True, 'bias': True, "mean_var": True},
|
||||
{"weights": False, 'bias': True, "mean_var": True},
|
||||
{"weights": False, 'bias': False, "mean_var": True},
|
||||
{"weights": False, 'bias': False, "mean_var": True, "eps": 1.5}
|
||||
{"eps": 0.0001},
|
||||
{'weights': True, 'eps': -0.05},
|
||||
{'weights': True},
|
||||
{'weights': True, 'bias': True},
|
||||
{"weights": True, 'bias': False, "mean_var": True},
|
||||
{"weights": True, 'bias': True, "mean_var": True},
|
||||
{"weights": False, 'bias': True, "mean_var": True},
|
||||
{"weights": False, 'bias': False, "mean_var": True},
|
||||
{"weights": False, 'bias': False,
|
||||
"mean_var": True, "eps": 1.5}
|
||||
])
|
||||
@pytest.mark.parametrize("kwargs_to_prepare_input", [
|
||||
{"ndim": 3},
|
||||
@ -61,4 +62,5 @@ class TestInstanceNorm(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
def test_group_norm(self, params, ie_device, precision, ir_version, kwargs_to_prepare_input):
|
||||
self._test(*self.create_model(**params),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, dynamic_shapes=not params.get("mean_var", False))
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input,
|
||||
dynamic_shapes=not params.get("mean_var", False), use_mo_convert=False)
|
||||
|
@ -13,10 +13,10 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
class TestListUnpack(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (
|
||||
np.random.randn(8, 3, 512, 512),
|
||||
np.random.randn(1, 3, 224, 224),
|
||||
np.random.randn(10, 1, 8, 8),
|
||||
np.random.randn(1, 1, 1, 1),
|
||||
np.random.randn(8, 3, 512, 512).astype(np.float32),
|
||||
np.random.randn(1, 3, 224, 224).astype(np.float32),
|
||||
np.random.randn(10, 1, 8, 8).astype(np.float32),
|
||||
np.random.randn(1, 1, 1, 1).astype(np.float32),
|
||||
)
|
||||
|
||||
def create_model_size_listunpack(self):
|
||||
@ -122,7 +122,7 @@ class TestListUnpack(PytorchLayerTest):
|
||||
*self.create_model_listconstruct_getitem_listunpack(idx),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version
|
||||
ir_version,
|
||||
)
|
||||
|
||||
class TestMeshgridListUnpack(PytorchLayerTest):
|
||||
|
@ -135,7 +135,7 @@ class TestPrimMax(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
def test_min_max(self, case, kwargs_to_prepare_input, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input)
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, use_mo_convert=False)
|
||||
|
||||
class TestPrimMin(PytorchLayerTest):
|
||||
def _prepare_input(self, first_input, second_input, dtype="float"):
|
||||
@ -199,4 +199,4 @@ class TestPrimMin(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
def test_min(self, case, kwargs_to_prepare_input, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input)
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, use_mo_convert=False)
|
||||
|
@ -73,4 +73,4 @@ class TestRepeatInterleaveNonConstRepeats(PytorchLayerTest):
|
||||
self.repeats = input_data['repeats']
|
||||
dim = input_data['dim']
|
||||
self._test(*self.create_model_non_const_repeat(dim),
|
||||
ie_device, precision, ir_version)
|
||||
ie_device, precision, ir_version, dynamic_shapes=False, use_mo_convert=False)
|
||||
|
@ -765,10 +765,9 @@ def _convert(cli_parser: argparse.ArgumentParser, framework, args):
|
||||
args.pop("use_legacy_frontend")
|
||||
return convert_pytorch_via_onnx(args, example_inputs, cli_parser, framework, _convert)
|
||||
|
||||
decoder, input_signature = get_pytorch_decoder(args['input_model'], parse_input_shapes(args), example_inputs)
|
||||
decoder = get_pytorch_decoder(args['input_model'], parse_input_shapes(args), example_inputs)
|
||||
args['input_model'] = decoder
|
||||
args["framework"] = "pytorch"
|
||||
args["input_signature"] = input_signature
|
||||
|
||||
argv = pack_params_to_args_namespace(args, cli_parser)
|
||||
|
||||
|
@ -187,11 +187,12 @@ def fe_input_user_data_repack(
|
||||
"""
|
||||
_input_shapes = []
|
||||
_input_names = []
|
||||
model_inputs = input_model.get_inputs()
|
||||
|
||||
if isinstance(input_user_shapes, list) and len(input_user_shapes) > 1 and isinstance(input_user_shapes[0],
|
||||
PartialShape):
|
||||
for shape in input_user_shapes:
|
||||
assert isinstance(shape, PartialShape), "Got incorrect format of input shapes."
|
||||
model_inputs = input_model.get_inputs()
|
||||
assert len(model_inputs) == len(input_user_shapes)
|
||||
for idx, model_input in enumerate(model_inputs):
|
||||
_input_shapes.append({"node": model_input, "shape": input_user_shapes[idx]})
|
||||
@ -234,7 +235,6 @@ def fe_input_user_data_repack(
|
||||
# for example, --input_shape [3] --freeze_placeholder_with_value "is_training->False"
|
||||
# means the model has two inputs: one is is_training to be frozen, the other to re-write the shape
|
||||
# NOTE: the logic relies on parameters with the single name
|
||||
model_inputs = input_model.get_inputs()
|
||||
frozen_names = freeze_placeholder.keys()
|
||||
assert len(model_inputs) == len(frozen_names) + 1, \
|
||||
"Please check the conversion command-line. Total number of model inputs ({} detected) " \
|
||||
@ -259,7 +259,7 @@ def fe_input_user_data_repack(
|
||||
# and they should not be changed and their properties (shape and type) should not be over-written
|
||||
# NOTE: the logic relies on parameters with the single name
|
||||
assert input_user_shapes is None
|
||||
for node in input_model.get_inputs():
|
||||
for node in model_inputs:
|
||||
assert len(node.get_names()) > 0, "Original model inputs must have tensor names."
|
||||
input_name = node.get_names()[0]
|
||||
_input_shapes.append(
|
||||
|
@ -19,7 +19,6 @@ from openvino.tools.mo.moc_frontend.analysis import json_model_analysis_dump
|
||||
from openvino.tools.mo.moc_frontend.extractor import fe_user_data_repack
|
||||
from openvino.tools.mo.utils.class_registration import get_enabled_and_disabled_transforms
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
from openvino.tools.mo.moc_frontend.pytorch_frontend_utils import pytorch_process_after_convert
|
||||
|
||||
|
||||
def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
@ -75,9 +74,10 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
# a model is not processed further in json analysis mode
|
||||
sys.exit(0)
|
||||
|
||||
model_inputs = input_model.get_inputs()
|
||||
inputs_equal = True
|
||||
if user_shapes:
|
||||
inputs_equal = check_places_are_same(input_model.get_inputs(), user_shapes)
|
||||
inputs_equal = check_places_are_same(model_inputs, user_shapes)
|
||||
|
||||
outputs_equal = True
|
||||
if outputs:
|
||||
@ -196,7 +196,7 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
# Set batch size
|
||||
if argv.batch is not None and argv.batch > 0:
|
||||
log.debug('Setting batch size to {}'.format(argv.batch))
|
||||
for place in input_model.get_inputs():
|
||||
for place in model_inputs:
|
||||
old_partial_shape = input_model.get_partial_shape(place)
|
||||
old_shape_array = shape_to_array(old_partial_shape) if old_partial_shape.rank.is_static else []
|
||||
joined_name = ' '.join(place.get_names())
|
||||
@ -214,8 +214,4 @@ def moc_pipeline(argv: argparse.Namespace, moc_front_end: FrontEnd):
|
||||
input_model.set_partial_shape(place, new_partial_shape)
|
||||
|
||||
ngraph_function = moc_front_end.convert(input_model)
|
||||
|
||||
# TO DO: remove as part of PyTorch frontend productization CVS-103615
|
||||
if argv.framework == "pytorch":
|
||||
pytorch_process_after_convert(argv, ngraph_function)
|
||||
return ngraph_function
|
||||
|
@ -6,7 +6,7 @@ import logging as log
|
||||
import numpy as np
|
||||
from openvino.tools.mo.moc_frontend.shape_utils import get_static_shape, get_dynamic_dims, parse_input_shapes
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
from openvino.runtime import PartialShape, Tensor
|
||||
from openvino.runtime import Tensor
|
||||
|
||||
def get_onnx_temp_filename(output_dir):
|
||||
output_dir = output_dir if output_dir is not None else os.getcwd()
|
||||
@ -22,35 +22,15 @@ def remove_tmp_onnx_model(out_dir):
|
||||
|
||||
|
||||
def get_pytorch_decoder(model, input_shape, example_inputs):
|
||||
import torch
|
||||
import inspect
|
||||
try:
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
except Exception as e:
|
||||
log.error("PyTorch frontend loading failed")
|
||||
raise e
|
||||
inputs = prepare_torch_inputs(example_inputs, input_shape, allow_none=True)
|
||||
model.eval()
|
||||
input_signature = None
|
||||
if isinstance(model, torch.nn.Module) and not isinstance(model, torch.jit._trace.TopLevelTracedModule):
|
||||
input_signature = list(inspect.signature(model.forward).parameters.keys())
|
||||
try:
|
||||
scripted = torch.jit.script(model)
|
||||
except Exception as scripting_err:
|
||||
if example_inputs is not None:
|
||||
try:
|
||||
scripted = torch.jit.trace(model, inputs)
|
||||
except Exception as tracing_e:
|
||||
log.error('Both traicing and scripting failed')
|
||||
raise tracing_e
|
||||
else:
|
||||
log.error("Model scripting failed")
|
||||
raise scripting_err
|
||||
else:
|
||||
scripted = model
|
||||
f_model = torch.jit.freeze(scripted)
|
||||
decoder = TorchScriptPythonDecoder(f_model)
|
||||
return decoder, input_signature
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=inputs)
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
def to_torch_tensor(tensor):
|
||||
@ -63,6 +43,8 @@ def to_torch_tensor(tensor):
|
||||
return torch.tensor(tensor)
|
||||
if isinstance(tensor, Tensor):
|
||||
return torch.tensor(tensor.data)
|
||||
if isinstance(tensor, (float, int, bool)):
|
||||
return tensor
|
||||
else:
|
||||
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
|
||||
"Got {}".format(type(tensor)))
|
||||
@ -148,7 +130,6 @@ def convert_pytorch_via_onnx(args, example_inputs, cli_parser, framework, main_c
|
||||
args['example_input'] = None
|
||||
args['onnx_opset_version'] = None
|
||||
try:
|
||||
|
||||
model_onnx = convert_pytorch_to_onnx(args['input_model'],
|
||||
parse_input_shapes(args),
|
||||
opset_version,
|
||||
@ -163,53 +144,3 @@ def convert_pytorch_via_onnx(args, example_inputs, cli_parser, framework, main_c
|
||||
finally:
|
||||
remove_tmp_onnx_model(out_dir)
|
||||
return ov_model, argv
|
||||
|
||||
|
||||
def pytorch_process_after_convert(argv, ov_model):
|
||||
import torch
|
||||
from openvino.frontend.pytorch.decoder import pt_to_ov_type_map
|
||||
|
||||
def add_tensor_name(input_desc, input_name):
|
||||
tensor = input_desc.get_tensor()
|
||||
input_names = tensor.names
|
||||
input_names.update(input_name)
|
||||
tensor.set_names(input_names)
|
||||
|
||||
example_inputs = getattr(argv, "example_input", None)
|
||||
input_signature = getattr(argv, "input_signature", None)
|
||||
provide_shapes = argv.input_shape is not None
|
||||
if example_inputs is not None:
|
||||
inputs = [example_inputs] if isinstance(example_inputs, torch.Tensor) else example_inputs
|
||||
if input_signature is not None and isinstance(inputs, dict):
|
||||
ordered_inputs = []
|
||||
upd_sign = []
|
||||
for key in input_signature:
|
||||
if key not in inputs:
|
||||
continue
|
||||
ordered_inputs.append(inputs[key])
|
||||
upd_sign.append(key)
|
||||
inputs = ordered_inputs
|
||||
input_signature = upd_sign
|
||||
for idx, input_tensor in enumerate(ov_model.inputs):
|
||||
if isinstance(inputs, (list, tuple)):
|
||||
input_data = inputs[idx]
|
||||
else:
|
||||
input_data = list(inputs.values())[idx]
|
||||
pt_dtype = input_data.dtype if isinstance(input_data, torch.Tensor) else type(input_data)
|
||||
dtype = pt_to_ov_type_map.get(str(pt_dtype))
|
||||
if dtype is None:
|
||||
raise f"Unknown input dtype {pt_dtype}"
|
||||
|
||||
input_tensor.get_node().set_element_type(dtype)
|
||||
if input_signature is not None:
|
||||
add_tensor_name(input_tensor, input_signature[idx])
|
||||
if not provide_shapes:
|
||||
# prevent dynamic rank issue
|
||||
shape = [-1] * len(input_data.shape)
|
||||
input_tensor.get_node().set_partial_shape(PartialShape(shape))
|
||||
|
||||
ov_model.validate_nodes_and_infer_types()
|
||||
elif input_signature is not None:
|
||||
for idx, input_tensor in enumerate(ov_model.inputs):
|
||||
add_tensor_name(input_tensor, input_signature[idx])
|
||||
return ov_model
|
||||
|
@ -707,10 +707,6 @@ mo_convert_params = {
|
||||
'For PyTorch it can be torch.Tensor.', '', '', None),
|
||||
'onnx_opset_version': ParamDescription('Version of ONNX opset that is used for converting from PyTorch to ONNX.',
|
||||
'', '', None),
|
||||
'input_signature': ParamDescription('PyTorch model forward method input signature, '
|
||||
'will be detected automatically for torch.nn.Module based model instances, '
|
||||
'for for scripted models may requires to set manually. Example of usage: for forward method defined as'
|
||||
' def forward(self, x, y), it will be ["x", "y"]', '', '', None)
|
||||
}
|
||||
}
|
||||
|
||||
@ -2010,7 +2006,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
|
||||
res = {}
|
||||
# collect input names
|
||||
if argv_input:
|
||||
inputs = argv_input.split(',')
|
||||
inputs = [get_node_name_with_port_from_input_value(input_value) for input_value in split_inputs(argv_input)]
|
||||
else:
|
||||
inputs = []
|
||||
if type(mean_values) is dict:
|
||||
|
Loading…
Reference in New Issue
Block a user