support aten::channel_shuffle (#20240)

* support aten::channel_shuffle

* remove getting rank
This commit is contained in:
Ekaterina Aidova 2023-10-10 10:16:26 +04:00 committed by Alexander Nesterov
parent 9adfaca1a8
commit 85814ff8a0
3 changed files with 71 additions and 5 deletions

View File

@ -6,6 +6,7 @@
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
@ -15,6 +16,7 @@
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
namespace ov {
@ -67,6 +69,35 @@ OutputVector translate_pixel_shuffle(const NodeContext& context) {
return {context.mark_node(std::make_shared<v1::Reshape>(transpose, shape_after, false))};
};
OutputVector translate_channel_shuffle(const NodeContext& context) {
// aten::channel_shuffle(Tensor self, int groups) -> Tensor
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto groups = context.get_input(1);
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
// PyTorch realization uses assumption that channels dim is always 1
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 1}));
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero));
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero, 2));
auto c = dims_splitted->output(1);
auto n = dims_splitted->output(0);
groups = context.mark_node(std::make_shared<v0::Convert>(groups, element::i32));
auto k = context.mark_node(std::make_shared<v1::Divide>(c, groups, true));
auto g = context.mark_node(std::make_shared<v0::Unsqueeze>(groups, zero));
// 1. Reshape input [N, G, K=C/G, -1]
auto reshape_indices = context.mark_node(std::make_shared<v0::Concat>(OutputVector{n, g, k, neg_1}, 0));
x = context.mark_node(std::make_shared<v1::Reshape>(x, reshape_indices, false));
// 2. Transpose to [N, K, G, -1]
auto permute_indices = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {0, 2, 1, 3}));
auto y = context.mark_node(std::make_shared<v1::Transpose>(x, permute_indices));
// 3. Reshape back to original shape
auto result = context.mark_node(std::make_shared<v1::Reshape>(y, shape, false));
return {result};
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -37,6 +37,7 @@ OP_CONVERTER(translate_bitwise_not);
OP_CONVERTER(translate_bitwise_or);
OP_CONVERTER(translate_cat);
OP_CONVERTER(translate_cdist);
OP_CONVERTER(translate_channel_shuffle);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
@ -263,6 +264,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::cdist", op::translate_cdist},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
{"aten::channel_shuffle", op::translate_channel_shuffle},
{"aten::clamp", op::translate_clamp},
{"aten::clamp_max", op::translate_1to1_match_2_inputs<opset10::Minimum>},
{"aten::clamp_min", op::translate_1to1_match_2_inputs<opset10::Maximum>},

View File

@ -7,7 +7,7 @@ import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestOneHot(PytorchLayerTest):
class TestPixelShuffle(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(*self.shape).astype(np.float32),)
@ -15,21 +15,54 @@ class TestOneHot(PytorchLayerTest):
import torch
import torch.nn.functional as F
class aten_one_hot(torch.nn.Module):
class aten_pixel_shuffle(torch.nn.Module):
def __init__(self, upscale_factor):
super(aten_one_hot, self).__init__()
super(aten_pixel_shuffle, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, x):
return F.pixel_shuffle(x, self.upscale_factor)
return aten_one_hot(upscale_factor), None, "aten::pixel_shuffle"
return aten_pixel_shuffle(upscale_factor), None, "aten::pixel_shuffle"
@pytest.mark.parametrize(("upscale_factor,shape"), [(3, [1, 9, 4, 4]),
(2, [1, 2, 3, 8, 4, 4]),])
@pytest.mark.nightly
@pytest.mark.precommit
def test_one_hot(self, upscale_factor, shape, ie_device, precision, ir_version):
def test_pixel_shuffle(self, upscale_factor, shape, ie_device, precision, ir_version):
self.shape = shape
self._test(*self.create_model(upscale_factor),
ie_device, precision, ir_version)
class TestChannelShuffle(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(*self.shape).astype(np.float32),)
def create_model(self, groups):
import torch
import torch.nn.functional as F
class aten_channel_shuffle(torch.nn.Module):
def __init__(self, upscale_factor):
super(aten_channel_shuffle, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, x):
return F.channel_shuffle(x, self.upscale_factor)
return aten_channel_shuffle(groups), None, "aten::channel_shuffle"
@pytest.mark.parametrize(("groups,shape"), [
(3, [1, 9, 4, 4]),
(2, [1, 8, 8, 4, 4]),
(4, [4, 4, 2]),
(5, [4, 10, 2, 10, 1, 1]),
(1, [2, 3, 4])
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_channel_shuffle(self, groups, shape, ie_device, precision, ir_version):
self.shape = shape
self._test(*self.create_model(groups),
ie_device, precision, ir_version)