[PT FE] Fix aten::chunk for dynamic shapes (#16902)
* [PT FE] Add replacer for chunk+getitem * [PT FE] Fix missing replaced nodes, fix incorrent chunk size calculation * [PT FE] Fix incorrect item shape, reduce tests count * [PT FE] Convert back with frontend --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
@@ -19,6 +19,8 @@ OutputVector translate_getitem(const NodeContext& context) {
|
||||
if (std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"),
|
||||
"special case for aten::__getitem__");
|
||||
FRONT_END_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::chunk"),
|
||||
"special case for aten::__getitem__");
|
||||
const auto&& list_elems = get_list_as_outputs(input);
|
||||
auto getitem_idx = context.const_input<int64_t>(1);
|
||||
if (getitem_idx < 0) {
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/op/variadic_split.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
@@ -127,6 +128,57 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
||||
replace_node(getitem, gather);
|
||||
return true;
|
||||
}
|
||||
if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
|
||||
auto input_tensor = chunk->get_input_source_output(0);
|
||||
auto chunks_i32 = chunk->get_input_source_output(1);
|
||||
auto dim_i32 = chunk->get_input_source_output(2);
|
||||
|
||||
auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
|
||||
auto getitem_index_i32 = getitem->get_input_source_output(1);
|
||||
auto getitem_index_i64 = std::make_shared<opset10::Convert>(getitem_index_i32, element::i64);
|
||||
auto getitem_index = std::make_shared<opset10::Unsqueeze>(getitem_index_i64, const_0);
|
||||
auto dim_i64 = std::make_shared<opset10::Convert>(dim_i32, element::i64);
|
||||
auto dim = std::make_shared<opset10::Unsqueeze>(dim_i64, const_0);
|
||||
auto chunks = std::make_shared<opset10::Convert>(chunks_i32, element::i64);
|
||||
|
||||
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor);
|
||||
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, const_0);
|
||||
auto input_size = std::make_shared<opset10::Squeeze>(input_dimension);
|
||||
|
||||
auto chunk_size = std::make_shared<opset10::Divide>(input_size, chunks, true);
|
||||
auto last_chunk_size = std::make_shared<opset10::Mod>(input_size, chunks);
|
||||
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, const_0_nodim);
|
||||
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i64);
|
||||
|
||||
auto computed_chunk_size = std::make_shared<opset10::Add>(chunk_size, is_last_nonzero_int);
|
||||
auto computed_last_chunk_size = std::make_shared<opset10::Mod>(input_size, computed_chunk_size);
|
||||
auto computed_is_last_nonzero = std::make_shared<opset10::Greater>(computed_last_chunk_size, const_0_nodim);
|
||||
auto computed_chunks = std::make_shared<opset10::Divide>(input_size, computed_chunk_size, true);
|
||||
|
||||
auto is_slice_normal_size = std::make_shared<opset10::Less>(getitem_index, computed_chunks);
|
||||
auto is_slice_not_normal_size = std::make_shared<opset10::GreaterEqual>(getitem_index, computed_chunks);
|
||||
auto is_slice_normal_size_int = std::make_shared<opset10::Convert>(is_slice_normal_size, element::i64);
|
||||
auto is_slice_not_normal_size_int =
|
||||
std::make_shared<opset10::Convert>(is_slice_not_normal_size, element::i64);
|
||||
|
||||
auto slice_size_lhs = std::make_shared<opset10::Multiply>(is_slice_normal_size_int, computed_chunk_size);
|
||||
auto slice_size_rhs =
|
||||
std::make_shared<opset10::Multiply>(is_slice_not_normal_size_int, computed_last_chunk_size);
|
||||
auto slice_size = std::make_shared<opset10::Add>(slice_size_lhs, slice_size_rhs);
|
||||
|
||||
auto slice_begin = std::make_shared<opset10::Multiply>(getitem_index, computed_chunk_size);
|
||||
auto slice_end = std::make_shared<opset10::Add>(slice_begin, slice_size);
|
||||
|
||||
auto sliced_chunk = std::make_shared<opset10::Slice>(input_tensor, slice_begin, slice_end, const_1, dim);
|
||||
|
||||
copy_runtime_info({getitem, input_node}, sliced_chunk);
|
||||
replace_node(getitem, sliced_chunk);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -43,6 +43,19 @@ class aten_chunk_4(torch.nn.Module):
|
||||
)
|
||||
return a,b,c,d
|
||||
|
||||
class aten_chunk_getitem(torch.nn.Module):
|
||||
def __init__(self, chunks, dim, idx) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.chunks = chunks
|
||||
self.dim = dim
|
||||
self.idx = idx
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return torch.chunk(input_tensor,
|
||||
chunks = self.chunks,
|
||||
dim = self.dim
|
||||
)[self.idx]
|
||||
|
||||
class TestChunk(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor,)
|
||||
@@ -52,11 +65,10 @@ class TestChunk(PytorchLayerTest):
|
||||
np.random.rand(5, 9, 7),
|
||||
np.random.rand(10, 13, 11),
|
||||
np.random.rand(8, 7, 6, 5, 4),
|
||||
np.random.rand(11, 11),
|
||||
np.random.rand(7, 7),
|
||||
])
|
||||
@pytest.mark.parametrize("chunks", [
|
||||
# 1, Does not work for 1 without translate
|
||||
# Does not work for 1 - no list_unpack present in the graph
|
||||
# 1,
|
||||
2,
|
||||
3,
|
||||
4
|
||||
@@ -81,4 +93,30 @@ class TestChunk(PytorchLayerTest):
|
||||
cls = aten_chunk_4
|
||||
|
||||
self._test(cls(dim), None, "aten::chunk",
|
||||
ie_device, precision, ir_version)
|
||||
ie_device, precision, ir_version, dynamic_shapes = False, freeze_model=True, trace_model=True)
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", [
|
||||
np.random.rand(4, 4),
|
||||
np.random.rand(10, 13, 11),
|
||||
np.random.rand(8, 7, 6, 5, 4),
|
||||
])
|
||||
@pytest.mark.parametrize("chunks", [
|
||||
2,
|
||||
3,
|
||||
4
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_chunk_getitem(self, input_tensor, chunks, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
for dim in range(len(input_tensor.shape)):
|
||||
|
||||
chunk_size = input_tensor.shape[dim] // chunks
|
||||
chunk_size += 1 if input_tensor.shape[dim] % chunks > 0 else 0
|
||||
|
||||
output_chunks = input_tensor.shape[dim] // chunk_size
|
||||
output_chunks += 1 if input_tensor.shape[dim] % chunk_size > 0 else 0
|
||||
|
||||
for idx in [0, 1, output_chunks - 1]:
|
||||
self._test(aten_chunk_getitem(chunks, dim, idx), None, "aten::chunk",
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
Reference in New Issue
Block a user