From 182b44f6ccc110e67b75e9811303ba72e84a2db0 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 12 Dec 2023 18:58:04 +0100 Subject: [PATCH] [PT FE] Support dict on input (#21450) * [PT FE] Support dict on input * Check that keys are strings * Revert changes in PtFrameworkNode --- src/frontends/pytorch/src/frontend.cpp | 3 +- src/frontends/pytorch/src/op/getitem.cpp | 3 ++ .../pytorch/src/transforms/dict_resolver.cpp | 52 ++++++++++++++++++- .../pytorch/src/transforms/dict_resolver.hpp | 11 +++- .../pytorch_tests/pytorch_layer_test_class.py | 45 ++++++++++------ tests/layer_tests/pytorch_tests/test_dict.py | 42 ++++++++++++++- .../mo/moc_frontend/pytorch_frontend_utils.py | 2 + .../moc_frontend/pytorch_frontend_utils.py | 2 + 8 files changed, 137 insertions(+), 23 deletions(-) diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 15d9be8dafc..af427798a2c 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -200,7 +200,8 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/op/getitem.cpp b/src/frontends/pytorch/src/op/getitem.cpp index 2dc698015f9..58d3639cc8a 100644 --- a/src/frontends/pytorch/src/op/getitem.cpp +++ b/src/frontends/pytorch/src/op/getitem.cpp @@ -19,6 +19,9 @@ using namespace ov::op; OutputVector translate_getitem(const NodeContext& context) { num_inputs_check(context, 2, 2); auto input = context.get_input(0); + const auto idx_type = context.get_input_type(1); + FRONT_END_OP_CONVERSION_CHECK(!idx_type.is(), + "String index in aten::__getitem__ means dict input, this is not supported."); if (ov::as_type_ptr(input.get_node_shared_ptr())) { FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"), "special case for aten::__getitem__"); diff --git a/src/frontends/pytorch/src/transforms/dict_resolver.cpp b/src/frontends/pytorch/src/transforms/dict_resolver.cpp index 455a1fc2cbc..d51eb793813 100644 --- a/src/frontends/pytorch/src/transforms/dict_resolver.cpp +++ b/src/frontends/pytorch/src/transforms/dict_resolver.cpp @@ -5,6 +5,7 @@ #include "dict_resolver.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/op/parameter.hpp" #include "openvino/op/result.hpp" #include "openvino/op/util/framework_node.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -18,7 +19,56 @@ namespace pass { using namespace ov::pass; using namespace ov::op; -bool DictResolver::run_on_model(const std::shared_ptr& model) { +bool DictParameterResolver::run_on_model(const std::shared_ptr& model) { + bool changed = false; + const auto parameters = model->get_parameters(); + ParameterVector new_params; + + for (const auto& p : parameters) { + bool at_least_one_unused = false; + if (p->get_output_size() == 1) { + const auto targets = p->get_output_target_inputs(0); + for (const auto inp : targets) { + const auto getitem_node = cast_fw_node(inp.get_node()->shared_from_this(), "aten::__getitem__"); + if (getitem_node) { + const auto index_node = std::dynamic_pointer_cast( + getitem_node->get_input_node_shared_ptr(1)); + if (!index_node) { + at_least_one_unused = true; + continue; + } + const auto attrs = index_node->get_attrs(); + if (attrs.find("string_value") == attrs.end()) { + // index node must contain string value + at_least_one_unused = true; + continue; + } + const auto name = attrs.at("string_value"); + auto new_param = std::make_shared(getitem_node->get_output_element_type(0), + getitem_node->get_output_partial_shape(0)); + new_param->set_friendly_name(name); + getitem_node->output(0).replace(new_param); + new_params.push_back(new_param); + changed = true; + } else { + at_least_one_unused = true; + } + } + } + if (changed) { + model->remove_parameter(p); + if (at_least_one_unused || p->get_output_size() != 1) { + new_params.push_back(p); + } + } + } + if (changed) { + model->add_parameters(new_params); + } + return changed; +}; + +bool DictResultResolver::run_on_model(const std::shared_ptr& model) { bool changed = false; const auto results = model->get_results(); for (const auto& res : results) { diff --git a/src/frontends/pytorch/src/transforms/dict_resolver.hpp b/src/frontends/pytorch/src/transforms/dict_resolver.hpp index 7cdec639cf2..0494c7dc32d 100644 --- a/src/frontends/pytorch/src/transforms/dict_resolver.hpp +++ b/src/frontends/pytorch/src/transforms/dict_resolver.hpp @@ -12,9 +12,16 @@ namespace frontend { namespace pytorch { namespace pass { -class DictResolver : public ov::pass::ModelPass { +// This transformation replaces pattern Parameter(Dict)->aten::__getitem__ +class DictParameterResolver : public ov::pass::ModelPass { public: - OPENVINO_RTTI("ov::frontend::pytorch::pass::DictResolver"); + OPENVINO_RTTI("ov::frontend::pytorch::pass::DictParameterResolver"); + bool run_on_model(const std::shared_ptr& model) override; +}; +// This transformation replaces pattern prim::DictConstruct->Result +class DictResultResolver : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::DictResultResolver"); bool run_on_model(const std::shared_ptr& model) override; }; diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index adee1fb3ccc..be361704c5a 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -53,6 +53,8 @@ class PytorchLayerTest: def numpy_to_torch_recursively(x): if isinstance(x, tuple): return tuple(numpy_to_torch_recursively(y) for y in x) + elif isinstance(x, dict): + return dict((k, numpy_to_torch_recursively(y)) for k, y in x.items()) elif isinstance(x, np.ndarray): return torch.from_numpy(x) else: @@ -81,9 +83,11 @@ class PytorchLayerTest: freeze_model = kwargs.get('freeze_model', True) with torch.no_grad(): if kwargs.get('use_convert_model', False): - smodel, converted_model = self.convert_via_mo(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) + smodel, converted_model = self.convert_via_mo( + model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) else: - smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) + smodel, converted_model = self.convert_directly_via_frontend( + model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) if kind is not None and not isinstance(kind, (tuple, list)): kind = [kind] @@ -124,7 +128,7 @@ class PytorchLayerTest: continue assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}" continue - ov_tensor_fw_format = torch.tensor(np.array(ov_tensor)) + ov_tensor_fw_format = torch.tensor(np.array(ov_tensor)) assert ov_tensor_fw_format.dtype == fw_tensor.dtype, f"dtype validation failed: {ov_tensor_fw_format.dtype} != {fw_tensor.dtype}" # Compare Ie results with Framework results @@ -137,15 +141,17 @@ class PytorchLayerTest: assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag" quant_size = kwargs['quant_size'] for i in range(len(infer_res)): - cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i] + cur_fw_res = flatten_fw_res[i].contiguous().numpy(force=True) if isinstance( + flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i] if np.array(cur_fw_res).size == 0: continue cur_ov_res = infer_res[compiled.output(i)] print(f"fw_res: {cur_fw_res};\n ov_res: {cur_ov_res}") n_is_not_close = np.array(cur_fw_res).size - np.isclose(cur_ov_res, cur_fw_res, - atol=fw_eps, - rtol=fw_eps, equal_nan=True).sum() - max_diff = np.array(abs(np.array(cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max() + atol=fw_eps, + rtol=fw_eps, equal_nan=True).sum() + max_diff = np.array(abs(np.array( + cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max() if not quantized_ops and n_is_not_close > 0: is_ok = False print("Max diff is {}".format(max_diff)) @@ -166,11 +172,15 @@ class PytorchLayerTest: def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model): from openvino import convert_model, PartialShape if trace_model: - decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model) - kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]} + decoder = TorchScriptPythonDecoder( + model, example_input=example_input, skip_freeze=not freeze_model) + kwargs = {"example_input": example_input if len( + example_input) > 1 or isinstance(example_input[0], dict) else example_input[0]} else: - decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model) - kwargs = {"input": [(i.dtype, PartialShape([-1] * len(i.shape))) for i in example_input]} + decoder = TorchScriptPythonDecoder( + model, skip_freeze=not freeze_model) + kwargs = {"input": [(i.dtype, PartialShape( + [-1] * len(i.shape))) for i in example_input]} smodel = decoder.pt_module print(smodel.inlined_graph) if not dynamic_shapes: @@ -185,9 +195,11 @@ class PytorchLayerTest: fe = fe_manager.load_by_framework('pytorch') if trace_model: - decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model) + decoder = TorchScriptPythonDecoder( + model, example_input=example_input, skip_freeze=not freeze_model) else: - decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model) + decoder = TorchScriptPythonDecoder( + model, skip_freeze=not freeze_model) smodel = decoder.pt_module print(smodel.inlined_graph) im = fe.load(decoder) @@ -206,7 +218,8 @@ class PytorchLayerTest: inp = ov_inputs[i] assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}." if params[i].get_node().get_element_type().is_dynamic(): - params[i].get_node().set_element_type(self._type_map[inp.dtype.name]) + params[i].get_node().set_element_type( + self._type_map[inp.dtype.name]) shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape params[i].get_node().set_partial_shape(PartialShape(shape)) om.validate_nodes_and_infer_types() @@ -235,7 +248,6 @@ class PytorchLayerTest: flatten_ov_res ), f'number of outputs are not equal, {len(flatten_fw_res)} != {len(flatten_ov_res)}' - # Check if output data types match for fw_tensor, ov_tensor in zip(flatten_fw_res, flatten_ov_res): if not isinstance(fw_tensor, torch.Tensor) and not isinstance(ov_tensor, torch.Tensor): @@ -264,7 +276,6 @@ class PytorchLayerTest: assert is_ok, "Accuracy validation failed" - def get_params(ie_device=None, precision=None): """ :param ie_device: list of devices @@ -309,4 +320,4 @@ def flattenize_outputs(res): def flattenize_inputs(res): - return flattenize(res, [tuple]) + return flattenize(res, [tuple, dict]) diff --git a/tests/layer_tests/pytorch_tests/test_dict.py b/tests/layer_tests/pytorch_tests/test_dict.py index 6e4db9dea82..4dfbf0f85c6 100644 --- a/tests/layer_tests/pytorch_tests/test_dict.py +++ b/tests/layer_tests/pytorch_tests/test_dict.py @@ -4,6 +4,7 @@ import numpy as np import pytest import torch +from typing import Dict from pytorch_layer_test_class import PytorchLayerTest @@ -15,7 +16,7 @@ class TestDict(PytorchLayerTest): def create_model(self): class aten_dict(torch.nn.Module): - def forward(self, x): + def forward(self, x): return {"b": x, "a": x + x, "c": 2 * x}, x / 2 return aten_dict(), None, "prim::DictConstruct" @@ -23,4 +24,41 @@ class TestDict(PytorchLayerTest): @pytest.mark.nightly @pytest.mark.precommit def test_dict(self, ie_device, precision, ir_version): - self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True) + self._test(*self.create_model(), ie_device, precision, + ir_version, use_convert_model=True) + + +class aten_dict_with_types(torch.nn.Module): + def forward(self, x_dict: Dict[str, torch.Tensor]): + return x_dict["x1"].to(torch.float32) + x_dict["x2"].to(torch.float32) + + +class aten_dict_no_types(torch.nn.Module): + def forward(self, x_dict: Dict[str, torch.Tensor]): + return x_dict["x1"] + x_dict["x2"] + + +class TestDictParam(PytorchLayerTest): + + def _prepare_input(self): + return ({"x1": np.random.randn(2, 5, 3, 4).astype(np.float32), + "x2": np.random.randn(2, 5, 3, 4).astype(np.float32)},) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_dict_param(self, ie_device, precision, ir_version): + self._test(aten_dict_with_types(), None, "aten::__getitem__", ie_device, precision, + ir_version, trace_model=True) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_dict_param_convert_model(self, ie_device, precision, ir_version): + self._test(aten_dict_with_types(), None, "aten::__getitem__", ie_device, precision, + ir_version, trace_model=True, use_convert_model=True) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.xfail(reason="Type is not propagated from PtFrameworkNode.") + def test_dict_param_no_types(self, ie_device, precision, ir_version): + self._test(aten_dict_no_types(), None, "aten::__getitem__", ie_device, precision, + ir_version, trace_model=True, freeze_model=False) diff --git a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py index 214fbbc4ff7..cf4c611feee 100644 --- a/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py +++ b/tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py @@ -152,6 +152,8 @@ def to_torch_tensor(tensor): if isinstance(tensor, (tuple, list)): # TODO: Function to_torch_tensor should be renamed as it handles not only a tensor return tuple(to_torch_tensor(x) for x in tensor) + if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()): + return dict((k, to_torch_tensor(x)) for k, x in tensor.items()) else: raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. " "Got {}".format(type(tensor))) diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index 882a075b7de..3a24e84af1a 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -154,6 +154,8 @@ def to_torch_tensor(tensor): if isinstance(tensor, (tuple, list)): # TODO: Function to_torch_tensor should be renamed as it handles not only a tensor return tuple(to_torch_tensor(x) for x in tensor) + if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()): + return dict((k, to_torch_tensor(x)) for k, x in tensor.items()) else: raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. " "Got {}".format(type(tensor)))