[PT FE] aten::concat (#19101)

* [PT FE] aten::concat

* Update src/frontends/pytorch/src/op/cat.cpp

* add out

* fix tests
This commit is contained in:
Ekaterina Aidova 2023-08-11 14:49:18 +03:00 committed by GitHub
parent a957764362
commit 9fd8a13fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 27 deletions

View File

@ -41,10 +41,14 @@ OutputVector translate_cat_common(const NodeContext& context,
OutputVector translate_cat(const NodeContext& context) {
// This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 2, 2);
num_inputs_check(context, 2, 3);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
auto axis = context.const_input<int64_t>(1);
return translate_cat_common(context, list_elems, axis);
auto out = translate_cat_common(context, list_elems, axis);
if (!context.input_is_none(2)) {
context.mutate_input(2, out[0]);
}
return out;
};
OutputVector translate_cat_fx(const NodeContext& context) {

View File

@ -248,6 +248,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::Bool", op::translate_bool},
{"aten::cat", op::translate_cat},
{"aten::concat", op::translate_cat},
{"aten::cdist", op::translate_cdist},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},

View File

@ -37,6 +37,8 @@ AtenCatToConcat::AtenCatToConcat() {
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto cat = cast_fw_node(m.get_match_root(), "aten::cat");
if (!cat)
cat = cast_fw_node(m.get_match_root(), "aten::concat");
if (!cat)
cat = cast_fw_node(m.get_match_root(), "quantized::cat");
if (!cat)

View File

@ -7,60 +7,100 @@ import torch
from pytorch_layer_test_class import PytorchLayerTest
class aten_cat(torch.nn.Module):
def forward(self, x):
return torch.cat([x, x], 1)
return torch.cat(self.prepare_input(x), 1)
def prepare_input(self, x):
return [x, x]
class aten_append_cat(torch.nn.Module):
def forward(self, x):
class aten_cat_out(aten_cat):
def forward(self, x, out):
return torch.cat(self.prepare_input(x), 1, out=out), out
class aten_append_cat(aten_cat):
def prepare_input(self, x):
list = []
list.append(x)
list.append(x)
return torch.cat(list, 1)
return list
class aten_append_cat_out(aten_cat_out):
def prepare_input(self, x):
list = []
list.append(x)
list.append(x)
return list
class aten_loop_append_cat(torch.nn.Module):
def forward(self, x):
class aten_loop_append_cat(aten_cat):
def prepare_input(self, x):
list = []
for i in range(3):
list.append(x)
return torch.cat(list, 1)
return list
class aten_add_cat(torch.nn.Module):
class aten_loop_append_cat_out(aten_cat_out):
def prepare_input(self, x):
list = []
for i in range(3):
list.append(x)
return list
class aten_add_cat(aten_cat):
def forward(self, x):
list = [x, x]
list2 = list + [x, x]
return torch.cat(list2, 1)
list1 = self.prepare_input(x)
list2 = self.prepare_input(x)
return torch.cat(list1 + list2, dim=1)
class aten_add_cat_out(aten_cat_out):
def forward(self, x, out):
list1 = self.prepare_input(x)
list2 = self.prepare_input(x)
return torch.cat(list1 + list2, dim=1, out=out)
class TestCat(PytorchLayerTest):
def _prepare_input(self):
def _prepare_input(self, out=False, num_repeats=2):
import numpy as np
return (np.random.randn(2, 1, 3),)
data = np.random.randn(2, 1, 3)
if not out:
return (data, )
concat = [data for _ in range(num_repeats)]
out = np.zeros_like(np.concatenate(concat, axis=1))
return (data, out)
@pytest.mark.nightly
@pytest.mark.precommit
def test_cat(self, ie_device, precision, ir_version):
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
ie_device, precision, ir_version)
@pytest.mark.parametrize("out", [False, True])
def test_cat(self, out, ie_device, precision, ir_version):
model = aten_cat() if not out else aten_cat_out()
self._test(model, None, ["aten::cat", "prim::ListConstruct"],
ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
@pytest.mark.nightly
@pytest.mark.precommit
def test_append_cat(self, ie_device, precision, ir_version):
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
ie_device, precision, ir_version)
@pytest.mark.parametrize("out", [False, True])
def test_append_cat(self, out, ie_device, precision, ir_version):
model = aten_append_cat() if not out else aten_append_cat_out()
self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct"],
ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.")
def test_loop_append_cat(self, ie_device, precision, ir_version):
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
ie_device, precision, ir_version, freeze_model=False)
@pytest.mark.parametrize("out", [False, True])
def test_loop_append_cat(self, out, ie_device, precision, ir_version):
model = aten_loop_append_cat() if not out else aten_loop_append_cat_out()
self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 3})
@pytest.mark.nightly
@pytest.mark.precommit
def test_add_cat(self, ie_device, precision, ir_version):
self._test(aten_add_cat(), None, ["aten::cat", "aten::add", "prim::ListConstruct"],
ie_device, precision, ir_version, freeze_model=False)
@pytest.mark.parametrize("out", [False, True])
def test_add_cat(self, out, ie_device, precision, ir_version):
model = aten_add_cat() if not out else aten_add_cat_out()
self._test(model, None, ["aten::cat", "aten::add", "prim::ListConstruct"],
ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 4})