diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 05a0b81fc91..d6bc13fa7b9 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -34,6 +34,7 @@ #include "transforms/prim_list_construct_pad.hpp" #include "transforms/prim_list_tuple_construct_replacer.hpp" #include "transforms/prim_list_unpack_replacer.hpp" +#include "transforms/prim_tuple_unpack_parameter_replacer.hpp" #include "transforms/rfftn_complex_replacer.hpp" #include "transforms/string_equality_replacer.hpp" #include "transforms/tuple_unpack_replacer.hpp" @@ -174,6 +175,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/transforms/prim_list_tuple_construct_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_tuple_construct_replacer.cpp index 097cdfd5f64..29f0d6962f4 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_tuple_construct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_tuple_construct_replacer.cpp @@ -3,7 +3,7 @@ // #include "prim_list_tuple_construct_replacer.hpp" -#include +#include #include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/op/result.hpp" @@ -17,21 +17,24 @@ namespace pass { bool DecomposeListTupleResults::run_on_model(const std::shared_ptr& model) { bool at_least_one_decomposed = false; - std::queue> results; - for (auto res : model->get_results()) { - results.push(res); - } + const auto& orig_results = model->get_results(); + std::deque> results(orig_results.begin(), orig_results.end()); + ov::ResultVector updated_results; // will hold final fully unpacked results list + while (!results.empty()) { auto result = results.front(); - results.pop(); + results.pop_front(); auto input_node = result->get_input_node_shared_ptr(0); auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct"); auto list_construct = cast_fw_node(input_node, "prim::ListConstruct"); if (!tuple_construct && !list_construct) { + updated_results.push_back(result); continue; } - for (const auto& input : input_node->inputs()) { - const auto& out = input.get_source_output(); + const auto& inputs = input_node->inputs(); + // enumerating inputs in reverse order because of results.push_front below + for (auto pinput = inputs.rbegin(); pinput != inputs.rend(); ++pinput) { + const auto& out = pinput->get_source_output(); if (const auto& fw_node = cast_fw_node(out.get_node_shared_ptr(), "prim::Constant")) { const auto& attrs = fw_node->get_attrs(); if (attrs.find("none_value") != attrs.end()) { @@ -42,13 +45,19 @@ bool DecomposeListTupleResults::run_on_model(const std::shared_ptr& model } } auto new_result = std::make_shared(out); - model->add_results({new_result}); - results.push(new_result); - model->remove_result(result); + results.push_front(new_result); at_least_one_decomposed = true; } } + if (at_least_one_decomposed) { + // remove all results + while (!model->get_results().empty()) + model->remove_result(model->get_results()[0]); + // and replace them all by updated list of results + model->add_results(updated_results); + } + return at_least_one_decomposed; }; } // namespace pass diff --git a/src/frontends/pytorch/src/transforms/prim_tuple_unpack_parameter_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_tuple_unpack_parameter_replacer.cpp new file mode 100644 index 00000000000..12577daa6f2 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/prim_tuple_unpack_parameter_replacer.cpp @@ -0,0 +1,122 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "prim_tuple_unpack_parameter_replacer.hpp" + +#include +#include + +#include "openvino/frontend/pytorch/decoder.hpp" +#include "openvino/op/result.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +bool DecomposeTupleParameters::run_on_model(const std::shared_ptr& model) { + bool at_least_one_decomposed = false; + const auto& orig_parameters = model->get_parameters(); + std::deque> parameters(orig_parameters.begin(), orig_parameters.end()); + ov::ParameterVector updated_parameters; // will hold final fully unpacked parameters list + + while (!parameters.empty()) { + auto parameter = parameters.front(); + parameters.pop_front(); + auto consumers = parameter->get_output_target_inputs(0); + size_t num_outputs = 0; // number of outputs in each unpack consumer should match + bool all_unpacks = true; + + // collects all outputs per each consumer operation for this tuple Parameter + std::vector consumer_outputs; + + // The following vector track consumer nodes having prim::TupleUnpack type to form a detailed + // error message in case when parameter replacement is required but not possible. + std::vector> consumer_unpacks; + + for (const auto& consumer : consumers) { + auto node = consumer.get_node()->shared_from_this(); + auto tuple_unpack = cast_fw_node(node, "prim::TupleUnpack"); + if (!tuple_unpack) { + all_unpacks = false; + continue; // need to look at all consumers to form good diagnostics + } + consumer_unpacks.push_back(node); + if (num_outputs == 0) { + num_outputs = node->get_output_size(); + } else if (num_outputs != node->get_output_size()) { + std::stringstream message; + message << "Unpack node " << node + << " as one of the consumers of a tuple, which is introduced by parameter " + << parameter->output(0) << ", has number of outputs " << node->get_output_size() + << " not matching number of outputs " << num_outputs << " for other consumer(s) found earlier."; + add_exception_to_fw_node(node, message.str()); + all_unpacks = false; + break; + } + consumer_outputs.push_back(node->outputs()); + } + + if (!all_unpacks || consumer_outputs.empty()) { + // if at least one consumer is not an unpack-like op or there are not matching number of unpacked objects, + // we cannot replace other unpacks even if they exist, leaving Unpack-op(s) in the graph for this Parameter + + updated_parameters.push_back(parameter); + // In case if at least one Unpack exists there is an opportinity to attach diagnostics + for (const auto& consumer : consumer_unpacks) { + std::stringstream message; + message << "Not prim::TupleUnpack operations exist except this one: " << consumer + << " found as one of the consumers of a tuple, which is introduced by parameter " + << parameter->output(0) << "."; + add_exception_to_fw_node(consumer, message.str()); + } + continue; + } + + // enumerating outputs in reverse order because of parameters.push_front below + for (size_t i = num_outputs; i--;) { + // Merged partial shape and element type among all the consumers of i-th result of unpack ops + PartialShape ps = PartialShape::dynamic(); + element::Type et = element::dynamic; + std::set> inputs; + + for (const auto& outputs : consumer_outputs) { + auto output = outputs[i]; + OPENVINO_ASSERT(PartialShape::merge_into(ps, output.get_partial_shape()), + "Consumers for unpack op have incompatible shape"); + OPENVINO_ASSERT(element::Type::merge(et, et, output.get_element_type()), + "Consumers for unpack op have incompatible types"); + auto target_inputs = output.get_target_inputs(); + inputs.insert(target_inputs.begin(), target_inputs.end()); + } + + auto new_parameter = std::make_shared(et, ps); + + for (auto input : inputs) { + auto names = input.get_tensor().get_names(); + input.replace_source_output(new_parameter->output(0)); + new_parameter->output(0).add_names(names); + } + + // TODO: Assign correct names + parameters.push_front(new_parameter); + at_least_one_decomposed = true; + } + } + + if (at_least_one_decomposed) { + // remove all parameters + while (!model->get_parameters().empty()) + model->remove_parameter(model->get_parameters()[0]); + // and replace them by updated list of parameters + model->add_parameters(updated_parameters); + } + + return at_least_one_decomposed; +}; +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/prim_tuple_unpack_parameter_replacer.hpp b/src/frontends/pytorch/src/transforms/prim_tuple_unpack_parameter_replacer.hpp new file mode 100644 index 00000000000..46007a5c12a --- /dev/null +++ b/src/frontends/pytorch/src/transforms/prim_tuple_unpack_parameter_replacer.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +// This transformation replaces all prim::TupleUnpack operations coming after Parameters with +// more Parameters -- one new parameter for each prim::TupleUnpack output. The original Parameter +// is replaced with these new Parameters preserving the order relative to other Parameters in a model. +// Order of new parameters is the same as the order of prim::TupleUnpack outputs. +// If prim::TupleUnpack has a consumer that is also prim::TupleUnpack, the transformation applies +// the replacement recursively until all prim::TupleUnpacks that take a Parameter output are eliminated. +// +// For example, if a model has the following signature: a, (b, (c, d)), e, where a, b, c, d, and e are +// tensors, and (x1, x2) means tuple consisting two elements x1 and x2, then the resulting model +// after the transformation will have a, b, c, d, e as inputs (without tuples, flattened). +// Note, that there is no special 'tuple' type of an input, tuple structure is restored by +// following prim::TupleUnpack operations in the graph only assuming that they can be applied on +// tuples only and the most nested objects in those tuples are tensors. +class DecomposeTupleParameters : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::DecomposeTupleParameters"); + bool run_on_model(const std::shared_ptr& model) override; +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index ae30a7e3ed8..d69a5c33dba 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -50,8 +50,15 @@ class PytorchLayerTest: else: inputs = self._prepare_input() - torch_inputs = [torch.from_numpy(inp) if isinstance( - inp, np.ndarray) else inp for inp in inputs] + def numpy_to_torch_recursively(x): + if isinstance(x, tuple): + return tuple(numpy_to_torch_recursively(y) for y in x) + elif isinstance(x, np.ndarray): + return torch.from_numpy(x) + else: + return x + + torch_inputs = [numpy_to_torch_recursively(inp) for inp in inputs] if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None: custom_eps = kwargs['custom_eps'] @@ -61,6 +68,8 @@ class PytorchLayerTest: def use_ts_backend(): return(os.environ.get('USE_TS_BACKEND', False)) + ov_inputs = flattenize_inputs(inputs) + if use_ts_backend(): self.ts_backend_test(model, torch_inputs, custom_eps) else: @@ -68,7 +77,7 @@ class PytorchLayerTest: model.eval() trace_model = kwargs.get('trace_model', False) freeze_model = kwargs.get('freeze_model', True) - model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, inputs, freeze_model) + model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model) graph = model.inlined_graph if kind is not None and not isinstance(kind, (tuple, list)): @@ -80,7 +89,7 @@ class PytorchLayerTest: # OV infer: core = Core() compiled = core.compile_model(converted_model, ie_device) - infer_res = compiled(deepcopy(inputs)) + infer_res = compiled(deepcopy(ov_inputs)) if hasattr(self, 'skip_framework') and self.skip_framework: warnings.warn('Framework is skipped') @@ -266,25 +275,33 @@ def get_params(ie_device=None, precision=None): return test_args -def flattenize_dict_outputs(res): +def flattenize_dict_outputs(res, types): if isinstance(res, dict): - return flattenize_outputs(res.values()) + return flattenize(res.values(), types) -def flattenize_outputs(res): +def flattenize(res, types: list): results = [] for res_item in res: # if None is at output we skip it if res_item is None: continue # If input is list or tuple flattenize it - if isinstance(res_item, (list, tuple)): - decomposed_res = flattenize_outputs(res_item) + if isinstance(res_item, (list, tuple)) and type(res_item) in types: + decomposed_res = flattenize(res_item, types) results.extend(decomposed_res) continue - if isinstance(res_item, dict): - decomposed_res = flattenize_dict_outputs(res_item) + if isinstance(res_item, dict) and type(res_item) in types: + decomposed_res = flattenize_dict_outputs(res_item, types) results.extend(decomposed_res) continue results.append(res_item) return results + + +def flattenize_outputs(res): + return flattenize(res, [list, tuple, dict]) + + +def flattenize_inputs(res): + return flattenize(res, [tuple]) diff --git a/tests/layer_tests/pytorch_tests/test_tuple_construct.py b/tests/layer_tests/pytorch_tests/test_tuple_construct.py index 9e782079965..a8bd03731c6 100644 --- a/tests/layer_tests/pytorch_tests/test_tuple_construct.py +++ b/tests/layer_tests/pytorch_tests/test_tuple_construct.py @@ -33,6 +33,11 @@ class TestTupleConstruct(PytorchLayerTest): def forward(self, x): return (x, [None, x + x], None) + class prim_tuple_construct_with_tensor_tail(torch.nn.Module): + + def forward(self, x): + return ((x, x + x), x + x + x) + class prim_tuple_construct_with_list_and_tuple(torch.nn.Module): def forward(self, x): @@ -43,6 +48,7 @@ class TestTupleConstruct(PytorchLayerTest): "multiple": prim_tuple_construct, "none": prim_tuple_construct_with_none, "list": prim_tuple_construct_with_list, + "tensor_tail": prim_tuple_construct_with_tensor_tail, "list_and_tuple": prim_tuple_construct_with_list_and_tuple } @@ -51,11 +57,11 @@ class TestTupleConstruct(PytorchLayerTest): return model(), ref_net, "prim::TupleConstruct" - @pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"]) + @pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "tensor_tail", "list_and_tuple"]) @pytest.mark.nightly def test_tuple_construct(self, case, ie_device, precision, ir_version): self._test(*self.create_model(case), ie_device, precision, ir_version) - + class TestTupleConstructTupleUnpack(PytorchLayerTest): def _prepare_input(self): @@ -69,7 +75,7 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest): def forward(self, x): x1, x2, x3, x4, x5 = self.prepare_input(x) return x1, x2, x3, x4, x5 - + def prepare_input(self, x): return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32) @@ -80,4 +86,106 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest): @pytest.mark.nightly def test_tuple_construct_unpack(self, ie_device, precision, ir_version): - self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False) \ No newline at end of file + self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False) + + +class TestTupleUnpackParameterSingle(PytorchLayerTest): + def _prepare_input(self): + def tensor_gen(): + return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32) + return ( (tensor_gen(), tensor_gen()), ) + + def create_model(self): + import torch + from typing import Tuple + + class model(torch.nn.Module): + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]): + x1, x2 = x + return x1, x2 + + + return model(), None, ["prim::TupleUnpack"] + + @pytest.mark.nightly + def test(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version) + + +class TestTupleUnpackParameterSingleMixed(PytorchLayerTest): + def _prepare_input(self): + def tensor_gen(): + return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32) + # generate tensor with a different shape for easier mismatch detection in case of mixed input order + def tensor_gen_2(): + return np.random.uniform(0, 50, (2, 3)).astype(np.float32) + return (tensor_gen_2(), (tensor_gen(), tensor_gen()), tensor_gen_2()) + + def create_model(self): + import torch + from typing import Tuple + + class model(torch.nn.Module): + + def forward(self, y1, x: Tuple[torch.Tensor, torch.Tensor], y2): + x1, x2 = x + return x1, x2, y1, y2 + + + return model(), None, ["prim::TupleUnpack"] + + @pytest.mark.nightly + def test(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version) + + +class TestTupleUnpackParameterNested(PytorchLayerTest): + def _prepare_input(self): + def tensor_gen(): + return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32) + return ( ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), ) + + def create_model(self): + import torch + from typing import Tuple + + class model(torch.nn.Module): + + def forward(self, x: Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]): + x1, x2 = x + y1, y2 = x1 + y3, y4 = x2 + return y1, y2, y3, y4 + + + return model(), None, ["prim::TupleUnpack"] + + @pytest.mark.nightly + def test(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version) + + +class TestTupleUnpackParameterMultiple(PytorchLayerTest): + def _prepare_input(self): + def tensor_gen(): + return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32) + return ( (tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()) ) + + def create_model(self): + import torch + from typing import Tuple + + class model(torch.nn.Module): + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]): + z1, z2 = x + z3, z4 = y + return z1, z2, z3, z4 + + + return model(), None, ["prim::TupleUnpack"] + + @pytest.mark.nightly + def test(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index ae9973ebab8..fd7c3ee7384 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -43,7 +43,7 @@ def update_list_or_dict(container, name, idx, value): container[idx] = value return - + def get_value_from_list_or_dict(container, name, idx): if isinstance(container, dict): if name is None: @@ -87,8 +87,8 @@ def extract_input_info_from_example(args, inputs): example_dtype = pt_to_ov_type_map.get(str(dtype)) user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id) if user_dtype is not None and example_dtype.to_dtype() != user_dtype: - raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}") - + raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}") + data_rank = getattr(example_input, "ndim", 0) user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id) if user_input_shape.rank.get_length() != data_rank: @@ -108,7 +108,7 @@ def extract_input_info_from_example(args, inputs): input_name = input_names[input_id] if input_names else None update_list_or_dict(input_shapes, input_name, input_id, input_shape) update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype()) - + args.placeholder_data_types = data_types args.placeholder_shapes = input_shapes if not args.input and input_names: @@ -126,6 +126,9 @@ def to_torch_tensor(tensor): return torch.tensor(tensor.data) if isinstance(tensor, (float, int, bool)): return tensor + if isinstance(tensor, tuple): + # TODO: Function to_torch_tensor should be renamed as it handles not only a tensor + return tuple(to_torch_tensor(x) for x in tensor) else: raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. " "Got {}".format(type(tensor)))