[PT FE] Fix aten::chunk for dynamic shapes (#16902)

* [PT FE] Add replacer for chunk+getitem

* [PT FE] Fix missing replaced nodes, fix incorrent chunk size calculation

* [PT FE] Fix incorrect item shape, reduce tests count

* [PT FE] Convert back with frontend

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Piotr Krzemiński
2023-05-01 11:32:10 +02:00
committed by GitHub
parent 52bf9abb8c
commit b7311d8907
3 changed files with 96 additions and 4 deletions

View File

@@ -19,6 +19,8 @@ OutputVector translate_getitem(const NodeContext& context) {
if (std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"),
"special case for aten::__getitem__");
FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::chunk"),
"special case for aten::__getitem__");
const auto&& list_elems = get_list_as_outputs(input);
auto getitem_idx = context.const_input<int64_t>(1);
if (getitem_idx < 0) {

View File

@@ -22,6 +22,7 @@
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "pt_framework_node.hpp"
@@ -127,6 +128,57 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
replace_node(getitem, gather);
return true;
}
if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
auto input_tensor = chunk->get_input_source_output(0);
auto chunks_i32 = chunk->get_input_source_output(1);
auto dim_i32 = chunk->get_input_source_output(2);
auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1});
auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0});
auto getitem_index_i32 = getitem->get_input_source_output(1);
auto getitem_index_i64 = std::make_shared<opset10::Convert>(getitem_index_i32, element::i64);
auto getitem_index = std::make_shared<opset10::Unsqueeze>(getitem_index_i64, const_0);
auto dim_i64 = std::make_shared<opset10::Convert>(dim_i32, element::i64);
auto dim = std::make_shared<opset10::Unsqueeze>(dim_i64, const_0);
auto chunks = std::make_shared<opset10::Convert>(chunks_i32, element::i64);
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor);
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, const_0);
auto input_size = std::make_shared<opset10::Squeeze>(input_dimension);
auto chunk_size = std::make_shared<opset10::Divide>(input_size, chunks, true);
auto last_chunk_size = std::make_shared<opset10::Mod>(input_size, chunks);
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, const_0_nodim);
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i64);
auto computed_chunk_size = std::make_shared<opset10::Add>(chunk_size, is_last_nonzero_int);
auto computed_last_chunk_size = std::make_shared<opset10::Mod>(input_size, computed_chunk_size);
auto computed_is_last_nonzero = std::make_shared<opset10::Greater>(computed_last_chunk_size, const_0_nodim);
auto computed_chunks = std::make_shared<opset10::Divide>(input_size, computed_chunk_size, true);
auto is_slice_normal_size = std::make_shared<opset10::Less>(getitem_index, computed_chunks);
auto is_slice_not_normal_size = std::make_shared<opset10::GreaterEqual>(getitem_index, computed_chunks);
auto is_slice_normal_size_int = std::make_shared<opset10::Convert>(is_slice_normal_size, element::i64);
auto is_slice_not_normal_size_int =
std::make_shared<opset10::Convert>(is_slice_not_normal_size, element::i64);
auto slice_size_lhs = std::make_shared<opset10::Multiply>(is_slice_normal_size_int, computed_chunk_size);
auto slice_size_rhs =
std::make_shared<opset10::Multiply>(is_slice_not_normal_size_int, computed_last_chunk_size);
auto slice_size = std::make_shared<opset10::Add>(slice_size_lhs, slice_size_rhs);
auto slice_begin = std::make_shared<opset10::Multiply>(getitem_index, computed_chunk_size);
auto slice_end = std::make_shared<opset10::Add>(slice_begin, slice_size);
auto sliced_chunk = std::make_shared<opset10::Slice>(input_tensor, slice_begin, slice_end, const_1, dim);
copy_runtime_info({getitem, input_node}, sliced_chunk);
replace_node(getitem, sliced_chunk);
return true;
}
return false;
};

View File

@@ -43,6 +43,19 @@ class aten_chunk_4(torch.nn.Module):
)
return a,b,c,d
class aten_chunk_getitem(torch.nn.Module):
def __init__(self, chunks, dim, idx) -> None:
torch.nn.Module.__init__(self)
self.chunks = chunks
self.dim = dim
self.idx = idx
def forward(self, input_tensor):
return torch.chunk(input_tensor,
chunks = self.chunks,
dim = self.dim
)[self.idx]
class TestChunk(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
@@ -52,11 +65,10 @@ class TestChunk(PytorchLayerTest):
np.random.rand(5, 9, 7),
np.random.rand(10, 13, 11),
np.random.rand(8, 7, 6, 5, 4),
np.random.rand(11, 11),
np.random.rand(7, 7),
])
@pytest.mark.parametrize("chunks", [
# 1, Does not work for 1 without translate
# Does not work for 1 - no list_unpack present in the graph
# 1,
2,
3,
4
@@ -81,4 +93,30 @@ class TestChunk(PytorchLayerTest):
cls = aten_chunk_4
self._test(cls(dim), None, "aten::chunk",
ie_device, precision, ir_version)
ie_device, precision, ir_version, dynamic_shapes = False, freeze_model=True, trace_model=True)
@pytest.mark.parametrize("input_tensor", [
np.random.rand(4, 4),
np.random.rand(10, 13, 11),
np.random.rand(8, 7, 6, 5, 4),
])
@pytest.mark.parametrize("chunks", [
2,
3,
4
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_chunk_getitem(self, input_tensor, chunks, ie_device, precision, ir_version):
self.input_tensor = input_tensor
for dim in range(len(input_tensor.shape)):
chunk_size = input_tensor.shape[dim] // chunks
chunk_size += 1 if input_tensor.shape[dim] % chunks > 0 else 0
output_chunks = input_tensor.shape[dim] // chunk_size
output_chunks += 1 if input_tensor.shape[dim] % chunk_size > 0 else 0
for idx in [0, 1, output_chunks - 1]:
self._test(aten_chunk_getitem(chunks, dim, idx), None, "aten::chunk",
ie_device, precision, ir_version)