[PT FE] Add aten::repeat_interleave operator (#15274)
This commit is contained in:
parent
da9470864c
commit
cab559b478
93
src/frontends/pytorch/src/op/repeat_interleave.cpp
Normal file
93
src/frontends/pytorch/src/op/repeat_interleave.cpp
Normal file
@ -0,0 +1,93 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
namespace {
|
||||
OutputVector generate_indices_from_repeats_tensor(std::vector<int32_t> repeats, NodeContext& context) {
|
||||
OutputVector all_indices;
|
||||
for (int i = 0; i < repeats.size(); i++) {
|
||||
Shape indices_shape{static_cast<size_t>(repeats.at(i))};
|
||||
std::vector<int32_t> indices_vec(repeats.at(i), i);
|
||||
auto indices = context.mark_node(opset10::Constant::create(element::i32, indices_shape, indices_vec));
|
||||
all_indices.push_back(indices);
|
||||
}
|
||||
return all_indices;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_repeat_interleave(NodeContext& context) {
|
||||
// constants
|
||||
auto const_0 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto const_1_list = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_neg_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
|
||||
// inputs
|
||||
auto input = context.get_input(0);
|
||||
std::shared_ptr<ov::Node> result;
|
||||
|
||||
auto repeats_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
|
||||
auto repeats_fw_node = std::dynamic_pointer_cast<opset10::Constant>(repeats_ext_node);
|
||||
if (repeats_fw_node && repeats_fw_node->cast_vector<int32_t>().size() > 1) {
|
||||
// repeats is Constant with more then 1 element
|
||||
auto repeats = repeats_fw_node->cast_vector<int32_t>();
|
||||
if (context.input_is_none(2)) {
|
||||
// case (repeats=tensor, dim=None)
|
||||
auto flat_shape = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(input, flat_shape, false));
|
||||
OutputVector all_indices = generate_indices_from_repeats_tensor(repeats, context);
|
||||
auto concat = context.mark_node(std::make_shared<opset10::Concat>(all_indices, 0));
|
||||
result = std::make_shared<opset10::Gather>(reshape, concat, const_0);
|
||||
} else {
|
||||
// case (repeats=tensor, dim=number)
|
||||
auto dimension = context.get_input(2);
|
||||
OutputVector all_indices = generate_indices_from_repeats_tensor(repeats, context);
|
||||
auto concat = context.mark_node(std::make_shared<opset10::Concat>(all_indices, 0));
|
||||
result = std::make_shared<opset10::Gather>(input, concat, dimension);
|
||||
}
|
||||
} else {
|
||||
// repeats is not Constant or single element constant
|
||||
// Curently we support only case when repeats contains only one element. Otherwise next Reshape will fail.
|
||||
auto repeats_input =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(1), const_1_list, false));
|
||||
auto repeats =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{repeats_input, const_1_list}, 0));
|
||||
auto shape_perm = context.mark_node(opset10::Constant::create(element::i32, Shape{2}, {1, 0}));
|
||||
if (context.input_is_none(2)) {
|
||||
// case (repeats=number, dim=None)
|
||||
auto flat_shape = context.mark_node(opset10::Constant::create(element::i32, Shape{2}, {1, -1}));
|
||||
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(input, flat_shape, false));
|
||||
auto tile = context.mark_node(std::make_shared<opset10::Tile>(reshape, repeats));
|
||||
auto transpose = context.mark_node(std::make_shared<opset10::Transpose>(tile, shape_perm));
|
||||
result = std::make_shared<opset10::Reshape>(transpose, const_neg_1, false);
|
||||
} else {
|
||||
// case (repeats=number, dim=number)
|
||||
auto dimension = context.get_input(2);
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input, element::i32));
|
||||
auto input_dim_size = context.mark_node(std::make_shared<opset10::Gather>(input_shape, dimension, const_0));
|
||||
auto range =
|
||||
context.mark_node(std::make_shared<opset10::Range>(const_0, input_dim_size, const_1, element::i32));
|
||||
auto range_unsqeezed = context.mark_node(std::make_shared<opset10::Unsqueeze>(range, const_0));
|
||||
auto tile = context.mark_node(std::make_shared<opset10::Tile>(range_unsqeezed, repeats));
|
||||
auto transpose = context.mark_node(std::make_shared<opset10::Transpose>(tile, shape_perm));
|
||||
auto flatten = context.mark_node(std::make_shared<opset10::Reshape>(transpose, const_neg_1, false));
|
||||
result = std::make_shared<opset10::Gather>(input, flatten, dimension);
|
||||
}
|
||||
}
|
||||
|
||||
return {context.mark_node(result)};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -81,6 +81,7 @@ OP_CONVERTER(translate_reciprocal);
|
||||
OP_CONVERTER(translate_relu6);
|
||||
OP_CONVERTER(translate_remainder);
|
||||
OP_CONVERTER(translate_repeat);
|
||||
OP_CONVERTER(translate_repeat_interleave);
|
||||
OP_CONVERTER(translate_reshape);
|
||||
OP_CONVERTER(translate_reshape_as);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
@ -249,6 +250,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::relu6", op::translate_relu6},
|
||||
{"aten::remainder", op::translate_remainder},
|
||||
{"aten::repeat", op::translate_repeat},
|
||||
{"aten::repeat_interleave", op::translate_repeat_interleave},
|
||||
{"aten::reshape", op::translate_reshape},
|
||||
{"aten::reshape_as", op::translate_reshape_as},
|
||||
{"aten::rsub", op::translate_rsub},
|
||||
|
76
tests/layer_tests/pytorch_tests/test_repeat_interleave.py
Normal file
76
tests/layer_tests/pytorch_tests/test_repeat_interleave.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize('input_data', ({'repeats': 1, 'dim': 0},
|
||||
{'repeats': 2, 'dim': 2},
|
||||
{'repeats': [2, 3], 'dim': 1},
|
||||
{'repeats': [3, 2, 1], 'dim': 3},
|
||||
{'repeats': [3, 2, 1], 'dim': 3},
|
||||
{'repeats': 2, 'dim': None},
|
||||
{'repeats': [random.randint(1, 5) for _ in range(36)], 'dim': None}))
|
||||
class TestRepeatInterleaveConstRepeats(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 2, 3, 3),)
|
||||
|
||||
def create_model_const_repeat(self, repeats, dim):
|
||||
class aten_repeat_interleave_const_repeat(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.repeats = torch.tensor(repeats, dtype=torch.int)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return input_tensor.repeat_interleave(self.repeats, self.dim)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_repeat_interleave_const_repeat(), ref_net, "aten::repeat_interleave"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_repeat_interleave_const_repeats(self, ie_device, precision, ir_version, input_data):
|
||||
repeats = input_data['repeats']
|
||||
dim = input_data['dim']
|
||||
self._test(*self.create_model_const_repeat(repeats, dim),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.parametrize('input_data', ({'repeats': np.array([1]).astype(np.int32), 'dim': 0},
|
||||
{'repeats': np.array(1).astype(np.int32), 'dim': 1},
|
||||
{'repeats': np.array([2]).astype(np.int32), 'dim': 2},
|
||||
{'repeats': np.array(2).astype(np.int32), 'dim': 1},
|
||||
{'repeats': np.array([3]).astype(np.int32), 'dim': None}))
|
||||
class TestRepeatInterleaveNonConstRepeats(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 2, 3, 3), self.repeats)
|
||||
|
||||
def create_model_non_const_repeat(self, dim):
|
||||
class aten_repeat_interleave_non_const_repeat(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input_tensor, repeats):
|
||||
return input_tensor.repeat_interleave(repeats, self.dim)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_repeat_interleave_non_const_repeat(), ref_net, "aten::repeat_interleave"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_repeat_interleave_non_const_repeats(self, ie_device, precision, ir_version, input_data):
|
||||
self.repeats = input_data['repeats']
|
||||
dim = input_data['dim']
|
||||
self._test(*self.create_model_non_const_repeat(dim),
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user