[PT FE] Add aten::tensor_split transformation (#19144)
* Add aten::tensor_split * Fix formating * Reduce number of test cases * fix requested changes
This commit is contained in:
parent
680333b2db
commit
f1d61f72ac
@ -5,6 +5,7 @@
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
@ -68,8 +69,8 @@ OutputVector translate_repeat_interleave(const NodeContext& context) {
|
||||
} else {
|
||||
// repeats is not Constant or single element constant
|
||||
// Curently we support only case when repeats contains only one element. Otherwise next Reshape will fail.
|
||||
auto repeats_input =
|
||||
context.mark_node(std::make_shared<v1::Reshape>(context.get_input(1), const_1_list, false));
|
||||
auto repeats_input = context.mark_node(std::make_shared<v0::Convert>(context.get_input(1), element::i32));
|
||||
repeats_input = context.mark_node(std::make_shared<v1::Reshape>(repeats_input, const_1_list, false));
|
||||
auto repeats = context.mark_node(std::make_shared<v0::Concat>(OutputVector{repeats_input, const_1_list}, 0));
|
||||
auto shape_perm = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
|
||||
if (context.input_is_none(2)) {
|
||||
|
@ -111,6 +111,68 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto tensor_split = cast_fw_node(input_node, "aten::tensor_split")) {
|
||||
auto rank = tensor_split->input(1).get_partial_shape().rank();
|
||||
if (rank.is_dynamic()) {
|
||||
add_exception_to_fw_node(tensor_split, "aten::tensor_split: dynamic rank is not supported.");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto const_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1});
|
||||
auto const_0_scalar = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto const_1_scalar = opset10::Constant::create(element::i32, Shape{}, {1});
|
||||
auto const_max = opset10::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
|
||||
auto const_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
||||
|
||||
auto input = tensor_split->get_input_source_output(0);
|
||||
auto indices_or_sections = tensor_split->get_input_source_output(1);
|
||||
auto dim = rg.make<opset10::Unsqueeze>(tensor_split->get_input_source_output(2), const_0);
|
||||
auto list_num_outs = opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size()});
|
||||
auto list_num_outs_scalar =
|
||||
opset10::Constant::create(element::i32, Shape{}, {list_unpack->get_output_size()});
|
||||
|
||||
if (rank.get_length() == 0) {
|
||||
auto input_shape = rg.make<opset10::ShapeOf>(input, element::i32);
|
||||
auto axis_size = rg.make<opset10::Gather>(input_shape, dim, const_0);
|
||||
auto minimum_split_size = rg.make<opset10::Divide>(axis_size, indices_or_sections);
|
||||
auto maximum_split_size = rg.make<opset10::Add>(minimum_split_size, const_1);
|
||||
auto num_splits_with_max_size = rg.make<opset10::Mod>(axis_size, indices_or_sections);
|
||||
auto num_splits_with_min_size =
|
||||
rg.make<opset10::Subtract>(indices_or_sections, num_splits_with_max_size);
|
||||
auto splits_max_size = rg.make<opset10::Tile>(maximum_split_size, num_splits_with_max_size);
|
||||
auto splits_min_size = rg.make<opset10::Tile>(minimum_split_size, num_splits_with_min_size);
|
||||
|
||||
auto split_sizes = rg.make<opset10::Concat>(OutputVector{splits_max_size, splits_min_size}, 0);
|
||||
// Reshape is used to make number of outputs static.
|
||||
auto split_sizes_known_lenght = rg.make<opset10::Reshape>(split_sizes, list_num_outs, false);
|
||||
auto splits = rg.make<opset10::VariadicSplit>(input, dim, split_sizes_known_lenght);
|
||||
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||
replace_node(list_unpack, splits->outputs());
|
||||
return true;
|
||||
} else {
|
||||
auto range =
|
||||
rg.make<opset10::Range>(const_0_scalar, list_num_outs_scalar, const_1_scalar, element::i32);
|
||||
auto range_plus_1 = rg.make<opset10::Add>(range, const_1);
|
||||
auto sections = rg.make<opset10::Concat>(OutputVector{const_0, indices_or_sections, const_max}, 0);
|
||||
|
||||
auto starts_tensor = rg.make<opset10::Slice>(sections, const_0, const_neg_1, const_1, const_0);
|
||||
auto starts =
|
||||
rg.make<opset10::Split>(starts_tensor, const_0_scalar, list_unpack->get_output_size())->outputs();
|
||||
auto stops_tensor = rg.make<opset10::Slice>(sections, const_1, const_max, const_1, const_0);
|
||||
auto stops =
|
||||
rg.make<opset10::Split>(stops_tensor, const_0_scalar, list_unpack->get_output_size())->outputs();
|
||||
OutputVector outputs{};
|
||||
for (size_t i = 0; i < list_unpack->get_output_size(); i++) {
|
||||
auto slice = rg.make<opset10::Slice>(input, starts[i], stops[i], const_1, dim);
|
||||
outputs.push_back(slice);
|
||||
}
|
||||
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||
replace_node(list_unpack, outputs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto unbind = cast_fw_node(input_node, "aten::unbind")) {
|
||||
const auto input = unbind->get_input_source_output(0);
|
||||
const auto axis = unbind->get_input_source_output(1);
|
||||
|
66
tests/layer_tests/pytorch_tests/test_tensor_split.py
Normal file
66
tests/layer_tests/pytorch_tests/test_tensor_split.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Collection
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestTensorSplit(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.rand(*self.input_shape),)
|
||||
|
||||
def create_model(self, splits, axis):
|
||||
class aten_tensor_split(torch.nn.Module):
|
||||
def __init__(self, splits, dim) -> None:
|
||||
super().__init__()
|
||||
self.splits = splits
|
||||
self.dim = dim
|
||||
num_outs = None
|
||||
if isinstance(splits, Number):
|
||||
num_outs = splits
|
||||
elif isinstance(splits, Collection):
|
||||
num_outs = len(splits) + 1
|
||||
self.forward = getattr(self, f"forward_{num_outs}")
|
||||
|
||||
def forward_2(self, input_tensor):
|
||||
a, b = torch.tensor_split(input_tensor, self.splits, dim=self.dim)
|
||||
return a, b
|
||||
|
||||
def forward_3(self, input_tensor):
|
||||
a, b, c = torch.tensor_split(input_tensor, self.splits, dim=self.dim)
|
||||
return a, b, c
|
||||
|
||||
def forward_4(self, input_tensor):
|
||||
a, b, c, d = torch.tensor_split(input_tensor, self.splits, dim=self.dim)
|
||||
return a, b, c, d
|
||||
|
||||
return aten_tensor_split(splits, axis), None, "aten::tensor_split"
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [(2, 1, 8), (3, 5, 7, 11)])
|
||||
@pytest.mark.parametrize(
|
||||
"splits",
|
||||
[
|
||||
# 1, Does not work for 1 - no list_unpack present in the graph
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
[2],
|
||||
[5],
|
||||
[-1],
|
||||
[-5],
|
||||
[1, 3],
|
||||
[1, 3, 5],
|
||||
[5, -1, 7],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("axis", [0, 1, -1])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_tensor_split(self, input_shape, splits, axis, ie_device, precision, ir_version):
|
||||
self.input_shape = input_shape
|
||||
self._test(*self.create_model(splits, axis), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user