[PT FE] aten::concat (#19101)
* [PT FE] aten::concat * Update src/frontends/pytorch/src/op/cat.cpp * add out * fix tests
This commit is contained in:
parent
a957764362
commit
9fd8a13fe6
@ -41,10 +41,14 @@ OutputVector translate_cat_common(const NodeContext& context,
|
|||||||
|
|
||||||
OutputVector translate_cat(const NodeContext& context) {
|
OutputVector translate_cat(const NodeContext& context) {
|
||||||
// This translator is only needed to get axis as constant from external scope
|
// This translator is only needed to get axis as constant from external scope
|
||||||
num_inputs_check(context, 2, 2);
|
num_inputs_check(context, 2, 3);
|
||||||
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
|
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
|
||||||
auto axis = context.const_input<int64_t>(1);
|
auto axis = context.const_input<int64_t>(1);
|
||||||
return translate_cat_common(context, list_elems, axis);
|
auto out = translate_cat_common(context, list_elems, axis);
|
||||||
|
if (!context.input_is_none(2)) {
|
||||||
|
context.mutate_input(2, out[0]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
};
|
};
|
||||||
|
|
||||||
OutputVector translate_cat_fx(const NodeContext& context) {
|
OutputVector translate_cat_fx(const NodeContext& context) {
|
||||||
|
@ -248,6 +248,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
|||||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||||
{"aten::Bool", op::translate_bool},
|
{"aten::Bool", op::translate_bool},
|
||||||
{"aten::cat", op::translate_cat},
|
{"aten::cat", op::translate_cat},
|
||||||
|
{"aten::concat", op::translate_cat},
|
||||||
{"aten::cdist", op::translate_cdist},
|
{"aten::cdist", op::translate_cdist},
|
||||||
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
||||||
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
||||||
|
@ -37,6 +37,8 @@ AtenCatToConcat::AtenCatToConcat() {
|
|||||||
|
|
||||||
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||||
auto cat = cast_fw_node(m.get_match_root(), "aten::cat");
|
auto cat = cast_fw_node(m.get_match_root(), "aten::cat");
|
||||||
|
if (!cat)
|
||||||
|
cat = cast_fw_node(m.get_match_root(), "aten::concat");
|
||||||
if (!cat)
|
if (!cat)
|
||||||
cat = cast_fw_node(m.get_match_root(), "quantized::cat");
|
cat = cast_fw_node(m.get_match_root(), "quantized::cat");
|
||||||
if (!cat)
|
if (!cat)
|
||||||
|
@ -7,60 +7,100 @@ import torch
|
|||||||
from pytorch_layer_test_class import PytorchLayerTest
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class aten_cat(torch.nn.Module):
|
class aten_cat(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.cat([x, x], 1)
|
return torch.cat(self.prepare_input(x), 1)
|
||||||
|
|
||||||
|
def prepare_input(self, x):
|
||||||
|
return [x, x]
|
||||||
|
|
||||||
class aten_append_cat(torch.nn.Module):
|
class aten_cat_out(aten_cat):
|
||||||
def forward(self, x):
|
def forward(self, x, out):
|
||||||
|
return torch.cat(self.prepare_input(x), 1, out=out), out
|
||||||
|
|
||||||
|
class aten_append_cat(aten_cat):
|
||||||
|
def prepare_input(self, x):
|
||||||
list = []
|
list = []
|
||||||
list.append(x)
|
list.append(x)
|
||||||
list.append(x)
|
list.append(x)
|
||||||
return torch.cat(list, 1)
|
return list
|
||||||
|
|
||||||
|
class aten_append_cat_out(aten_cat_out):
|
||||||
|
def prepare_input(self, x):
|
||||||
|
list = []
|
||||||
|
list.append(x)
|
||||||
|
list.append(x)
|
||||||
|
return list
|
||||||
|
|
||||||
class aten_loop_append_cat(torch.nn.Module):
|
class aten_loop_append_cat(aten_cat):
|
||||||
def forward(self, x):
|
def prepare_input(self, x):
|
||||||
list = []
|
list = []
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
list.append(x)
|
list.append(x)
|
||||||
return torch.cat(list, 1)
|
return list
|
||||||
|
|
||||||
|
|
||||||
class aten_add_cat(torch.nn.Module):
|
class aten_loop_append_cat_out(aten_cat_out):
|
||||||
|
def prepare_input(self, x):
|
||||||
|
list = []
|
||||||
|
for i in range(3):
|
||||||
|
list.append(x)
|
||||||
|
return list
|
||||||
|
|
||||||
|
class aten_add_cat(aten_cat):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
list = [x, x]
|
list1 = self.prepare_input(x)
|
||||||
list2 = list + [x, x]
|
list2 = self.prepare_input(x)
|
||||||
return torch.cat(list2, 1)
|
return torch.cat(list1 + list2, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class aten_add_cat_out(aten_cat_out):
|
||||||
|
def forward(self, x, out):
|
||||||
|
list1 = self.prepare_input(x)
|
||||||
|
list2 = self.prepare_input(x)
|
||||||
|
return torch.cat(list1 + list2, dim=1, out=out)
|
||||||
|
|
||||||
class TestCat(PytorchLayerTest):
|
class TestCat(PytorchLayerTest):
|
||||||
def _prepare_input(self):
|
def _prepare_input(self, out=False, num_repeats=2):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
return (np.random.randn(2, 1, 3),)
|
data = np.random.randn(2, 1, 3)
|
||||||
|
if not out:
|
||||||
|
return (data, )
|
||||||
|
concat = [data for _ in range(num_repeats)]
|
||||||
|
out = np.zeros_like(np.concatenate(concat, axis=1))
|
||||||
|
return (data, out)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_cat(self, ie_device, precision, ir_version):
|
@pytest.mark.parametrize("out", [False, True])
|
||||||
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
|
def test_cat(self, out, ie_device, precision, ir_version):
|
||||||
ie_device, precision, ir_version)
|
model = aten_cat() if not out else aten_cat_out()
|
||||||
|
self._test(model, None, ["aten::cat", "prim::ListConstruct"],
|
||||||
|
ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
|
||||||
|
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_append_cat(self, ie_device, precision, ir_version):
|
@pytest.mark.parametrize("out", [False, True])
|
||||||
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
|
def test_append_cat(self, out, ie_device, precision, ir_version):
|
||||||
ie_device, precision, ir_version)
|
model = aten_append_cat() if not out else aten_append_cat_out()
|
||||||
|
self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct"],
|
||||||
|
ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
|
||||||
|
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
@pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.")
|
@pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.")
|
||||||
def test_loop_append_cat(self, ie_device, precision, ir_version):
|
@pytest.mark.parametrize("out", [False, True])
|
||||||
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
|
def test_loop_append_cat(self, out, ie_device, precision, ir_version):
|
||||||
ie_device, precision, ir_version, freeze_model=False)
|
model = aten_loop_append_cat() if not out else aten_loop_append_cat_out()
|
||||||
|
self._test(model, None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
|
||||||
|
ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 3})
|
||||||
|
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_add_cat(self, ie_device, precision, ir_version):
|
@pytest.mark.parametrize("out", [False, True])
|
||||||
self._test(aten_add_cat(), None, ["aten::cat", "aten::add", "prim::ListConstruct"],
|
def test_add_cat(self, out, ie_device, precision, ir_version):
|
||||||
ie_device, precision, ir_version, freeze_model=False)
|
model = aten_add_cat() if not out else aten_add_cat_out()
|
||||||
|
self._test(model, None, ["aten::cat", "aten::add", "prim::ListConstruct"],
|
||||||
|
ie_device, precision, ir_version, freeze_model=False, kwargs_to_prepare_input={"out": out, "num_repeats": 4})
|
||||||
|
Loading…
Reference in New Issue
Block a user