[PT FE] Add aten::repeat_interleave operator (#15274)

This commit is contained in:
Leonard Sikorski 2023-02-01 11:45:04 +01:00 committed by GitHub
parent da9470864c
commit cab559b478
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 171 additions and 0 deletions

View File

@ -0,0 +1,93 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "pt_framework_node.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
namespace {
OutputVector generate_indices_from_repeats_tensor(std::vector<int32_t> repeats, NodeContext& context) {
OutputVector all_indices;
for (int i = 0; i < repeats.size(); i++) {
Shape indices_shape{static_cast<size_t>(repeats.at(i))};
std::vector<int32_t> indices_vec(repeats.at(i), i);
auto indices = context.mark_node(opset10::Constant::create(element::i32, indices_shape, indices_vec));
all_indices.push_back(indices);
}
return all_indices;
};
} // namespace
OutputVector translate_repeat_interleave(NodeContext& context) {
// constants
auto const_0 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {1}));
auto const_1_list = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {1}));
auto const_neg_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
// inputs
auto input = context.get_input(0);
std::shared_ptr<ov::Node> result;
auto repeats_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
auto repeats_fw_node = std::dynamic_pointer_cast<opset10::Constant>(repeats_ext_node);
if (repeats_fw_node && repeats_fw_node->cast_vector<int32_t>().size() > 1) {
// repeats is Constant with more then 1 element
auto repeats = repeats_fw_node->cast_vector<int32_t>();
if (context.input_is_none(2)) {
// case (repeats=tensor, dim=None)
auto flat_shape = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(input, flat_shape, false));
OutputVector all_indices = generate_indices_from_repeats_tensor(repeats, context);
auto concat = context.mark_node(std::make_shared<opset10::Concat>(all_indices, 0));
result = std::make_shared<opset10::Gather>(reshape, concat, const_0);
} else {
// case (repeats=tensor, dim=number)
auto dimension = context.get_input(2);
OutputVector all_indices = generate_indices_from_repeats_tensor(repeats, context);
auto concat = context.mark_node(std::make_shared<opset10::Concat>(all_indices, 0));
result = std::make_shared<opset10::Gather>(input, concat, dimension);
}
} else {
// repeats is not Constant or single element constant
// Curently we support only case when repeats contains only one element. Otherwise next Reshape will fail.
auto repeats_input =
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(1), const_1_list, false));
auto repeats =
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{repeats_input, const_1_list}, 0));
auto shape_perm = context.mark_node(opset10::Constant::create(element::i32, Shape{2}, {1, 0}));
if (context.input_is_none(2)) {
// case (repeats=number, dim=None)
auto flat_shape = context.mark_node(opset10::Constant::create(element::i32, Shape{2}, {1, -1}));
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(input, flat_shape, false));
auto tile = context.mark_node(std::make_shared<opset10::Tile>(reshape, repeats));
auto transpose = context.mark_node(std::make_shared<opset10::Transpose>(tile, shape_perm));
result = std::make_shared<opset10::Reshape>(transpose, const_neg_1, false);
} else {
// case (repeats=number, dim=number)
auto dimension = context.get_input(2);
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input, element::i32));
auto input_dim_size = context.mark_node(std::make_shared<opset10::Gather>(input_shape, dimension, const_0));
auto range =
context.mark_node(std::make_shared<opset10::Range>(const_0, input_dim_size, const_1, element::i32));
auto range_unsqeezed = context.mark_node(std::make_shared<opset10::Unsqueeze>(range, const_0));
auto tile = context.mark_node(std::make_shared<opset10::Tile>(range_unsqeezed, repeats));
auto transpose = context.mark_node(std::make_shared<opset10::Transpose>(tile, shape_perm));
auto flatten = context.mark_node(std::make_shared<opset10::Reshape>(transpose, const_neg_1, false));
result = std::make_shared<opset10::Gather>(input, flatten, dimension);
}
}
return {context.mark_node(result)};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -81,6 +81,7 @@ OP_CONVERTER(translate_reciprocal);
OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_remainder);
OP_CONVERTER(translate_repeat);
OP_CONVERTER(translate_repeat_interleave);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_reshape_as);
OP_CONVERTER(translate_rsub);
@ -249,6 +250,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::relu6", op::translate_relu6},
{"aten::remainder", op::translate_remainder},
{"aten::repeat", op::translate_repeat},
{"aten::repeat_interleave", op::translate_repeat_interleave},
{"aten::reshape", op::translate_reshape},
{"aten::reshape_as", op::translate_reshape_as},
{"aten::rsub", op::translate_rsub},

View File

@ -0,0 +1,76 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
import numpy as np
import random
import torch
@pytest.mark.parametrize('input_data', ({'repeats': 1, 'dim': 0},
{'repeats': 2, 'dim': 2},
{'repeats': [2, 3], 'dim': 1},
{'repeats': [3, 2, 1], 'dim': 3},
{'repeats': [3, 2, 1], 'dim': 3},
{'repeats': 2, 'dim': None},
{'repeats': [random.randint(1, 5) for _ in range(36)], 'dim': None}))
class TestRepeatInterleaveConstRepeats(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 2, 3, 3),)
def create_model_const_repeat(self, repeats, dim):
class aten_repeat_interleave_const_repeat(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.repeats = torch.tensor(repeats, dtype=torch.int)
self.dim = dim
def forward(self, input_tensor):
return input_tensor.repeat_interleave(self.repeats, self.dim)
ref_net = None
return aten_repeat_interleave_const_repeat(), ref_net, "aten::repeat_interleave"
@pytest.mark.nightly
@pytest.mark.precommit
def test_repeat_interleave_const_repeats(self, ie_device, precision, ir_version, input_data):
repeats = input_data['repeats']
dim = input_data['dim']
self._test(*self.create_model_const_repeat(repeats, dim),
ie_device, precision, ir_version)
@pytest.mark.parametrize('input_data', ({'repeats': np.array([1]).astype(np.int32), 'dim': 0},
{'repeats': np.array(1).astype(np.int32), 'dim': 1},
{'repeats': np.array([2]).astype(np.int32), 'dim': 2},
{'repeats': np.array(2).astype(np.int32), 'dim': 1},
{'repeats': np.array([3]).astype(np.int32), 'dim': None}))
class TestRepeatInterleaveNonConstRepeats(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 2, 3, 3), self.repeats)
def create_model_non_const_repeat(self, dim):
class aten_repeat_interleave_non_const_repeat(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dim = dim
def forward(self, input_tensor, repeats):
return input_tensor.repeat_interleave(repeats, self.dim)
ref_net = None
return aten_repeat_interleave_non_const_repeat(), ref_net, "aten::repeat_interleave"
@pytest.mark.nightly
@pytest.mark.precommit
def test_repeat_interleave_non_const_repeats(self, ie_device, precision, ir_version, input_data):
self.repeats = input_data['repeats']
dim = input_data['dim']
self._test(*self.create_model_non_const_repeat(dim),
ie_device, precision, ir_version)