[PT FE] Add aten::tensor_split transformation (#19144)

* Add aten::tensor_split

* Fix formating

* Reduce number of test cases

* fix requested changes
This commit is contained in:
Mateusz Mikolajczyk 2023-08-14 18:31:38 +02:00 committed by GitHub
parent 680333b2db
commit f1d61f72ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 2 deletions

View File

@ -5,6 +5,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
@ -68,8 +69,8 @@ OutputVector translate_repeat_interleave(const NodeContext& context) {
} else {
// repeats is not Constant or single element constant
// Curently we support only case when repeats contains only one element. Otherwise next Reshape will fail.
auto repeats_input =
context.mark_node(std::make_shared<v1::Reshape>(context.get_input(1), const_1_list, false));
auto repeats_input = context.mark_node(std::make_shared<v0::Convert>(context.get_input(1), element::i32));
repeats_input = context.mark_node(std::make_shared<v1::Reshape>(repeats_input, const_1_list, false));
auto repeats = context.mark_node(std::make_shared<v0::Concat>(OutputVector{repeats_input, const_1_list}, 0));
auto shape_perm = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
if (context.input_is_none(2)) {

View File

@ -111,6 +111,68 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
return true;
}
if (auto tensor_split = cast_fw_node(input_node, "aten::tensor_split")) {
auto rank = tensor_split->input(1).get_partial_shape().rank();
if (rank.is_dynamic()) {
add_exception_to_fw_node(tensor_split, "aten::tensor_split: dynamic rank is not supported.");
return false;
}
auto const_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1});
auto const_0_scalar = opset10::Constant::create(element::i32, Shape{}, {0});
auto const_1_scalar = opset10::Constant::create(element::i32, Shape{}, {1});
auto const_max = opset10::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
auto const_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
auto input = tensor_split->get_input_source_output(0);
auto indices_or_sections = tensor_split->get_input_source_output(1);
auto dim = rg.make<opset10::Unsqueeze>(tensor_split->get_input_source_output(2), const_0);
auto list_num_outs = opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size()});
auto list_num_outs_scalar =
opset10::Constant::create(element::i32, Shape{}, {list_unpack->get_output_size()});
if (rank.get_length() == 0) {
auto input_shape = rg.make<opset10::ShapeOf>(input, element::i32);
auto axis_size = rg.make<opset10::Gather>(input_shape, dim, const_0);
auto minimum_split_size = rg.make<opset10::Divide>(axis_size, indices_or_sections);
auto maximum_split_size = rg.make<opset10::Add>(minimum_split_size, const_1);
auto num_splits_with_max_size = rg.make<opset10::Mod>(axis_size, indices_or_sections);
auto num_splits_with_min_size =
rg.make<opset10::Subtract>(indices_or_sections, num_splits_with_max_size);
auto splits_max_size = rg.make<opset10::Tile>(maximum_split_size, num_splits_with_max_size);
auto splits_min_size = rg.make<opset10::Tile>(minimum_split_size, num_splits_with_min_size);
auto split_sizes = rg.make<opset10::Concat>(OutputVector{splits_max_size, splits_min_size}, 0);
// Reshape is used to make number of outputs static.
auto split_sizes_known_lenght = rg.make<opset10::Reshape>(split_sizes, list_num_outs, false);
auto splits = rg.make<opset10::VariadicSplit>(input, dim, split_sizes_known_lenght);
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, splits->outputs());
return true;
} else {
auto range =
rg.make<opset10::Range>(const_0_scalar, list_num_outs_scalar, const_1_scalar, element::i32);
auto range_plus_1 = rg.make<opset10::Add>(range, const_1);
auto sections = rg.make<opset10::Concat>(OutputVector{const_0, indices_or_sections, const_max}, 0);
auto starts_tensor = rg.make<opset10::Slice>(sections, const_0, const_neg_1, const_1, const_0);
auto starts =
rg.make<opset10::Split>(starts_tensor, const_0_scalar, list_unpack->get_output_size())->outputs();
auto stops_tensor = rg.make<opset10::Slice>(sections, const_1, const_max, const_1, const_0);
auto stops =
rg.make<opset10::Split>(stops_tensor, const_0_scalar, list_unpack->get_output_size())->outputs();
OutputVector outputs{};
for (size_t i = 0; i < list_unpack->get_output_size(); i++) {
auto slice = rg.make<opset10::Slice>(input, starts[i], stops[i], const_1, dim);
outputs.push_back(slice);
}
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs);
return true;
}
}
if (auto unbind = cast_fw_node(input_node, "aten::unbind")) {
const auto input = unbind->get_input_source_output(0);
const auto axis = unbind->get_input_source_output(1);

View File

@ -0,0 +1,66 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Collection
from numbers import Number
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestTensorSplit(PytorchLayerTest):
def _prepare_input(self):
return (np.random.rand(*self.input_shape),)
def create_model(self, splits, axis):
class aten_tensor_split(torch.nn.Module):
def __init__(self, splits, dim) -> None:
super().__init__()
self.splits = splits
self.dim = dim
num_outs = None
if isinstance(splits, Number):
num_outs = splits
elif isinstance(splits, Collection):
num_outs = len(splits) + 1
self.forward = getattr(self, f"forward_{num_outs}")
def forward_2(self, input_tensor):
a, b = torch.tensor_split(input_tensor, self.splits, dim=self.dim)
return a, b
def forward_3(self, input_tensor):
a, b, c = torch.tensor_split(input_tensor, self.splits, dim=self.dim)
return a, b, c
def forward_4(self, input_tensor):
a, b, c, d = torch.tensor_split(input_tensor, self.splits, dim=self.dim)
return a, b, c, d
return aten_tensor_split(splits, axis), None, "aten::tensor_split"
@pytest.mark.parametrize("input_shape", [(2, 1, 8), (3, 5, 7, 11)])
@pytest.mark.parametrize(
"splits",
[
# 1, Does not work for 1 - no list_unpack present in the graph
2,
3,
4,
[2],
[5],
[-1],
[-5],
[1, 3],
[1, 3, 5],
[5, -1, 7],
],
)
@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.nightly
@pytest.mark.precommit
def test_tensor_split(self, input_shape, splits, axis, ie_device, precision, ir_version):
self.input_shape = input_shape
self._test(*self.create_model(splits, axis), ie_device, precision, ir_version)