diff --git a/src/frontends/pytorch/src/op/unflatten.cpp b/src/frontends/pytorch/src/op/unflatten.cpp new file mode 100644 index 00000000000..eff0a5130cc --- /dev/null +++ b/src/frontends/pytorch/src/op/unflatten.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_unflatten(const NodeContext& context) { + // aten::unflatten.int(Tensor(a) self, int dim, int[] sizes) -> Tensor(a) + num_inputs_check(context, 3, 3); + auto input = context.get_input(0); + auto dim = context.get_input(1); + auto sizes = context.get_input(2); + auto input_shape = context.mark_node(std::make_shared(input, element::i32)); + auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); + dim = context.mark_node(std::make_shared(dim, element::i32)); + dim = normalize_axis(context, dim, input); + sizes = context.mark_node(std::make_shared(sizes, element::i32)); + auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits::max()})); + auto dim_plus_one = context.mark_node(std::make_shared(dim, one_1d)); + auto head_part_rank = context.mark_node(std::make_shared(input_shape, zero_1d, dim, one_1d)); + auto tail_part_rank = context.mark_node(std::make_shared(input_shape, dim_plus_one, max_int, one_1d)); + auto new_shape = + context.mark_node(std::make_shared(OutputVector{head_part_rank, sizes, tail_part_rank}, 0)); + return {context.mark_node(std::make_shared(input, new_shape, false))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index b0d6def4467..9d13ee0d046 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -126,6 +126,7 @@ OP_CONVERTER(translate_topk); OP_CONVERTER(translate_transpose); OP_CONVERTER(translate_tril); OP_CONVERTER(translate_triu); +OP_CONVERTER(translate_unflatten); OP_CONVERTER(translate_unfold); OP_CONVERTER(translate_upsample_bicubic2d); OP_CONVERTER(translate_upsample_bilinear2d); @@ -338,6 +339,7 @@ const std::map get_supported_ops() { {"aten::triu", op::translate_triu}, {"aten::type_as", op::translate_1to1_match_2_inputs}, // TODO: overflow semantics is different + {"aten::unflatten", op::translate_unflatten}, {"aten::unfold", op::translate_unfold}, {"aten::unsqueeze", op::translate_1to1_match_2_inputs}, {"aten::unsqueeze_", op::inplace_op>}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index ff55c275ce4..dcfcce0d3c2 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -116,6 +116,17 @@ std::shared_ptr get_axes_range(const NodeContext& context, int input_id) { return context.mark_node(std::make_shared(start, reduced_rank, step, element::i32)); }; +std::shared_ptr normalize_axis(const NodeContext& context, + const Output& axis, + const Output& input_node) { + Output rank; + std::tie(std::ignore, rank) = get_shape_rank(context, input_node); + auto axis_rank = context.mark_node(std::make_shared(axis, rank)); + auto is_less = context.mark_node(std::make_shared(axis_rank, rank)); + auto new_axis = context.mark_node(std::make_shared(is_less, axis_rank, axis)); + return new_axis; +} + std::shared_ptr numel(const NodeContext& context, const Output& x) { auto input_shape = context.mark_node(std::make_shared(x, element::i32)); auto axes = context.mark_node(opset10::Constant::create(element::i32, Shape({1}), {0})); diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index dd603169070..aea3fd505c5 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -38,6 +38,10 @@ Output reshape_kernel_for_group(const NodeContext& context, const Output get_axes_range(const NodeContext& context, int input_id); +std::shared_ptr normalize_axis(const NodeContext& context, + const Output& axis, + const Output& input_node); + std::shared_ptr numel(const NodeContext& context, const Output& x); element::Type convert_dtype(int64_t dtype_value); diff --git a/tests/layer_tests/pytorch_tests/test_unflatten.py b/tests/layer_tests/pytorch_tests/test_unflatten.py new file mode 100644 index 00000000000..e260b125e11 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_unflatten.py @@ -0,0 +1,35 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestUnflatten(PytorchLayerTest): + def _prepare_input(self, dtype): + return (np.random.uniform(0, 50, (6, 3, 4)).astype(dtype),) + + def create_model(self, dim, shape): + import torch + + class aten_unflatten(torch.nn.Module): + def __init__(self, dim, shape): + super(aten_unflatten, self).__init__() + self.dim = dim + self.shape = shape + + def forward(self, x): + return x.unflatten(self.dim, self.shape) + + ref_net = None + + return aten_unflatten(dim, shape), ref_net, "aten::unflatten" + + @pytest.mark.parametrize(("dim", "shape"), [(0, [2, 1, 3]), (1, [1, 3]), (2, (2, -1)), (-1, (2, 2)), (-2, (-1, 1))]) + @pytest.mark.parametrize("dtype", ["float32", "int32"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) \ No newline at end of file