[PT FE]: support aten::unflatten (#17736)

* [PT FE]: support aten::unflatten

* Update src/frontends/pytorch/src/utils.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Ekaterina Aidova 2023-05-26 19:27:05 +04:00 committed by GitHub
parent 84f46bd048
commit b2aaa10ef6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 0 deletions

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_unflatten(const NodeContext& context) {
// aten::unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a)
num_inputs_check(context, 3, 3);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto sizes = context.get_input(2);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
dim = normalize_axis(context, dim, input);
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));
auto head_part_rank = context.mark_node(std::make_shared<v8::Slice>(input_shape, zero_1d, dim, one_1d));
auto tail_part_rank = context.mark_node(std::make_shared<v8::Slice>(input_shape, dim_plus_one, max_int, one_1d));
auto new_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{head_part_rank, sizes, tail_part_rank}, 0));
return {context.mark_node(std::make_shared<v1::Reshape>(input, new_shape, false))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -126,6 +126,7 @@ OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
OP_CONVERTER(translate_tril);
OP_CONVERTER(translate_triu);
OP_CONVERTER(translate_unflatten);
OP_CONVERTER(translate_unfold);
OP_CONVERTER(translate_upsample_bicubic2d);
OP_CONVERTER(translate_upsample_bilinear2d);
@ -338,6 +339,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::triu", op::translate_triu},
{"aten::type_as",
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
{"aten::unflatten", op::translate_unflatten},
{"aten::unfold", op::translate_unfold},
{"aten::unsqueeze", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
{"aten::unsqueeze_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},

View File

@ -116,6 +116,17 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
};
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
const Output<Node>& axis,
const Output<Node>& input_node) {
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, input_node);
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));
return new_axis;
}
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x) {
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, element::i32));
auto axes = context.mark_node(opset10::Constant::create(element::i32, Shape({1}), {0}));

View File

@ -38,6 +38,10 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
const Output<Node>& axis,
const Output<Node>& input_node);
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);
element::Type convert_dtype(int64_t dtype_value);

View File

@ -0,0 +1,35 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestUnflatten(PytorchLayerTest):
def _prepare_input(self, dtype):
return (np.random.uniform(0, 50, (6, 3, 4)).astype(dtype),)
def create_model(self, dim, shape):
import torch
class aten_unflatten(torch.nn.Module):
def __init__(self, dim, shape):
super(aten_unflatten, self).__init__()
self.dim = dim
self.shape = shape
def forward(self, x):
return x.unflatten(self.dim, self.shape)
ref_net = None
return aten_unflatten(dim, shape), ref_net, "aten::unflatten"
@pytest.mark.parametrize(("dim", "shape"), [(0, [2, 1, 3]), (1, [1, 3]), (2, (2, -1)), (-1, (2, 2)), (-2, (-1, 1))])
@pytest.mark.parametrize("dtype", ["float32", "int32"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version):
self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})