[PT FE]: support aten::pixel_unshuffle (#20325)
This commit is contained in:
parent
52f8e423f8
commit
ac780c7c16
@ -69,6 +69,45 @@ OutputVector translate_pixel_shuffle(const NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(transpose, shape_after, false))};
|
||||
};
|
||||
|
||||
OutputVector translate_pixel_unshuffle(const NodeContext& context) {
|
||||
// aten::pixel_unshuffle(Tensor self, int upscale_factor) -> Tensor
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto upscale_factor = context.get_input(1);
|
||||
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto zero_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto one_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
Output<Node> shape;
|
||||
Output<Node> rank;
|
||||
std::tie(shape, rank) = get_shape_rank(context, x, true);
|
||||
// 1. Reshape input to [-1, C, H / r, r, W / r, r], where r is upscale factor
|
||||
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {-3, -2, -1}));
|
||||
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero_s));
|
||||
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero_s, 3));
|
||||
auto c = dims_splitted->output(0);
|
||||
auto h = dims_splitted->output(1);
|
||||
auto w = dims_splitted->output(2);
|
||||
auto dims_before = context.mark_node(std::make_shared<v8::Slice>(shape, zero, neg_3, one));
|
||||
auto r = context.mark_node(std::make_shared<v0::Unsqueeze>(upscale_factor, zero));
|
||||
auto new_h = context.mark_node(std::make_shared<v1::Divide>(h, upscale_factor, true));
|
||||
auto new_w = context.mark_node(std::make_shared<v1::Divide>(w, upscale_factor, true));
|
||||
auto intermediate_shape =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{neg_1, c, new_h, r, new_w, r}, 0));
|
||||
auto x_reshaped = context.mark_node(std::make_shared<v1::Reshape>(x, intermediate_shape, false));
|
||||
// 2. Transpose to [-1, C, r, r, H / r, W / r]
|
||||
auto transpose_order = context.mark_node(v0::Constant::create(element::i32, Shape{6}, {0, 1, 3, 5, 2, 4}));
|
||||
auto x_transposed = context.mark_node(std::make_shared<v1::Transpose>(x_reshaped, transpose_order));
|
||||
// 3. Reshape to [*, C*r*r, H / r, W / r]
|
||||
auto r_sqr = context.mark_node(std::make_shared<v1::Multiply>(r, r));
|
||||
auto new_c = context.mark_node(std::make_shared<v1::Multiply>(c, r_sqr));
|
||||
auto final_shape =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, new_c, new_h, new_w}, 0));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(x_transposed, final_shape, false))};
|
||||
};
|
||||
|
||||
OutputVector translate_channel_shuffle(const NodeContext& context) {
|
||||
// aten::channel_shuffle(Tensor self, int groups) -> Tensor
|
||||
num_inputs_check(context, 2, 2);
|
||||
|
@ -125,6 +125,7 @@ OP_CONVERTER(translate_outer);
|
||||
OP_CONVERTER(translate_pad);
|
||||
OP_CONVERTER(translate_pairwise_distance);
|
||||
OP_CONVERTER(translate_pixel_shuffle);
|
||||
OP_CONVERTER(translate_pixel_unshuffle);
|
||||
OP_CONVERTER(translate_pow);
|
||||
OP_CONVERTER(translate_pythonop);
|
||||
OP_CONVERTER(translate_quantize_per_channel);
|
||||
@ -409,6 +410,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::pairwise_distance", op::translate_pairwise_distance},
|
||||
{"aten::permute", op::translate_1to1_match_2_inputs<opset10::Transpose>},
|
||||
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
|
||||
{"aten::pixel_unshuffle", op::translate_pixel_unshuffle},
|
||||
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
|
||||
{"aten::pow", op::translate_pow},
|
||||
{"aten::quantize_per_channel", op::translate_quantize_per_channel},
|
||||
|
@ -35,6 +35,34 @@ class TestPixelShuffle(PytorchLayerTest):
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestPixelUnshuffle(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(*self.shape).astype(np.float32),)
|
||||
|
||||
def create_model(self, upscale_factor):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_pixel_unshuffle(torch.nn.Module):
|
||||
def __init__(self, upscale_factor):
|
||||
super(aten_pixel_unshuffle, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def forward(self, x):
|
||||
return F.pixel_unshuffle(x, self.upscale_factor)
|
||||
|
||||
return aten_pixel_unshuffle(upscale_factor), None, "aten::pixel_unshuffle"
|
||||
|
||||
@pytest.mark.parametrize(("upscale_factor,shape"), [(3, [1, 1, 12, 12]),
|
||||
(2, [1, 2, 3, 2, 8, 8]),])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_pixel_unshuffle(self, upscale_factor, shape, ie_device, precision, ir_version):
|
||||
self.shape = shape
|
||||
self._test(*self.create_model(upscale_factor),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestChannelShuffle(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(*self.shape).astype(np.float32),)
|
||||
@ -65,4 +93,4 @@ class TestChannelShuffle(PytorchLayerTest):
|
||||
def test_channel_shuffle(self, groups, shape, ie_device, precision, ir_version):
|
||||
self.shape = shape
|
||||
self._test(*self.create_model(groups),
|
||||
ie_device, precision, ir_version)
|
||||
ie_device, precision, ir_version)
|
||||
|
Loading…
Reference in New Issue
Block a user