From 9fd8a13fe6bea701edf99c49b9dc06dd6ace2405 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 11 Aug 2023 14:49:18 +0300 Subject: [PATCH] [PT FE] aten::concat (#19101) * [PT FE] aten::concat * Update src/frontends/pytorch/src/op/cat.cpp * add out * fix tests --- src/frontends/pytorch/src/op/cat.cpp | 8 +- src/frontends/pytorch/src/op_table.cpp | 1 + .../src/transforms/aten_cat_replacer.cpp | 2 + tests/layer_tests/pytorch_tests/test_cat.py | 90 +++++++++++++------ 4 files changed, 74 insertions(+), 27 deletions(-) diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 76b5a542cf4..7c2a43f0c38 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -41,10 +41,14 @@ OutputVector translate_cat_common(const NodeContext& context, OutputVector translate_cat(const NodeContext& context) { // This translator is only needed to get axis as constant from external scope - num_inputs_check(context, 2, 2); + num_inputs_check(context, 2, 3); const auto&& list_elems = get_list_as_outputs(context.get_input(0)); auto axis = context.const_input(1); - return translate_cat_common(context, list_elems, axis); + auto out = translate_cat_common(context, list_elems, axis); + if (!context.input_is_none(2)) { + context.mutate_input(2, out[0]); + } + return out; }; OutputVector translate_cat_fx(const NodeContext& context) { diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index f3815d17369..6ba7e70ab73 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -248,6 +248,7 @@ const std::map get_supported_ops_ts() { {"aten::bmm", op::translate_1to1_match_2_inputs}, {"aten::Bool", op::translate_bool}, {"aten::cat", op::translate_cat}, + {"aten::concat", op::translate_cat}, {"aten::cdist", op::translate_cdist}, {"aten::ceil", op::translate_1to1_match_1_inputs}, {"aten::ceil_", op::inplace_op>}, diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index 79296bcd6c5..fe0c828a33c 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -37,6 +37,8 @@ AtenCatToConcat::AtenCatToConcat() { ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { auto cat = cast_fw_node(m.get_match_root(), "aten::cat"); + if (!cat) + cat = cast_fw_node(m.get_match_root(), "aten::concat"); if (!cat) cat = cast_fw_node(m.get_match_root(), "quantized::cat"); if (!cat) diff --git a/tests/layer_tests/pytorch_tests/test_cat.py b/tests/layer_tests/pytorch_tests/test_cat.py index b1d3fcef5ea..7d590336ad1 100644 --- a/tests/layer_tests/pytorch_tests/test_cat.py +++ b/tests/layer_tests/pytorch_tests/test_cat.py @@ -7,60 +7,100 @@ import torch from pytorch_layer_test_class import PytorchLayerTest + class aten_cat(torch.nn.Module): def forward(self, x): - return torch.cat([x, x], 1) + return torch.cat(self.prepare_input(x), 1) + def prepare_input(self, x): + return [x, x] -class aten_append_cat(torch.nn.Module): - def forward(self, x): +class aten_cat_out(aten_cat): + def forward(self, x, out): + return torch.cat(self.prepare_input(x), 1, out=out), out + +class aten_append_cat(aten_cat): + def prepare_input(self, x): list = [] list.append(x) list.append(x) - return torch.cat(list, 1) + return list +class aten_append_cat_out(aten_cat_out): + def prepare_input(self, x): + list = [] + list.append(x) + list.append(x) + return list -class aten_loop_append_cat(torch.nn.Module): - def forward(self, x): +class aten_loop_append_cat(aten_cat): + def prepare_input(self, x): list = [] for i in range(3): list.append(x) - return torch.cat(list, 1) + return list -class aten_add_cat(torch.nn.Module): +class aten_loop_append_cat_out(aten_cat_out): + def prepare_input(self, x): + list = [] + for i in range(3): + list.append(x) + return list + +class aten_add_cat(aten_cat): def forward(self, x): - list = [x, x] - list2 = list + [x, x] - return torch.cat(list2, 1) + list1 = self.prepare_input(x) + list2 = self.prepare_input(x) + return torch.cat(list1 + list2, dim=1) +class aten_add_cat_out(aten_cat_out): + def forward(self, x, out): + list1 = self.prepare_input(x) + list2 = self.prepare_input(x) + return torch.cat(list1 + list2, dim=1, out=out) + class TestCat(PytorchLayerTest): - def _prepare_input(self): + def _prepare_input(self, out=False, num_repeats=2): import numpy as np - return (np.random.randn(2, 1, 3),) + data = np.random.randn(2, 1, 3) + if not out: + return (data, ) + concat = [data for _ in range(num_repeats)] + out = np.zeros_like(np.concatenate(concat, axis=1)) + return (data, out) + @pytest.mark.nightly @pytest.mark.precommit - def test_cat(self, ie_device, precision, ir_version): - self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"], - ie_device, precision, ir_version) + @pytest.mark.parametrize("out", [False, True]) + def test_cat(self, out, ie_device, precision, ir_version): + model = aten_cat() if not out else aten_cat_out() + self._test(model, None, ["aten::cat", "prim::ListConstruct"], + ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2}) @pytest.mark.nightly @pytest.mark.precommit - def test_append_cat(self, ie_device, precision, ir_version): - self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"], - ie_device, precision, ir_version) + @pytest.mark.parametrize("out", [False, True]) + def test_append_cat(self, out, ie_device, precision, ir_version): + model = aten_append_cat() if not out else aten_append_cat_out() + self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct"], + ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2}) @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.") - def test_loop_append_cat(self, ie_device, precision, ir_version): - self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"], - ie_device, precision, ir_version, freeze_model=False) + @pytest.mark.parametrize("out", [False, True]) + def test_loop_append_cat(self, out, ie_device, precision, ir_version): + model = aten_loop_append_cat() if not out else aten_loop_append_cat_out() + self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"], + ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 3}) @pytest.mark.nightly @pytest.mark.precommit - def test_add_cat(self, ie_device, precision, ir_version): - self._test(aten_add_cat(), None, ["aten::cat", "aten::add", "prim::ListConstruct"], - ie_device, precision, ir_version, freeze_model=False) + @pytest.mark.parametrize("out", [False, True]) + def test_add_cat(self, out, ie_device, precision, ir_version): + model = aten_add_cat() if not out else aten_add_cat_out() + self._test(model, None, ["aten::cat", "aten::add", "prim::ListConstruct"], + ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 4})