support aten::channel_shuffle (#20240)
* support aten::channel_shuffle * remove getting rank
This commit is contained in:
parent
9adfaca1a8
commit
85814ff8a0
@ -6,6 +6,7 @@
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
@ -15,6 +16,7 @@
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -67,6 +69,35 @@ OutputVector translate_pixel_shuffle(const NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(transpose, shape_after, false))};
|
||||
};
|
||||
|
||||
OutputVector translate_channel_shuffle(const NodeContext& context) {
|
||||
// aten::channel_shuffle(Tensor self, int groups) -> Tensor
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto groups = context.get_input(1);
|
||||
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
|
||||
// PyTorch realization uses assumption that channels dim is always 1
|
||||
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 1}));
|
||||
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero));
|
||||
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero, 2));
|
||||
auto c = dims_splitted->output(1);
|
||||
auto n = dims_splitted->output(0);
|
||||
groups = context.mark_node(std::make_shared<v0::Convert>(groups, element::i32));
|
||||
auto k = context.mark_node(std::make_shared<v1::Divide>(c, groups, true));
|
||||
auto g = context.mark_node(std::make_shared<v0::Unsqueeze>(groups, zero));
|
||||
// 1. Reshape input [N, G, K=C/G, -1]
|
||||
auto reshape_indices = context.mark_node(std::make_shared<v0::Concat>(OutputVector{n, g, k, neg_1}, 0));
|
||||
x = context.mark_node(std::make_shared<v1::Reshape>(x, reshape_indices, false));
|
||||
// 2. Transpose to [N, K, G, -1]
|
||||
auto permute_indices = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {0, 2, 1, 3}));
|
||||
auto y = context.mark_node(std::make_shared<v1::Transpose>(x, permute_indices));
|
||||
// 3. Reshape back to original shape
|
||||
auto result = context.mark_node(std::make_shared<v1::Reshape>(y, shape, false));
|
||||
return {result};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -37,6 +37,7 @@ OP_CONVERTER(translate_bitwise_not);
|
||||
OP_CONVERTER(translate_bitwise_or);
|
||||
OP_CONVERTER(translate_cat);
|
||||
OP_CONVERTER(translate_cdist);
|
||||
OP_CONVERTER(translate_channel_shuffle);
|
||||
OP_CONVERTER(translate_clamp);
|
||||
OP_CONVERTER(translate_constant);
|
||||
OP_CONVERTER(translate_conv_transposend);
|
||||
@ -263,6 +264,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::cdist", op::translate_cdist},
|
||||
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
||||
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
||||
{"aten::channel_shuffle", op::translate_channel_shuffle},
|
||||
{"aten::clamp", op::translate_clamp},
|
||||
{"aten::clamp_max", op::translate_1to1_match_2_inputs<opset10::Minimum>},
|
||||
{"aten::clamp_min", op::translate_1to1_match_2_inputs<opset10::Maximum>},
|
||||
|
@ -7,7 +7,7 @@ import pytest
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestOneHot(PytorchLayerTest):
|
||||
class TestPixelShuffle(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(*self.shape).astype(np.float32),)
|
||||
|
||||
@ -15,21 +15,54 @@ class TestOneHot(PytorchLayerTest):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_one_hot(torch.nn.Module):
|
||||
class aten_pixel_shuffle(torch.nn.Module):
|
||||
def __init__(self, upscale_factor):
|
||||
super(aten_one_hot, self).__init__()
|
||||
super(aten_pixel_shuffle, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def forward(self, x):
|
||||
return F.pixel_shuffle(x, self.upscale_factor)
|
||||
|
||||
return aten_one_hot(upscale_factor), None, "aten::pixel_shuffle"
|
||||
return aten_pixel_shuffle(upscale_factor), None, "aten::pixel_shuffle"
|
||||
|
||||
@pytest.mark.parametrize(("upscale_factor,shape"), [(3, [1, 9, 4, 4]),
|
||||
(2, [1, 2, 3, 8, 4, 4]),])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_one_hot(self, upscale_factor, shape, ie_device, precision, ir_version):
|
||||
def test_pixel_shuffle(self, upscale_factor, shape, ie_device, precision, ir_version):
|
||||
self.shape = shape
|
||||
self._test(*self.create_model(upscale_factor),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestChannelShuffle(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(*self.shape).astype(np.float32),)
|
||||
|
||||
def create_model(self, groups):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_channel_shuffle(torch.nn.Module):
|
||||
def __init__(self, upscale_factor):
|
||||
super(aten_channel_shuffle, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def forward(self, x):
|
||||
return F.channel_shuffle(x, self.upscale_factor)
|
||||
|
||||
return aten_channel_shuffle(groups), None, "aten::channel_shuffle"
|
||||
|
||||
@pytest.mark.parametrize(("groups,shape"), [
|
||||
(3, [1, 9, 4, 4]),
|
||||
(2, [1, 8, 8, 4, 4]),
|
||||
(4, [4, 4, 2]),
|
||||
(5, [4, 10, 2, 10, 1, 1]),
|
||||
(1, [2, 3, 4])
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_channel_shuffle(self, groups, shape, ie_device, precision, ir_version):
|
||||
self.shape = shape
|
||||
self._test(*self.create_model(groups),
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user