diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index b4b53e6ce3a..e4b429cad6c 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -15,6 +15,7 @@ #include "transforms/append_list_unpack_replacer.hpp" #include "transforms/aten_cat_replacer.hpp" #include "transforms/aten_getitem_replacer.hpp" +#include "transforms/aten_stack_list_construct_replacer.hpp" #include "transforms/listconstruct_reshape_replacer.hpp" #include "transforms/max_prim_list_construct_replacer.hpp" #include "transforms/prim_list_construct_pad.hpp" @@ -86,6 +87,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp new file mode 100644 index 00000000000..241110bc612 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "aten_stack_list_construct_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +using namespace ov::pass::pattern; + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +AtenStackListConstructReplacer::AtenStackListConstructReplacer() { + auto list_construct = ov::pass::pattern::wrap_type(); + auto axis = ov::pass::pattern::wrap_type(); + + // We search for a pattern: ListConstruct -> aten::stack <- Constant + auto stack = ov::pass::pattern::wrap_type({list_construct, axis}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + auto stack = cast_fw_node(m.get_match_root(), "aten::stack"); + if (!stack) { + return false; + } + const auto& pattern_map = m.get_pattern_value_map(); + auto input_node = pattern_map.at(list_construct).get_node_shared_ptr(); + auto axis_node = pattern_map.at(axis).get_node_shared_ptr(); + auto axis_const = std::dynamic_pointer_cast(axis_node); + auto axis = axis_const->cast_vector(); + // Check if ListConstruct is an input + if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) { + const auto& list_inputs = list_construct_node->input_values(); + OutputVector node_vector; + auto zero = opset10::Constant::create(element::i32, Shape{}, {0}); + // Iterate over values in ListConstruct + for (const auto& list_input : list_inputs) { + auto node = concat_list_construct(list_input.get_node_shared_ptr()); + auto unsqueezed_node = std::make_shared(node, axis_const); + node_vector.push_back(unsqueezed_node); + } + // Concat vectors on provided axis + auto concat = std::make_shared(node_vector, axis[0]); + + copy_runtime_info({stack, input_node}, concat); + replace_node(stack, concat); + return true; + } + return false; + }; + + auto m = std::make_shared(stack, "ov::frontend::pytorch::pass::AtenStackListConstructReplacer"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.hpp b/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.hpp new file mode 100644 index 00000000000..4ac808c817c --- /dev/null +++ b/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +class AtenStackListConstructReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenStackListConstructReplacer"); + AtenStackListConstructReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_stack.py b/tests/layer_tests/pytorch_tests/test_stack.py new file mode 100644 index 00000000000..670033c7b29 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_stack.py @@ -0,0 +1,77 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestStack2D(PytorchLayerTest): + def _prepare_input(self): + return self.input_tensors + + def create_model(self, dim): + import torch + + class aten_stack(torch.nn.Module): + def __init__(self, dim): + super(aten_stack, self).__init__() + self.dim = dim + + def forward(self, x, y): + inputs = [x, y] + return torch.stack(inputs, self.dim) + + ref_net = None + + return aten_stack(dim), ref_net, "aten::stack" + + @pytest.mark.parametrize("input_tensor", ([ + [np.random.rand(1, 3, 3), np.random.rand(1, 3, 3)], + [np.random.rand(4, 4, 2), np.random.rand(4, 4, 2)], + [np.random.rand(8, 1, 1, 9), np.random.rand(8, 1, 1, 9)] + ])) + @pytest.mark.parametrize("dim", ([ + 0, 1, 2, + ])) + @pytest.mark.nightly + @pytest.mark.precommit + def test_stack2D(self, input_tensor, dim, ie_device, precision, ir_version): + self.input_tensors = input_tensor + self._test(*self.create_model(dim), ie_device, precision, ir_version) + + +class TestStack3D(PytorchLayerTest): + def _prepare_input(self): + return self.input_tensors + + def create_model(self, dim): + import torch + + class aten_stack(torch.nn.Module): + def __init__(self, dim): + super(aten_stack, self).__init__() + self.dim = dim + + def forward(self, x, y, z): + inputs = [x, y, z] + return torch.stack(inputs, self.dim) + + ref_net = None + + return aten_stack(dim), ref_net, "aten::stack" + + @pytest.mark.parametrize("input_tensor", ([ + [np.random.rand(1, 3, 3), np.random.rand(1, 3, 3), np.random.rand(1, 3, 3)], + [np.random.rand(4, 4, 2), np.random.rand(4, 4, 2), np.random.rand(4, 4, 2)], + [np.random.rand(8, 1, 1, 9), np.random.rand(8, 1, 1, 9), np.random.rand(8, 1, 1, 9)] + ])) + @pytest.mark.parametrize("dim", ([ + 0, 1, 2, + ])) + @pytest.mark.nightly + @pytest.mark.precommit + def test_stack3D(self, input_tensor, dim, ie_device, precision, ir_version): + self.input_tensors = input_tensor + self._test(*self.create_model(dim), ie_device, precision, ir_version)