From 838d792d961caf84954d9570ee12d3c352c55374 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 14 Jun 2023 11:28:19 +0400 Subject: [PATCH] [PT FE]: fix unflatten for list construct sizes (#18039) --- src/frontends/pytorch/src/op/unflatten.cpp | 3 ++ .../pytorch_tests/test_unflatten.py | 36 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/unflatten.cpp b/src/frontends/pytorch/src/op/unflatten.cpp index eff0a5130cc..673efbc1480 100644 --- a/src/frontends/pytorch/src/op/unflatten.cpp +++ b/src/frontends/pytorch/src/op/unflatten.cpp @@ -25,6 +25,9 @@ OutputVector translate_unflatten(const NodeContext& context) { auto input = context.get_input(0); auto dim = context.get_input(1); auto sizes = context.get_input(2); + if (context.get_input_type(2).is()) { + sizes = concat_list_construct(sizes); + } auto input_shape = context.mark_node(std::make_shared(input, element::i32)); auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); diff --git a/tests/layer_tests/pytorch_tests/test_unflatten.py b/tests/layer_tests/pytorch_tests/test_unflatten.py index e260b125e11..3f8e9de3a2b 100644 --- a/tests/layer_tests/pytorch_tests/test_unflatten.py +++ b/tests/layer_tests/pytorch_tests/test_unflatten.py @@ -32,4 +32,38 @@ class TestUnflatten(PytorchLayerTest): @pytest.mark.nightly @pytest.mark.precommit def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version): - self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) \ No newline at end of file + self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) + + +class TestUnflattenListSizes(PytorchLayerTest): + def _prepare_input(self, dtype): + return (np.random.uniform(0, 50, (6, 2, 4)).astype(dtype),) + + def create_model(self, dim): + import torch + + class aten_unflatten(torch.nn.Module): + def __init__(self, dim): + super(aten_unflatten, self).__init__() + self.dim = dim + + def forward(self, x): + dim1, dim2, dim3 = x.shape + if self.dim == 0: + sizes = [dim1, -1] + elif self.dim == 1: + sizes = [dim2 // 2, -1] + else: + sizes = [2, dim3 // 2, -1] + return x.unflatten(self.dim, sizes) + + ref_net = None + + return aten_unflatten(dim), ref_net, "aten::unflatten" + + @pytest.mark.parametrize("dim", [0, 1, 2]) + @pytest.mark.parametrize("dtype", ["float32", "int32"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_unflatten(self, dim, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(dim), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) \ No newline at end of file