[PT FE] Support dict on input (#21450)

* [PT FE] Support dict on input

* Check that keys are strings

* Revert changes in PtFrameworkNode
This commit is contained in:
Maxim Vafin 2023-12-12 18:58:04 +01:00 committed by GitHub
parent 3e42ddbde5
commit 182b44f6cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 23 deletions

View File

@ -200,7 +200,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
manager.register_pass<ov::frontend::pytorch::pass::DictParameterResolver>();
manager.register_pass<ov::frontend::pytorch::pass::DictResultResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();

View File

@ -19,6 +19,9 @@ using namespace ov::op;
OutputVector translate_getitem(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto input = context.get_input(0);
const auto idx_type = context.get_input_type(1);
FRONT_END_OP_CONVERSION_CHECK(!idx_type.is<type::Str>(),
"String index in aten::__getitem__ means dict input, this is not supported.");
if (ov::as_type_ptr<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"),
"special case for aten::__getitem__");

View File

@ -5,6 +5,7 @@
#include "dict_resolver.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
@ -18,7 +19,56 @@ namespace pass {
using namespace ov::pass;
using namespace ov::op;
bool DictResolver::run_on_model(const std::shared_ptr<Model>& model) {
bool DictParameterResolver::run_on_model(const std::shared_ptr<Model>& model) {
bool changed = false;
const auto parameters = model->get_parameters();
ParameterVector new_params;
for (const auto& p : parameters) {
bool at_least_one_unused = false;
if (p->get_output_size() == 1) {
const auto targets = p->get_output_target_inputs(0);
for (const auto inp : targets) {
const auto getitem_node = cast_fw_node(inp.get_node()->shared_from_this(), "aten::__getitem__");
if (getitem_node) {
const auto index_node = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(
getitem_node->get_input_node_shared_ptr(1));
if (!index_node) {
at_least_one_unused = true;
continue;
}
const auto attrs = index_node->get_attrs();
if (attrs.find("string_value") == attrs.end()) {
// index node must contain string value
at_least_one_unused = true;
continue;
}
const auto name = attrs.at("string_value");
auto new_param = std::make_shared<v0::Parameter>(getitem_node->get_output_element_type(0),
getitem_node->get_output_partial_shape(0));
new_param->set_friendly_name(name);
getitem_node->output(0).replace(new_param);
new_params.push_back(new_param);
changed = true;
} else {
at_least_one_unused = true;
}
}
}
if (changed) {
model->remove_parameter(p);
if (at_least_one_unused || p->get_output_size() != 1) {
new_params.push_back(p);
}
}
}
if (changed) {
model->add_parameters(new_params);
}
return changed;
};
bool DictResultResolver::run_on_model(const std::shared_ptr<Model>& model) {
bool changed = false;
const auto results = model->get_results();
for (const auto& res : results) {

View File

@ -12,9 +12,16 @@ namespace frontend {
namespace pytorch {
namespace pass {
class DictResolver : public ov::pass::ModelPass {
// This transformation replaces pattern Parameter(Dict)->aten::__getitem__
class DictParameterResolver : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::DictResolver");
OPENVINO_RTTI("ov::frontend::pytorch::pass::DictParameterResolver");
bool run_on_model(const std::shared_ptr<Model>& model) override;
};
// This transformation replaces pattern prim::DictConstruct->Result
class DictResultResolver : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::DictResultResolver");
bool run_on_model(const std::shared_ptr<Model>& model) override;
};

View File

@ -53,6 +53,8 @@ class PytorchLayerTest:
def numpy_to_torch_recursively(x):
if isinstance(x, tuple):
return tuple(numpy_to_torch_recursively(y) for y in x)
elif isinstance(x, dict):
return dict((k, numpy_to_torch_recursively(y)) for k, y in x.items())
elif isinstance(x, np.ndarray):
return torch.from_numpy(x)
else:
@ -81,9 +83,11 @@ class PytorchLayerTest:
freeze_model = kwargs.get('freeze_model', True)
with torch.no_grad():
if kwargs.get('use_convert_model', False):
smodel, converted_model = self.convert_via_mo(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
smodel, converted_model = self.convert_via_mo(
model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
else:
smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
smodel, converted_model = self.convert_directly_via_frontend(
model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
if kind is not None and not isinstance(kind, (tuple, list)):
kind = [kind]
@ -124,7 +128,7 @@ class PytorchLayerTest:
continue
assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}"
continue
ov_tensor_fw_format = torch.tensor(np.array(ov_tensor))
ov_tensor_fw_format = torch.tensor(np.array(ov_tensor))
assert ov_tensor_fw_format.dtype == fw_tensor.dtype, f"dtype validation failed: {ov_tensor_fw_format.dtype} != {fw_tensor.dtype}"
# Compare Ie results with Framework results
@ -137,15 +141,17 @@ class PytorchLayerTest:
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
quant_size = kwargs['quant_size']
for i in range(len(infer_res)):
cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(
flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
if np.array(cur_fw_res).size == 0:
continue
cur_ov_res = infer_res[compiled.output(i)]
print(f"fw_res: {cur_fw_res};\n ov_res: {cur_ov_res}")
n_is_not_close = np.array(cur_fw_res).size - np.isclose(cur_ov_res, cur_fw_res,
atol=fw_eps,
rtol=fw_eps, equal_nan=True).sum()
max_diff = np.array(abs(np.array(cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max()
atol=fw_eps,
rtol=fw_eps, equal_nan=True).sum()
max_diff = np.array(abs(np.array(
cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max()
if not quantized_ops and n_is_not_close > 0:
is_ok = False
print("Max diff is {}".format(max_diff))
@ -166,11 +172,15 @@ class PytorchLayerTest:
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
from openvino import convert_model, PartialShape
if trace_model:
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]}
decoder = TorchScriptPythonDecoder(
model, example_input=example_input, skip_freeze=not freeze_model)
kwargs = {"example_input": example_input if len(
example_input) > 1 or isinstance(example_input[0], dict) else example_input[0]}
else:
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
kwargs = {"input": [(i.dtype, PartialShape([-1] * len(i.shape))) for i in example_input]}
decoder = TorchScriptPythonDecoder(
model, skip_freeze=not freeze_model)
kwargs = {"input": [(i.dtype, PartialShape(
[-1] * len(i.shape))) for i in example_input]}
smodel = decoder.pt_module
print(smodel.inlined_graph)
if not dynamic_shapes:
@ -185,9 +195,11 @@ class PytorchLayerTest:
fe = fe_manager.load_by_framework('pytorch')
if trace_model:
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
decoder = TorchScriptPythonDecoder(
model, example_input=example_input, skip_freeze=not freeze_model)
else:
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
decoder = TorchScriptPythonDecoder(
model, skip_freeze=not freeze_model)
smodel = decoder.pt_module
print(smodel.inlined_graph)
im = fe.load(decoder)
@ -206,7 +218,8 @@ class PytorchLayerTest:
inp = ov_inputs[i]
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
if params[i].get_node().get_element_type().is_dynamic():
params[i].get_node().set_element_type(self._type_map[inp.dtype.name])
params[i].get_node().set_element_type(
self._type_map[inp.dtype.name])
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
params[i].get_node().set_partial_shape(PartialShape(shape))
om.validate_nodes_and_infer_types()
@ -235,7 +248,6 @@ class PytorchLayerTest:
flatten_ov_res
), f'number of outputs are not equal, {len(flatten_fw_res)} != {len(flatten_ov_res)}'
# Check if output data types match
for fw_tensor, ov_tensor in zip(flatten_fw_res, flatten_ov_res):
if not isinstance(fw_tensor, torch.Tensor) and not isinstance(ov_tensor, torch.Tensor):
@ -264,7 +276,6 @@ class PytorchLayerTest:
assert is_ok, "Accuracy validation failed"
def get_params(ie_device=None, precision=None):
"""
:param ie_device: list of devices
@ -309,4 +320,4 @@ def flattenize_outputs(res):
def flattenize_inputs(res):
return flattenize(res, [tuple])
return flattenize(res, [tuple, dict])

View File

@ -4,6 +4,7 @@
import numpy as np
import pytest
import torch
from typing import Dict
from pytorch_layer_test_class import PytorchLayerTest
@ -15,7 +16,7 @@ class TestDict(PytorchLayerTest):
def create_model(self):
class aten_dict(torch.nn.Module):
def forward(self, x):
def forward(self, x):
return {"b": x, "a": x + x, "c": 2 * x}, x / 2
return aten_dict(), None, "prim::DictConstruct"
@ -23,4 +24,41 @@ class TestDict(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_dict(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)
self._test(*self.create_model(), ie_device, precision,
ir_version, use_convert_model=True)
class aten_dict_with_types(torch.nn.Module):
def forward(self, x_dict: Dict[str, torch.Tensor]):
return x_dict["x1"].to(torch.float32) + x_dict["x2"].to(torch.float32)
class aten_dict_no_types(torch.nn.Module):
def forward(self, x_dict: Dict[str, torch.Tensor]):
return x_dict["x1"] + x_dict["x2"]
class TestDictParam(PytorchLayerTest):
def _prepare_input(self):
return ({"x1": np.random.randn(2, 5, 3, 4).astype(np.float32),
"x2": np.random.randn(2, 5, 3, 4).astype(np.float32)},)
@pytest.mark.nightly
@pytest.mark.precommit
def test_dict_param(self, ie_device, precision, ir_version):
self._test(aten_dict_with_types(), None, "aten::__getitem__", ie_device, precision,
ir_version, trace_model=True)
@pytest.mark.nightly
@pytest.mark.precommit
def test_dict_param_convert_model(self, ie_device, precision, ir_version):
self._test(aten_dict_with_types(), None, "aten::__getitem__", ie_device, precision,
ir_version, trace_model=True, use_convert_model=True)
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.xfail(reason="Type is not propagated from PtFrameworkNode.")
def test_dict_param_no_types(self, ie_device, precision, ir_version):
self._test(aten_dict_no_types(), None, "aten::__getitem__", ie_device, precision,
ir_version, trace_model=True, freeze_model=False)

View File

@ -152,6 +152,8 @@ def to_torch_tensor(tensor):
if isinstance(tensor, (tuple, list)):
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
return tuple(to_torch_tensor(x) for x in tensor)
if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()):
return dict((k, to_torch_tensor(x)) for k, x in tensor.items())
else:
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
"Got {}".format(type(tensor)))

View File

@ -154,6 +154,8 @@ def to_torch_tensor(tensor):
if isinstance(tensor, (tuple, list)):
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
return tuple(to_torch_tensor(x) for x in tensor)
if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()):
return dict((k, to_torch_tensor(x)) for k, x in tensor.items())
else:
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
"Got {}".format(type(tensor)))