Remove None at outputs of the model, improve types handling in frontend (#15258)
* Remove None at outputs of the model, improve types handling in frontend * Fix py code style * Add torch dependency in pybind tests * Fix tests if fe is disabled and add backward type cpnversion * Move decoder tests to layer tests * Fix codestyle * Add comment * Move tests to separate folder * Update .ci/azure/linux.yml
This commit is contained in:
parent
8e073819c3
commit
994b227b86
@ -525,6 +525,11 @@ jobs:
|
||||
$(RUN_PREFIX) python3 -m pytest $(LAYER_TESTS_DIR)/mo_python_api_tests/test_mo_convert_pytorch.py --ir_version=11 --junitxml=./TEST-test_mo_convert_pytorch.xmlTEST
|
||||
displayName: 'MO Python API Tests - Import PyTorch model from memory'
|
||||
|
||||
- script: |
|
||||
python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt
|
||||
$(RUN_PREFIX) python3 -m pytest $(LAYER_TESTS_DIR)/py_frontend_tests --junitxml=./TEST-test_py_fontend.xml
|
||||
displayName: 'Python Frontend tests'
|
||||
|
||||
- task: PublishTestResults@2
|
||||
condition: always()
|
||||
inputs:
|
||||
|
@ -37,4 +37,4 @@ tox
|
||||
types-pkg_resources
|
||||
wheel>=0.38.1
|
||||
protobuf~=3.18.1
|
||||
numpy>=1.16.6,<=1.23.4
|
||||
numpy>=1.16.6,<=1.23.4
|
||||
|
@ -109,22 +109,22 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def inputs(self):
|
||||
return [x.unique() for x in self.graph_element.inputs()]
|
||||
|
||||
def get_input(self, index):
|
||||
def get_input(self, index: int):
|
||||
return self.inputs()[index]
|
||||
|
||||
def get_input_shape(self, index):
|
||||
def get_input_shape(self, index: int):
|
||||
raw_input = self._raw_input(index)
|
||||
return self.get_shape_for_value(raw_input)
|
||||
|
||||
def get_input_type(self, index):
|
||||
def get_input_type(self, index: int):
|
||||
raw_input = self._raw_input(index)
|
||||
return self.get_type_for_value(raw_input)
|
||||
|
||||
def get_output_shape(self, index):
|
||||
def get_output_shape(self, index: int):
|
||||
output = self._raw_output(index)
|
||||
return self.get_shape_for_value(output)
|
||||
|
||||
def get_output_type(self, index):
|
||||
def get_output_type(self, index: int):
|
||||
output = self._raw_output(index)
|
||||
return self.get_type_for_value(output)
|
||||
|
||||
@ -136,12 +136,16 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
# TODO: Don't use str, use native types
|
||||
if str(pt_type) in pt_to_ov_type_map:
|
||||
return OVAny(pt_to_ov_type_map[str(pt_type)])
|
||||
elif pt_type.__class__ is torch.TensorType:
|
||||
elif isinstance(pt_type, torch.TensorType):
|
||||
# Tensor type, parse element type
|
||||
return OVAny(DecoderType.Tensor(self._get_known_type_for_value(pt_type.dtype())))
|
||||
elif pt_type.__class__ is torch.ListType:
|
||||
elif isinstance(pt_type, torch.ListType):
|
||||
element_type = pt_type.getElementType()
|
||||
return OVAny(DecoderType.List(self._get_known_type_for_value(element_type)))
|
||||
elif isinstance(pt_type, torch.StringType):
|
||||
return OVAny(DecoderType.Str())
|
||||
elif isinstance(pt_type, torch.NoneType):
|
||||
return OVAny(DecoderType.PyNone())
|
||||
else:
|
||||
# Not yet recognized
|
||||
return OVAny(OVType.dynamic)
|
||||
@ -161,7 +165,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
full_type = self._get_known_type_for_value(value.type())
|
||||
return full_type
|
||||
|
||||
def get_input_transpose_order(self, index):
|
||||
def get_input_transpose_order(self, index: int) -> list:
|
||||
raw_input = self._raw_input(index)
|
||||
if raw_input.type() is not None and raw_input.type().kind() == "TensorType":
|
||||
strides = raw_input.type().strides()
|
||||
@ -169,7 +173,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
|
||||
return []
|
||||
|
||||
def get_output_transpose_order(self, index):
|
||||
def get_output_transpose_order(self, index: int) -> list:
|
||||
output = self._raw_output(index)
|
||||
if output.type() is not None and output.type().kind() == "TensorType":
|
||||
strides = output.type().strides()
|
||||
@ -177,7 +181,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
|
||||
return []
|
||||
|
||||
def get_subgraph_size(self):
|
||||
def get_subgraph_size(self) -> int:
|
||||
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1
|
||||
|
||||
def visit_subgraph(self, node_visitor):
|
||||
@ -238,16 +242,16 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return None
|
||||
pt_value = self._raw_output(0)
|
||||
|
||||
pt_type_class = pt_value.type().__class__
|
||||
if pt_type_class is torch.TensorType:
|
||||
pt_type = pt_value.type()
|
||||
if isinstance(pt_type, torch.TensorType):
|
||||
return self.as_constant_tensor(pt_value)
|
||||
if pt_type_class is torch.ListType:
|
||||
if isinstance(pt_type, torch.ListType):
|
||||
return self.as_constant_list(pt_value)
|
||||
if str(pt_value.type()) in ["torch.int32", "int"]:
|
||||
if str(pt_type) in ["torch.int32", "int"]:
|
||||
return op.Constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
|
||||
if str(pt_value.type()) in ["torch.float", "torch.FloatType", "float"]:
|
||||
if str(pt_type) in ["torch.float", "torch.FloatType", "float"]:
|
||||
return op.Constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
|
||||
if str(pt_value.type()) in ["torch.bool", "bool"]:
|
||||
if str(pt_type) in ["torch.bool", "bool"]:
|
||||
return op.Constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()
|
||||
|
||||
return None
|
||||
|
@ -30,4 +30,6 @@ void regclass_frontend_pytorch_decoder(py::module m) {
|
||||
def(py::init<Any>());
|
||||
py::class_<type::Str>(type_module, "Str").
|
||||
def(py::init<>());
|
||||
py::class_<type::PyNone>(type_module, "PyNone").
|
||||
def(py::init<>());
|
||||
}
|
||||
|
@ -124,6 +124,15 @@ py::object from_ov_any(const ov::Any& any) {
|
||||
std::stringstream uuid_stream;
|
||||
uuid_stream << any.as<ov::device::UUID>();
|
||||
return py::cast(uuid_stream.str());
|
||||
// Custom FrontEnd Types
|
||||
} else if (any.is<ov::frontend::type::List>()) {
|
||||
return py::cast(any.as<ov::frontend::type::List>());
|
||||
} else if (any.is<ov::frontend::type::Tensor>()) {
|
||||
return py::cast(any.as<ov::frontend::type::Tensor>());
|
||||
} else if (any.is<ov::frontend::type::Str>()) {
|
||||
return py::cast(any.as<ov::frontend::type::Str>());
|
||||
} else if (any.is<ov::frontend::type::PyNone>()) {
|
||||
return py::cast(any.as<ov::frontend::type::PyNone>());
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "Failed to convert parameter to Python representation!");
|
||||
return py::cast<py::object>((PyObject*)NULL);
|
||||
@ -242,6 +251,10 @@ ov::Any py_object_to_any(const py::object& py_obj) {
|
||||
return py::cast<ov::frontend::type::Tensor>(py_obj);
|
||||
} else if (py::isinstance<ov::frontend::type::List>(py_obj)) {
|
||||
return py::cast<ov::frontend::type::List>(py_obj);
|
||||
} else if (py::isinstance<ov::frontend::type::Str>(py_obj)) {
|
||||
return py::cast<ov::frontend::type::Str>(py_obj);
|
||||
} else if (py::isinstance<ov::frontend::type::PyNone>(py_obj)) {
|
||||
return py::cast<ov::frontend::type::PyNone>(py_obj);
|
||||
// If there is no match fallback to py::object
|
||||
} else if (py::isinstance<py::object>(py_obj)) {
|
||||
return py_obj;
|
||||
|
@ -33,6 +33,8 @@ struct List {
|
||||
|
||||
struct Str {};
|
||||
|
||||
struct PyNone {};
|
||||
|
||||
struct Optional;
|
||||
struct Dict;
|
||||
struct NamedTuple;
|
||||
|
@ -55,15 +55,17 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
Any get_input_type(size_t index) const {
|
||||
return m_decoder->get_input_type(index);
|
||||
}
|
||||
|
||||
bool input_is_none(size_t index) const {
|
||||
return m_decoder->input_is_none(index);
|
||||
}
|
||||
|
||||
// Convert the resulting value of this node to ov Constant; works correctly only for nodes that produce
|
||||
// constant value, naturally for prim::Constant
|
||||
OutputVector as_constant() const {
|
||||
return m_decoder->as_constant();
|
||||
}
|
||||
OutputVector as_constant() const;
|
||||
|
||||
/*
|
||||
TODO: Should be uncommented when explicit NodeContext ctor won't require passing op_type
|
||||
|
@ -14,6 +14,28 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
|
||||
OutputVector NodeContext::as_constant() const {
|
||||
auto dtype = m_decoder->get_output_type(0);
|
||||
if (dtype.is<type::Str>()) {
|
||||
// Cannot represent string as Constant, creating FrameworkNode
|
||||
auto str = m_decoder->as_string();
|
||||
auto fw_node = std::make_shared<PtFrameworkNode>(m_decoder, OutputVector{});
|
||||
auto attrs = fw_node->get_attrs();
|
||||
attrs["string_value"] = str;
|
||||
fw_node->set_attrs(attrs);
|
||||
return {fw_node};
|
||||
} else if (dtype.is<type::PyNone>()) {
|
||||
// Cannot represent None as Constant, creating FrameworkNode
|
||||
auto fw_node = std::make_shared<PtFrameworkNode>(m_decoder, OutputVector{});
|
||||
auto attrs = fw_node->get_attrs();
|
||||
attrs["none_value"] = "";
|
||||
fw_node->set_attrs(attrs);
|
||||
return {fw_node};
|
||||
} else {
|
||||
return m_decoder->as_constant();
|
||||
}
|
||||
}
|
||||
|
||||
Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) {
|
||||
if (m_tensor_map->find(index) != m_tensor_map->end()) {
|
||||
return m_tensor_map->at(index);
|
||||
|
@ -3,8 +3,6 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
|
@ -16,10 +16,11 @@ OutputVector translate_convnd(NodeContext& context) {
|
||||
// In torch pads at beginning are same as at end
|
||||
auto pads = CoordinateDiff(strides.size(), 0);
|
||||
auto pad_type = ov::op::PadType::EXPLICIT;
|
||||
try {
|
||||
auto dtype = context.get_input_type(4);
|
||||
if (dtype.is<type::Str>()) {
|
||||
auto pad_mode = context.const_input<std::string>(4);
|
||||
pad_type = convert_pad(pad_mode);
|
||||
} catch (ov::frontend::GeneralFailure) {
|
||||
} else {
|
||||
pads = context.const_input<CoordinateDiff>(4);
|
||||
}
|
||||
auto dilations = context.const_input<Strides>(5);
|
||||
|
@ -27,9 +27,18 @@ bool DecomposeTupleResults::run_on_model(const std::shared_ptr<Model>& model) {
|
||||
if (!tuple_construct) {
|
||||
continue;
|
||||
}
|
||||
auto inputs = input_node->inputs();
|
||||
for (auto input : inputs) {
|
||||
model->add_results({std::make_shared<opset10::Result>(input.get_source_output())});
|
||||
for (const auto& input : input_node->inputs()) {
|
||||
const auto& out = input.get_source_output();
|
||||
if (const auto& fw_node = cast_fw_node(out.get_node_shared_ptr(), "prim::Constant")) {
|
||||
const auto& attrs = fw_node->get_attrs();
|
||||
if (attrs.find("none_value") != attrs.end()) {
|
||||
// This is None constant, we skip None if it goes to output of the model. It can be embedding loss
|
||||
// function calculation in model, which used only in training stage. When we move model to eval mode
|
||||
// and does not provide annotation, it is not calculated and return by default None.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
model->add_results({std::make_shared<opset10::Result>(out)});
|
||||
}
|
||||
|
||||
model->remove_result(result);
|
||||
|
84
tests/layer_tests/py_frontend_tests/test_torch_decoder.py
Normal file
84
tests/layer_tests/py_frontend_tests/test_torch_decoder.py
Normal file
@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
|
||||
class AtenDiv(torch.nn.Module):
|
||||
# aten::div can have str or NoneType constant
|
||||
def __init__(self, rounding_mode):
|
||||
super(AtenDiv, self).__init__()
|
||||
self.rounding_mode = rounding_mode
|
||||
|
||||
def forward(self, input_tensor, other_tensor):
|
||||
return torch.div(input_tensor, other_tensor, rounding_mode=self.rounding_mode)
|
||||
|
||||
|
||||
def get_scripted_model(model):
|
||||
with torch.no_grad():
|
||||
model = torch.jit.script(model)
|
||||
model.eval()
|
||||
model = torch.jit.freeze(model)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_get_output_type_str():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
|
||||
model = get_scripted_model(AtenDiv("trunc"))
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
# div model has exactly 1 constant
|
||||
assert len(consts) > 0
|
||||
str_const = consts[0]
|
||||
assert isinstance(list(str_const.outputs())[0].type(), torch.StringType)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, str_const)
|
||||
assert isinstance(nc_decoder.get_output_type(0).value, DecoderType.Str)
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_get_output_type_none():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
|
||||
model = get_scripted_model(AtenDiv(None))
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
# div model has exactly 1 constant
|
||||
assert len(consts) > 0
|
||||
none_const = consts[0]
|
||||
assert isinstance(list(none_const.outputs())[0].type(), torch.NoneType)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, none_const)
|
||||
assert isinstance(nc_decoder.get_output_type(0).value, DecoderType.PyNone)
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_get_input_type_str():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
|
||||
model = get_scripted_model(AtenDiv("trunc"))
|
||||
divs = [n for n in model.inlined_graph.nodes() if n.kind() == "aten::div"]
|
||||
assert len(divs) > 0
|
||||
div_node = divs[0]
|
||||
assert isinstance(list(div_node.inputs())[2].type(), torch.StringType)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, div_node)
|
||||
assert isinstance(nc_decoder.get_input_type(2).value, DecoderType.Str)
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_get_input_type_none():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
|
||||
model = get_scripted_model(AtenDiv(None))
|
||||
divs = [n for n in model.inlined_graph.nodes() if n.kind() == "aten::div"]
|
||||
assert len(divs) > 0
|
||||
div_node = divs[0]
|
||||
assert isinstance(list(div_node.inputs())[2].type(), torch.NoneType)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, div_node)
|
||||
assert isinstance(nc_decoder.get_input_type(2).value, DecoderType.PyNone)
|
@ -94,10 +94,18 @@ class PytorchLayerTest:
|
||||
fw_res = (fw_res,)
|
||||
|
||||
output_list = list(infer_res.values())
|
||||
assert len(fw_res) == len(
|
||||
output_list), f'number of outputs not equal, {len(fw_res)} != {len(output_list)}'
|
||||
|
||||
flatten_fw_res = []
|
||||
for res_item in fw_res:
|
||||
# if None is at output we skip it
|
||||
if res_item is None:
|
||||
continue
|
||||
flatten_fw_res.append(res_item)
|
||||
|
||||
assert len(flatten_fw_res) == len(
|
||||
output_list), f'number of outputs not equal, {len(flatten_fw_res)} != {len(output_list)}'
|
||||
# check if results dtypes match
|
||||
for fw_tensor, ov_tensor in zip(fw_res, output_list):
|
||||
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
|
||||
if not isinstance(fw_tensor, torch.Tensor):
|
||||
if np.isscalar(fw_tensor):
|
||||
assert fw_tensor == np.array(ov_tensor).item()
|
||||
@ -119,8 +127,8 @@ class PytorchLayerTest:
|
||||
fw_eps = custom_eps if precision == 'FP32' else 5e-2
|
||||
is_ok = True
|
||||
for i in range(len(infer_res)):
|
||||
cur_fw_res = fw_res[i].to(memory_format=torch.contiguous_format).numpy(
|
||||
) if isinstance(fw_res[i], torch.Tensor) else fw_res[i]
|
||||
cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy(
|
||||
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
|
||||
cur_ov_res = infer_res[compiled.output(i)]
|
||||
print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}")
|
||||
if not np.allclose(cur_ov_res, cur_fw_res,
|
||||
|
@ -35,7 +35,7 @@ class TestLog(PytorchLayerTest):
|
||||
return aten_log(op_fn), ref_net, f"aten::{op}"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precomit
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize(("op", "input_dtype"),
|
||||
[["log", "float32"],
|
||||
["log", "int32"],
|
||||
|
45
tests/layer_tests/pytorch_tests/test_tuple_construct.py
Normal file
45
tests/layer_tests/pytorch_tests/test_tuple_construct.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestTupleConstruct(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.uniform(0, 50, (1, 10)).astype(np.float32),)
|
||||
|
||||
def create_model(self, case):
|
||||
import torch
|
||||
|
||||
class prim_tuple_construct_single_value(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return (x,)
|
||||
|
||||
class prim_tuple_construct(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return (x, x + x)
|
||||
|
||||
class prim_tuple_construct_with_none(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return (x, None, x + x, None)
|
||||
|
||||
cases = {
|
||||
"single": prim_tuple_construct_single_value,
|
||||
"multiple": prim_tuple_construct,
|
||||
"none": prim_tuple_construct_with_none
|
||||
}
|
||||
|
||||
ref_net = None
|
||||
model = cases[case]
|
||||
|
||||
return model(), ref_net, "prim::TupleConstruct"
|
||||
|
||||
@pytest.mark.parametrize("case", ["single", "multiple", "none"])
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct(self, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user