[PT FE] Add aten::Chunk implementation (#16035)
* [PT FE] Add chunk implementation: * [PT FE] Fix chunk int64 instead of const node errors, add tests for chunking * [PT FE] Test Chunk-If implementation * [PT FE] Change the translate to replace chunk implementation, use VariadicSplit instead of Slice * [PT FE] Reduce artifacts from debugging * Update test_chunk.py * [PT FE] Improve & debug chunk implementation: * [PT FE] Simplify implementation, fix remaining bugs * [PT FE] Statify the split lenghts output * [PT FE] Clear code, remove debugging artifacts
This commit is contained in:
parent
7d16ee1835
commit
6b70c449ba
@ -74,14 +74,51 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
}
|
||||
|
||||
if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
|
||||
// Using number of ListUnpack outputs instead of 1st input to chunk.
|
||||
// TODO: confirm it works for all cases
|
||||
auto split = std::make_shared<opset10::Split>(chunk->get_input_source_output(0),
|
||||
chunk->get_input_source_output(2),
|
||||
list_unpack->get_output_size());
|
||||
auto input_tensor = chunk->get_input_source_output(0);
|
||||
auto chunks_i32 = chunk->get_input_source_output(1);
|
||||
auto dim = chunk->get_input_source_output(2);
|
||||
|
||||
copy_runtime_info({list_unpack, input_node}, split);
|
||||
replace_node(list_unpack, split);
|
||||
auto chunks = std::make_shared<opset10::Convert>(chunks_i32, element::i64);
|
||||
auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto const_1_nodim = opset10::Constant::create(element::i64, Shape{}, {1});
|
||||
auto const_shape = opset10::Constant::create(element::i64, Shape{1}, {list_unpack->get_output_size()});
|
||||
|
||||
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor);
|
||||
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, const_0);
|
||||
auto input_size = std::make_shared<opset10::Squeeze>(input_dimension);
|
||||
|
||||
auto chunk_size = std::make_shared<opset10::Divide>(input_size, chunks, true);
|
||||
auto last_chunk_size = std::make_shared<opset10::Mod>(input_size, chunks);
|
||||
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, const_0_nodim);
|
||||
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i64);
|
||||
|
||||
auto computed_chunk_size = std::make_shared<opset10::Add>(chunk_size, is_last_nonzero_int);
|
||||
auto computed_chunk_size_incr = std::make_shared<opset10::Add>(computed_chunk_size, const_1_nodim);
|
||||
auto computed_last_chunk_size = std::make_shared<opset10::Mod>(input_size, computed_chunk_size);
|
||||
auto computed_is_last_nonzero = std::make_shared<opset10::Greater>(computed_last_chunk_size, const_0_nodim);
|
||||
auto computed_is_last_nonzero_int =
|
||||
std::make_shared<opset10::Convert>(computed_is_last_nonzero, element::i64);
|
||||
auto computed_is_last_nonzero_int_unsq =
|
||||
std::make_shared<opset10::Unsqueeze>(computed_is_last_nonzero_int, const_0);
|
||||
auto computed_chunks = std::make_shared<opset10::Divide>(input_size, computed_chunk_size, true);
|
||||
auto computed_chunks_unsq = std::make_shared<opset10::Unsqueeze>(computed_chunks, const_0);
|
||||
|
||||
auto chunk_lengths = std::make_shared<opset10::RandomUniform>(computed_chunks_unsq,
|
||||
computed_chunk_size,
|
||||
computed_chunk_size_incr,
|
||||
element::i64);
|
||||
auto split_lengths = std::make_shared<opset10::Pad>(chunk_lengths,
|
||||
const_0,
|
||||
computed_is_last_nonzero_int_unsq,
|
||||
computed_last_chunk_size,
|
||||
ov::op::PadMode::CONSTANT);
|
||||
auto split_lengths_static = std::make_shared<opset10::Reshape>(split_lengths, const_shape, false);
|
||||
auto sliced_chunks = std::make_shared<opset10::VariadicSplit>(input_tensor, dim, split_lengths_static);
|
||||
|
||||
copy_runtime_info({list_unpack, input_node}, sliced_chunks);
|
||||
replace_node(list_unpack, sliced_chunks);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
84
tests/layer_tests/pytorch_tests/test_chunk.py
Normal file
84
tests/layer_tests/pytorch_tests/test_chunk.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
class aten_chunk_2(torch.nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input_tensor):
|
||||
a,b = torch.chunk(input_tensor,
|
||||
chunks = 2,
|
||||
dim = self.dim
|
||||
)
|
||||
return a,b
|
||||
|
||||
class aten_chunk_3(torch.nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input_tensor):
|
||||
a,b,c = torch.chunk(input_tensor,
|
||||
chunks = 3,
|
||||
dim = self.dim
|
||||
)
|
||||
return a,b,c
|
||||
|
||||
class aten_chunk_4(torch.nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input_tensor):
|
||||
a,b,c,d = torch.chunk(input_tensor,
|
||||
chunks = 4,
|
||||
dim = self.dim
|
||||
)
|
||||
return a,b,c,d
|
||||
|
||||
class TestChunk(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor,)
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", [
|
||||
np.random.rand(4, 4),
|
||||
np.random.rand(5, 9, 7),
|
||||
np.random.rand(10, 13, 11),
|
||||
np.random.rand(8, 7, 6, 5, 4),
|
||||
np.random.rand(11, 11),
|
||||
np.random.rand(7, 7),
|
||||
])
|
||||
@pytest.mark.parametrize("chunks", [
|
||||
# 1, Does not work for 1 without translate
|
||||
2,
|
||||
3,
|
||||
4
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_chunk(self, input_tensor, chunks, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
for dim in range(len(input_tensor.shape)):
|
||||
chunk_size = input_tensor.shape[dim] // chunks
|
||||
chunk_size += 1 if input_tensor.shape[dim] % chunks > 0 else 0
|
||||
|
||||
output_chunks = input_tensor.shape[dim] // chunk_size
|
||||
output_chunks += 1 if input_tensor.shape[dim] % chunk_size > 0 else 0
|
||||
|
||||
if output_chunks == 2:
|
||||
cls = aten_chunk_2
|
||||
elif output_chunks == 3:
|
||||
cls = aten_chunk_3
|
||||
elif output_chunks == 4:
|
||||
cls = aten_chunk_4
|
||||
|
||||
self._test(cls(dim), None, "aten::chunk",
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user