[PT FE] aten::concat (#19101)
* [PT FE] aten::concat * Update src/frontends/pytorch/src/op/cat.cpp * add out * fix tests
This commit is contained in:
parent
a957764362
commit
9fd8a13fe6
@ -41,10 +41,14 @@ OutputVector translate_cat_common(const NodeContext& context,
|
||||
|
||||
OutputVector translate_cat(const NodeContext& context) {
|
||||
// This translator is only needed to get axis as constant from external scope
|
||||
num_inputs_check(context, 2, 2);
|
||||
num_inputs_check(context, 2, 3);
|
||||
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
|
||||
auto axis = context.const_input<int64_t>(1);
|
||||
return translate_cat_common(context, list_elems, axis);
|
||||
auto out = translate_cat_common(context, list_elems, axis);
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, out[0]);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
OutputVector translate_cat_fx(const NodeContext& context) {
|
||||
|
@ -248,6 +248,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::Bool", op::translate_bool},
|
||||
{"aten::cat", op::translate_cat},
|
||||
{"aten::concat", op::translate_cat},
|
||||
{"aten::cdist", op::translate_cdist},
|
||||
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
||||
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
||||
|
@ -37,6 +37,8 @@ AtenCatToConcat::AtenCatToConcat() {
|
||||
|
||||
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||
auto cat = cast_fw_node(m.get_match_root(), "aten::cat");
|
||||
if (!cat)
|
||||
cat = cast_fw_node(m.get_match_root(), "aten::concat");
|
||||
if (!cat)
|
||||
cat = cast_fw_node(m.get_match_root(), "quantized::cat");
|
||||
if (!cat)
|
||||
|
@ -7,60 +7,100 @@ import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
|
||||
class aten_cat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cat([x, x], 1)
|
||||
return torch.cat(self.prepare_input(x), 1)
|
||||
|
||||
def prepare_input(self, x):
|
||||
return [x, x]
|
||||
|
||||
class aten_append_cat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
class aten_cat_out(aten_cat):
|
||||
def forward(self, x, out):
|
||||
return torch.cat(self.prepare_input(x), 1, out=out), out
|
||||
|
||||
class aten_append_cat(aten_cat):
|
||||
def prepare_input(self, x):
|
||||
list = []
|
||||
list.append(x)
|
||||
list.append(x)
|
||||
return torch.cat(list, 1)
|
||||
return list
|
||||
|
||||
class aten_append_cat_out(aten_cat_out):
|
||||
def prepare_input(self, x):
|
||||
list = []
|
||||
list.append(x)
|
||||
list.append(x)
|
||||
return list
|
||||
|
||||
class aten_loop_append_cat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
class aten_loop_append_cat(aten_cat):
|
||||
def prepare_input(self, x):
|
||||
list = []
|
||||
for i in range(3):
|
||||
list.append(x)
|
||||
return torch.cat(list, 1)
|
||||
return list
|
||||
|
||||
|
||||
class aten_add_cat(torch.nn.Module):
|
||||
class aten_loop_append_cat_out(aten_cat_out):
|
||||
def prepare_input(self, x):
|
||||
list = []
|
||||
for i in range(3):
|
||||
list.append(x)
|
||||
return list
|
||||
|
||||
class aten_add_cat(aten_cat):
|
||||
def forward(self, x):
|
||||
list = [x, x]
|
||||
list2 = list + [x, x]
|
||||
return torch.cat(list2, 1)
|
||||
list1 = self.prepare_input(x)
|
||||
list2 = self.prepare_input(x)
|
||||
return torch.cat(list1 + list2, dim=1)
|
||||
|
||||
|
||||
class aten_add_cat_out(aten_cat_out):
|
||||
def forward(self, x, out):
|
||||
list1 = self.prepare_input(x)
|
||||
list2 = self.prepare_input(x)
|
||||
return torch.cat(list1 + list2, dim=1, out=out)
|
||||
|
||||
class TestCat(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
def _prepare_input(self, out=False, num_repeats=2):
|
||||
import numpy as np
|
||||
return (np.random.randn(2, 1, 3),)
|
||||
data = np.random.randn(2, 1, 3)
|
||||
if not out:
|
||||
return (data, )
|
||||
concat = [data for _ in range(num_repeats)]
|
||||
out = np.zeros_like(np.concatenate(concat, axis=1))
|
||||
return (data, out)
|
||||
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version)
|
||||
@pytest.mark.parametrize("out", [False, True])
|
||||
def test_cat(self, out, ie_device, precision, ir_version):
|
||||
model = aten_cat() if not out else aten_cat_out()
|
||||
self._test(model, None, ["aten::cat", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_append_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version)
|
||||
@pytest.mark.parametrize("out", [False, True])
|
||||
def test_append_cat(self, out, ie_device, precision, ir_version):
|
||||
model = aten_append_cat() if not out else aten_append_cat_out()
|
||||
self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.")
|
||||
def test_loop_append_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
|
||||
ie_device, precision, ir_version, freeze_model=False)
|
||||
@pytest.mark.parametrize("out", [False, True])
|
||||
def test_loop_append_cat(self, out, ie_device, precision, ir_version):
|
||||
model = aten_loop_append_cat() if not out else aten_loop_append_cat_out()
|
||||
self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
|
||||
ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 3})
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_add_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_add_cat(), None, ["aten::cat", "aten::add", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version, freeze_model=False)
|
||||
@pytest.mark.parametrize("out", [False, True])
|
||||
def test_add_cat(self, out, ie_device, precision, ir_version):
|
||||
model = aten_add_cat() if not out else aten_add_cat_out()
|
||||
self._test(model, None, ["aten::cat", "aten::add", "prim::ListConstruct"],
|
||||
ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 4})
|
||||
|
Loading…
Reference in New Issue
Block a user