[PT FE] Add support for aten::pixel_shuffle (#20124)
* [PT FE] Add support for aten::pixel_shuffle * Add comments * Update src/frontends/pytorch/src/op/pixel_shuffle.cpp
This commit is contained in:
parent
b73b2502b1
commit
84d98d8bf7
73
src/frontends/pytorch/src/op/pixel_shuffle.cpp
Normal file
73
src/frontends/pytorch/src/op/pixel_shuffle.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_pixel_shuffle(const NodeContext& context) {
|
||||
// aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto upscale_factor = context.get_input(1);
|
||||
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto zero_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto one_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
Output<Node> shape;
|
||||
Output<Node> rank;
|
||||
std::tie(shape, rank) = get_shape_rank(context, x, true);
|
||||
// 1. Reshape input to [*, -1, r, r, H, W], where r is upscale factor
|
||||
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {-3, -2, -1}));
|
||||
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero_s));
|
||||
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero_s, 3));
|
||||
auto c = dims_splitted->output(0);
|
||||
auto h = dims_splitted->output(1);
|
||||
auto w = dims_splitted->output(2);
|
||||
auto dims_before = context.mark_node(std::make_shared<v8::Slice>(shape, zero, neg_3, one));
|
||||
auto upscale_factor_1d = context.mark_node(std::make_shared<v1::Reshape>(upscale_factor, neg_1, false));
|
||||
auto intermediate_shape = context.mark_node(
|
||||
std::make_shared<v0::Concat>(OutputVector{dims_before, neg_1, upscale_factor_1d, upscale_factor_1d, h, w}, 0));
|
||||
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(x, intermediate_shape, false));
|
||||
// 2. Transpose tensor to [*, C, r, H, r, W]
|
||||
auto dims_before_len = context.mark_node(std::make_shared<v3::ShapeOf>(dims_before, element::i32));
|
||||
auto dims_before_len_s = context.mark_node(std::make_shared<v0::Squeeze>(dims_before_len, zero));
|
||||
auto order_begin = context.mark_node(std::make_shared<v4::Range>(zero_s, dims_before_len_s, one_s, element::i32));
|
||||
auto order_end_neg = context.mark_node(
|
||||
v0::Constant::create(element::i32, Shape{5}, {-3, 0, -2, 1, -1})); // +2 because rank is expanded
|
||||
auto order_end = context.mark_node(std::make_shared<v1::Add>(order_end_neg, rank));
|
||||
auto order = context.mark_node(std::make_shared<v0::Concat>(OutputVector{order_begin, order_end}, 0));
|
||||
auto transpose = context.mark_node(std::make_shared<v1::Transpose>(reshape, order));
|
||||
// 3. Reshape to [*, -1, r * H, r * W]
|
||||
auto new_h = context.mark_node(std::make_shared<v1::Multiply>(h, upscale_factor));
|
||||
auto new_w = context.mark_node(std::make_shared<v1::Multiply>(w, upscale_factor));
|
||||
auto shape_after =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, neg_1, new_h, new_w}, 0));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(transpose, shape_after, false))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -120,6 +120,7 @@ OP_CONVERTER(translate_or);
|
||||
OP_CONVERTER(translate_outer);
|
||||
OP_CONVERTER(translate_pad);
|
||||
OP_CONVERTER(translate_pairwise_distance);
|
||||
OP_CONVERTER(translate_pixel_shuffle);
|
||||
OP_CONVERTER(translate_pow);
|
||||
OP_CONVERTER(translate_pythonop);
|
||||
OP_CONVERTER(translate_quantize_per_channel);
|
||||
@ -390,6 +391,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::pad", op::translate_pad},
|
||||
{"aten::pairwise_distance", op::translate_pairwise_distance},
|
||||
{"aten::permute", op::translate_1to1_match_2_inputs<opset10::Transpose>},
|
||||
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
|
||||
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
|
||||
{"aten::pow", op::translate_pow},
|
||||
{"aten::quantize_per_channel", op::translate_quantize_per_channel},
|
||||
|
35
tests/layer_tests/pytorch_tests/test_pixel_shuffle.py
Normal file
35
tests/layer_tests/pytorch_tests/test_pixel_shuffle.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestOneHot(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(*self.shape).astype(np.float32),)
|
||||
|
||||
def create_model(self, upscale_factor):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_one_hot(torch.nn.Module):
|
||||
def __init__(self, upscale_factor):
|
||||
super(aten_one_hot, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def forward(self, x):
|
||||
return F.pixel_shuffle(x, self.upscale_factor)
|
||||
|
||||
return aten_one_hot(upscale_factor), None, "aten::pixel_shuffle"
|
||||
|
||||
@pytest.mark.parametrize(("upscale_factor,shape"), [(3, [1, 9, 4, 4]),
|
||||
(2, [1, 2, 3, 8, 4, 4]),])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_one_hot(self, upscale_factor, shape, ie_device, precision, ir_version):
|
||||
self.shape = shape
|
||||
self._test(*self.create_model(upscale_factor),
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user