[PT FE] Support dict on input (#21450)
* [PT FE] Support dict on input * Check that keys are strings * Revert changes in PtFrameworkNode
This commit is contained in:
parent
3e42ddbde5
commit
182b44f6cc
@ -200,7 +200,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DictParameterResolver>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DictResultResolver>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
|
||||
|
@ -19,6 +19,9 @@ using namespace ov::op;
|
||||
OutputVector translate_getitem(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input = context.get_input(0);
|
||||
const auto idx_type = context.get_input_type(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!idx_type.is<type::Str>(),
|
||||
"String index in aten::__getitem__ means dict input, this is not supported.");
|
||||
if (ov::as_type_ptr<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"),
|
||||
"special case for aten::__getitem__");
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "dict_resolver.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
@ -18,7 +19,56 @@ namespace pass {
|
||||
using namespace ov::pass;
|
||||
using namespace ov::op;
|
||||
|
||||
bool DictResolver::run_on_model(const std::shared_ptr<Model>& model) {
|
||||
bool DictParameterResolver::run_on_model(const std::shared_ptr<Model>& model) {
|
||||
bool changed = false;
|
||||
const auto parameters = model->get_parameters();
|
||||
ParameterVector new_params;
|
||||
|
||||
for (const auto& p : parameters) {
|
||||
bool at_least_one_unused = false;
|
||||
if (p->get_output_size() == 1) {
|
||||
const auto targets = p->get_output_target_inputs(0);
|
||||
for (const auto inp : targets) {
|
||||
const auto getitem_node = cast_fw_node(inp.get_node()->shared_from_this(), "aten::__getitem__");
|
||||
if (getitem_node) {
|
||||
const auto index_node = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(
|
||||
getitem_node->get_input_node_shared_ptr(1));
|
||||
if (!index_node) {
|
||||
at_least_one_unused = true;
|
||||
continue;
|
||||
}
|
||||
const auto attrs = index_node->get_attrs();
|
||||
if (attrs.find("string_value") == attrs.end()) {
|
||||
// index node must contain string value
|
||||
at_least_one_unused = true;
|
||||
continue;
|
||||
}
|
||||
const auto name = attrs.at("string_value");
|
||||
auto new_param = std::make_shared<v0::Parameter>(getitem_node->get_output_element_type(0),
|
||||
getitem_node->get_output_partial_shape(0));
|
||||
new_param->set_friendly_name(name);
|
||||
getitem_node->output(0).replace(new_param);
|
||||
new_params.push_back(new_param);
|
||||
changed = true;
|
||||
} else {
|
||||
at_least_one_unused = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
model->remove_parameter(p);
|
||||
if (at_least_one_unused || p->get_output_size() != 1) {
|
||||
new_params.push_back(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
model->add_parameters(new_params);
|
||||
}
|
||||
return changed;
|
||||
};
|
||||
|
||||
bool DictResultResolver::run_on_model(const std::shared_ptr<Model>& model) {
|
||||
bool changed = false;
|
||||
const auto results = model->get_results();
|
||||
for (const auto& res : results) {
|
||||
|
@ -12,9 +12,16 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
class DictResolver : public ov::pass::ModelPass {
|
||||
// This transformation replaces pattern Parameter(Dict)->aten::__getitem__
|
||||
class DictParameterResolver : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::DictResolver");
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::DictParameterResolver");
|
||||
bool run_on_model(const std::shared_ptr<Model>& model) override;
|
||||
};
|
||||
// This transformation replaces pattern prim::DictConstruct->Result
|
||||
class DictResultResolver : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::DictResultResolver");
|
||||
bool run_on_model(const std::shared_ptr<Model>& model) override;
|
||||
};
|
||||
|
||||
|
@ -53,6 +53,8 @@ class PytorchLayerTest:
|
||||
def numpy_to_torch_recursively(x):
|
||||
if isinstance(x, tuple):
|
||||
return tuple(numpy_to_torch_recursively(y) for y in x)
|
||||
elif isinstance(x, dict):
|
||||
return dict((k, numpy_to_torch_recursively(y)) for k, y in x.items())
|
||||
elif isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x)
|
||||
else:
|
||||
@ -81,9 +83,11 @@ class PytorchLayerTest:
|
||||
freeze_model = kwargs.get('freeze_model', True)
|
||||
with torch.no_grad():
|
||||
if kwargs.get('use_convert_model', False):
|
||||
smodel, converted_model = self.convert_via_mo(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
|
||||
smodel, converted_model = self.convert_via_mo(
|
||||
model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
|
||||
else:
|
||||
smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
|
||||
smodel, converted_model = self.convert_directly_via_frontend(
|
||||
model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
|
||||
|
||||
if kind is not None and not isinstance(kind, (tuple, list)):
|
||||
kind = [kind]
|
||||
@ -124,7 +128,7 @@ class PytorchLayerTest:
|
||||
continue
|
||||
assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}"
|
||||
continue
|
||||
ov_tensor_fw_format = torch.tensor(np.array(ov_tensor))
|
||||
ov_tensor_fw_format = torch.tensor(np.array(ov_tensor))
|
||||
assert ov_tensor_fw_format.dtype == fw_tensor.dtype, f"dtype validation failed: {ov_tensor_fw_format.dtype} != {fw_tensor.dtype}"
|
||||
|
||||
# Compare Ie results with Framework results
|
||||
@ -137,15 +141,17 @@ class PytorchLayerTest:
|
||||
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
|
||||
quant_size = kwargs['quant_size']
|
||||
for i in range(len(infer_res)):
|
||||
cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
|
||||
cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(
|
||||
flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
|
||||
if np.array(cur_fw_res).size == 0:
|
||||
continue
|
||||
cur_ov_res = infer_res[compiled.output(i)]
|
||||
print(f"fw_res: {cur_fw_res};\n ov_res: {cur_ov_res}")
|
||||
n_is_not_close = np.array(cur_fw_res).size - np.isclose(cur_ov_res, cur_fw_res,
|
||||
atol=fw_eps,
|
||||
rtol=fw_eps, equal_nan=True).sum()
|
||||
max_diff = np.array(abs(np.array(cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max()
|
||||
atol=fw_eps,
|
||||
rtol=fw_eps, equal_nan=True).sum()
|
||||
max_diff = np.array(abs(np.array(
|
||||
cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max()
|
||||
if not quantized_ops and n_is_not_close > 0:
|
||||
is_ok = False
|
||||
print("Max diff is {}".format(max_diff))
|
||||
@ -166,11 +172,15 @@ class PytorchLayerTest:
|
||||
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
|
||||
from openvino import convert_model, PartialShape
|
||||
if trace_model:
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
|
||||
kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]}
|
||||
decoder = TorchScriptPythonDecoder(
|
||||
model, example_input=example_input, skip_freeze=not freeze_model)
|
||||
kwargs = {"example_input": example_input if len(
|
||||
example_input) > 1 or isinstance(example_input[0], dict) else example_input[0]}
|
||||
else:
|
||||
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
|
||||
kwargs = {"input": [(i.dtype, PartialShape([-1] * len(i.shape))) for i in example_input]}
|
||||
decoder = TorchScriptPythonDecoder(
|
||||
model, skip_freeze=not freeze_model)
|
||||
kwargs = {"input": [(i.dtype, PartialShape(
|
||||
[-1] * len(i.shape))) for i in example_input]}
|
||||
smodel = decoder.pt_module
|
||||
print(smodel.inlined_graph)
|
||||
if not dynamic_shapes:
|
||||
@ -185,9 +195,11 @@ class PytorchLayerTest:
|
||||
fe = fe_manager.load_by_framework('pytorch')
|
||||
|
||||
if trace_model:
|
||||
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
|
||||
decoder = TorchScriptPythonDecoder(
|
||||
model, example_input=example_input, skip_freeze=not freeze_model)
|
||||
else:
|
||||
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
|
||||
decoder = TorchScriptPythonDecoder(
|
||||
model, skip_freeze=not freeze_model)
|
||||
smodel = decoder.pt_module
|
||||
print(smodel.inlined_graph)
|
||||
im = fe.load(decoder)
|
||||
@ -206,7 +218,8 @@ class PytorchLayerTest:
|
||||
inp = ov_inputs[i]
|
||||
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
|
||||
if params[i].get_node().get_element_type().is_dynamic():
|
||||
params[i].get_node().set_element_type(self._type_map[inp.dtype.name])
|
||||
params[i].get_node().set_element_type(
|
||||
self._type_map[inp.dtype.name])
|
||||
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
|
||||
params[i].get_node().set_partial_shape(PartialShape(shape))
|
||||
om.validate_nodes_and_infer_types()
|
||||
@ -235,7 +248,6 @@ class PytorchLayerTest:
|
||||
flatten_ov_res
|
||||
), f'number of outputs are not equal, {len(flatten_fw_res)} != {len(flatten_ov_res)}'
|
||||
|
||||
|
||||
# Check if output data types match
|
||||
for fw_tensor, ov_tensor in zip(flatten_fw_res, flatten_ov_res):
|
||||
if not isinstance(fw_tensor, torch.Tensor) and not isinstance(ov_tensor, torch.Tensor):
|
||||
@ -264,7 +276,6 @@ class PytorchLayerTest:
|
||||
assert is_ok, "Accuracy validation failed"
|
||||
|
||||
|
||||
|
||||
def get_params(ie_device=None, precision=None):
|
||||
"""
|
||||
:param ie_device: list of devices
|
||||
@ -309,4 +320,4 @@ def flattenize_outputs(res):
|
||||
|
||||
|
||||
def flattenize_inputs(res):
|
||||
return flattenize(res, [tuple])
|
||||
return flattenize(res, [tuple, dict])
|
||||
|
@ -4,6 +4,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
@ -15,7 +16,7 @@ class TestDict(PytorchLayerTest):
|
||||
|
||||
def create_model(self):
|
||||
class aten_dict(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
def forward(self, x):
|
||||
return {"b": x, "a": x + x, "c": 2 * x}, x / 2
|
||||
|
||||
return aten_dict(), None, "prim::DictConstruct"
|
||||
@ -23,4 +24,41 @@ class TestDict(PytorchLayerTest):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_dict(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)
|
||||
self._test(*self.create_model(), ie_device, precision,
|
||||
ir_version, use_convert_model=True)
|
||||
|
||||
|
||||
class aten_dict_with_types(torch.nn.Module):
|
||||
def forward(self, x_dict: Dict[str, torch.Tensor]):
|
||||
return x_dict["x1"].to(torch.float32) + x_dict["x2"].to(torch.float32)
|
||||
|
||||
|
||||
class aten_dict_no_types(torch.nn.Module):
|
||||
def forward(self, x_dict: Dict[str, torch.Tensor]):
|
||||
return x_dict["x1"] + x_dict["x2"]
|
||||
|
||||
|
||||
class TestDictParam(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return ({"x1": np.random.randn(2, 5, 3, 4).astype(np.float32),
|
||||
"x2": np.random.randn(2, 5, 3, 4).astype(np.float32)},)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_dict_param(self, ie_device, precision, ir_version):
|
||||
self._test(aten_dict_with_types(), None, "aten::__getitem__", ie_device, precision,
|
||||
ir_version, trace_model=True)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_dict_param_convert_model(self, ie_device, precision, ir_version):
|
||||
self._test(aten_dict_with_types(), None, "aten::__getitem__", ie_device, precision,
|
||||
ir_version, trace_model=True, use_convert_model=True)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.xfail(reason="Type is not propagated from PtFrameworkNode.")
|
||||
def test_dict_param_no_types(self, ie_device, precision, ir_version):
|
||||
self._test(aten_dict_no_types(), None, "aten::__getitem__", ie_device, precision,
|
||||
ir_version, trace_model=True, freeze_model=False)
|
||||
|
@ -152,6 +152,8 @@ def to_torch_tensor(tensor):
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
|
||||
return tuple(to_torch_tensor(x) for x in tensor)
|
||||
if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()):
|
||||
return dict((k, to_torch_tensor(x)) for k, x in tensor.items())
|
||||
else:
|
||||
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
|
||||
"Got {}".format(type(tensor)))
|
||||
|
@ -154,6 +154,8 @@ def to_torch_tensor(tensor):
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
|
||||
return tuple(to_torch_tensor(x) for x in tensor)
|
||||
if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()):
|
||||
return dict((k, to_torch_tensor(x)) for k, x in tensor.items())
|
||||
else:
|
||||
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
|
||||
"Got {}".format(type(tensor)))
|
||||
|
Loading…
Reference in New Issue
Block a user