[PT FE]: support aten::unflatten (#17736)
* [PT FE]: support aten::unflatten * Update src/frontends/pytorch/src/utils.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
84f46bd048
commit
b2aaa10ef6
46
src/frontends/pytorch/src/op/unflatten.cpp
Normal file
46
src/frontends/pytorch/src/op/unflatten.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_unflatten(const NodeContext& context) {
|
||||
// aten::unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a)
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto input = context.get_input(0);
|
||||
auto dim = context.get_input(1);
|
||||
auto sizes = context.get_input(2);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
|
||||
dim = normalize_axis(context, dim, input);
|
||||
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
||||
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
|
||||
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));
|
||||
auto head_part_rank = context.mark_node(std::make_shared<v8::Slice>(input_shape, zero_1d, dim, one_1d));
|
||||
auto tail_part_rank = context.mark_node(std::make_shared<v8::Slice>(input_shape, dim_plus_one, max_int, one_1d));
|
||||
auto new_shape =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{head_part_rank, sizes, tail_part_rank}, 0));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(input, new_shape, false))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -126,6 +126,7 @@ OP_CONVERTER(translate_topk);
|
||||
OP_CONVERTER(translate_transpose);
|
||||
OP_CONVERTER(translate_tril);
|
||||
OP_CONVERTER(translate_triu);
|
||||
OP_CONVERTER(translate_unflatten);
|
||||
OP_CONVERTER(translate_unfold);
|
||||
OP_CONVERTER(translate_upsample_bicubic2d);
|
||||
OP_CONVERTER(translate_upsample_bilinear2d);
|
||||
@ -338,6 +339,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::triu", op::translate_triu},
|
||||
{"aten::type_as",
|
||||
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
|
||||
{"aten::unflatten", op::translate_unflatten},
|
||||
{"aten::unfold", op::translate_unfold},
|
||||
{"aten::unsqueeze", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
|
||||
{"aten::unsqueeze_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
|
||||
|
@ -116,6 +116,17 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
|
||||
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
|
||||
};
|
||||
|
||||
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
|
||||
const Output<Node>& axis,
|
||||
const Output<Node>& input_node) {
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input_node);
|
||||
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
|
||||
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
|
||||
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));
|
||||
return new_axis;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, element::i32));
|
||||
auto axes = context.mark_node(opset10::Constant::create(element::i32, Shape({1}), {0}));
|
||||
|
@ -38,6 +38,10 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
|
||||
|
||||
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
|
||||
|
||||
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
|
||||
const Output<Node>& axis,
|
||||
const Output<Node>& input_node);
|
||||
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);
|
||||
|
||||
element::Type convert_dtype(int64_t dtype_value);
|
||||
|
35
tests/layer_tests/pytorch_tests/test_unflatten.py
Normal file
35
tests/layer_tests/pytorch_tests/test_unflatten.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestUnflatten(PytorchLayerTest):
|
||||
def _prepare_input(self, dtype):
|
||||
return (np.random.uniform(0, 50, (6, 3, 4)).astype(dtype),)
|
||||
|
||||
def create_model(self, dim, shape):
|
||||
import torch
|
||||
|
||||
class aten_unflatten(torch.nn.Module):
|
||||
def __init__(self, dim, shape):
|
||||
super(aten_unflatten, self).__init__()
|
||||
self.dim = dim
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.unflatten(self.dim, self.shape)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_unflatten(dim, shape), ref_net, "aten::unflatten"
|
||||
|
||||
@pytest.mark.parametrize(("dim", "shape"), [(0, [2, 1, 3]), (1, [1, 3]), (2, (2, -1)), (-1, (2, 2)), (-2, (-1, 1))])
|
||||
@pytest.mark.parametrize("dtype", ["float32", "int32"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
Loading…
Reference in New Issue
Block a user