diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index cb7704a99a6..931901b1933 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -74,14 +74,51 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { } if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { - // Using number of ListUnpack outputs instead of 1st input to chunk. - // TODO: confirm it works for all cases - auto split = std::make_shared(chunk->get_input_source_output(0), - chunk->get_input_source_output(2), - list_unpack->get_output_size()); + auto input_tensor = chunk->get_input_source_output(0); + auto chunks_i32 = chunk->get_input_source_output(1); + auto dim = chunk->get_input_source_output(2); - copy_runtime_info({list_unpack, input_node}, split); - replace_node(list_unpack, split); + auto chunks = std::make_shared(chunks_i32, element::i64); + auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0}); + auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1}); + auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0}); + auto const_1_nodim = opset10::Constant::create(element::i64, Shape{}, {1}); + auto const_shape = opset10::Constant::create(element::i64, Shape{1}, {list_unpack->get_output_size()}); + + auto input_shape = std::make_shared(input_tensor); + auto input_dimension = std::make_shared(input_shape, dim, const_0); + auto input_size = std::make_shared(input_dimension); + + auto chunk_size = std::make_shared(input_size, chunks, true); + auto last_chunk_size = std::make_shared(input_size, chunks); + auto is_last_nonzero = std::make_shared(last_chunk_size, const_0_nodim); + auto is_last_nonzero_int = std::make_shared(is_last_nonzero, element::i64); + + auto computed_chunk_size = std::make_shared(chunk_size, is_last_nonzero_int); + auto computed_chunk_size_incr = std::make_shared(computed_chunk_size, const_1_nodim); + auto computed_last_chunk_size = std::make_shared(input_size, computed_chunk_size); + auto computed_is_last_nonzero = std::make_shared(computed_last_chunk_size, const_0_nodim); + auto computed_is_last_nonzero_int = + std::make_shared(computed_is_last_nonzero, element::i64); + auto computed_is_last_nonzero_int_unsq = + std::make_shared(computed_is_last_nonzero_int, const_0); + auto computed_chunks = std::make_shared(input_size, computed_chunk_size, true); + auto computed_chunks_unsq = std::make_shared(computed_chunks, const_0); + + auto chunk_lengths = std::make_shared(computed_chunks_unsq, + computed_chunk_size, + computed_chunk_size_incr, + element::i64); + auto split_lengths = std::make_shared(chunk_lengths, + const_0, + computed_is_last_nonzero_int_unsq, + computed_last_chunk_size, + ov::op::PadMode::CONSTANT); + auto split_lengths_static = std::make_shared(split_lengths, const_shape, false); + auto sliced_chunks = std::make_shared(input_tensor, dim, split_lengths_static); + + copy_runtime_info({list_unpack, input_node}, sliced_chunks); + replace_node(list_unpack, sliced_chunks); return true; } diff --git a/tests/layer_tests/pytorch_tests/test_chunk.py b/tests/layer_tests/pytorch_tests/test_chunk.py new file mode 100644 index 00000000000..76b2776b10a --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_chunk.py @@ -0,0 +1,84 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class aten_chunk_2(torch.nn.Module): + def __init__(self, dim) -> None: + torch.nn.Module.__init__(self) + self.dim = dim + + def forward(self, input_tensor): + a,b = torch.chunk(input_tensor, + chunks = 2, + dim = self.dim + ) + return a,b + +class aten_chunk_3(torch.nn.Module): + def __init__(self, dim) -> None: + torch.nn.Module.__init__(self) + self.dim = dim + + def forward(self, input_tensor): + a,b,c = torch.chunk(input_tensor, + chunks = 3, + dim = self.dim + ) + return a,b,c + +class aten_chunk_4(torch.nn.Module): + def __init__(self, dim) -> None: + torch.nn.Module.__init__(self) + self.dim = dim + + def forward(self, input_tensor): + a,b,c,d = torch.chunk(input_tensor, + chunks = 4, + dim = self.dim + ) + return a,b,c,d + +class TestChunk(PytorchLayerTest): + def _prepare_input(self): + return (self.input_tensor,) + + @pytest.mark.parametrize("input_tensor", [ + np.random.rand(4, 4), + np.random.rand(5, 9, 7), + np.random.rand(10, 13, 11), + np.random.rand(8, 7, 6, 5, 4), + np.random.rand(11, 11), + np.random.rand(7, 7), + ]) + @pytest.mark.parametrize("chunks", [ + # 1, Does not work for 1 without translate + 2, + 3, + 4 + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_chunk(self, input_tensor, chunks, ie_device, precision, ir_version): + self.input_tensor = input_tensor + + for dim in range(len(input_tensor.shape)): + chunk_size = input_tensor.shape[dim] // chunks + chunk_size += 1 if input_tensor.shape[dim] % chunks > 0 else 0 + + output_chunks = input_tensor.shape[dim] // chunk_size + output_chunks += 1 if input_tensor.shape[dim] % chunk_size > 0 else 0 + + if output_chunks == 2: + cls = aten_chunk_2 + elif output_chunks == 3: + cls = aten_chunk_3 + elif output_chunks == 4: + cls = aten_chunk_4 + + self._test(cls(dim), None, "aten::chunk", + ie_device, precision, ir_version)