[PT FE] Add translation for aten::__range_length and aten::__derive_index (#17618)

* Add operators and tests

* Fix op kind

* Merge tests

* Fix freeze issue as separate bug

* Fix indent

* Fix print placement

* Fix dtype
This commit is contained in:
Mateusz Mikolajczyk 2023-05-29 15:34:08 +02:00 committed by GitHub
parent a6d3f9d093
commit a1a753bb03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 130 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_derive_index(const NodeContext& context) {
// aten::__derive_index(int index, int start, int step) -> int
num_inputs_check(context, 3, 3);
auto index = context.get_input(0);
auto start = context.get_input(1);
auto step = context.get_input(2);
auto index_step = context.mark_node(std::make_shared<v1::Multiply>(index, step));
return {context.mark_node(std::make_shared<v1::Add>(start, index_step))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/ceiling.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_range_length(const NodeContext& context) {
// aten::__range_length(int lo, int hi, int step) -> int
num_inputs_check(context, 3, 3);
auto lo = context.get_input(0);
auto hi = context.get_input(1);
auto step = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), ov::element::f32));
auto length = context.mark_node(std::make_shared<v1::Subtract>(hi, lo));
auto length_f32 = context.mark_node(std::make_shared<v0::Convert>(length, ov::element::f32));
auto num_steps = context.mark_node(std::make_shared<v1::Divide>(length_f32, step, false, AutoBroadcastType::NUMPY));
auto ceil = context.mark_node(std::make_shared<v0::Ceiling>(num_steps));
auto ceil_int = context.mark_node(std::make_shared<v0::Convert>(ceil, ov::element::i32));
return {context.mark_node(std::make_shared<v0::Relu>(ceil_int))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -38,6 +38,7 @@ OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_cumsum);
OP_CONVERTER(translate_deform_conv);
OP_CONVERTER(translate_derive_index);
OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_elu);
@ -98,6 +99,7 @@ OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_pythonop);
OP_CONVERTER(translate_range_length);
OP_CONVERTER(translate_reciprocal);
OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_remainder);
@ -147,9 +149,11 @@ OP_CONVERTER(translate_zeros_like);
const std::map<std::string, CreatorFunction> get_supported_ops() {
return {
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
{"aten::__derive_index", op::translate_derive_index},
{"aten::__getitem__", op::translate_getitem},
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
{"aten::__or__", op::translate_1to1_match_2_inputs<opset10::LogicalOr>},
{"aten::__range_length", op::translate_range_length},
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},
{"aten::_set_item", op::translate_set_item},

View File

@ -0,0 +1,58 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
@pytest.mark.parametrize(
"start, stop, step",
[
[1, 32, 1],
[1, 32, 2],
[1, 32, 10],
[1, 32, -1],
[1, 32, -2],
[1, 32, -10],
[32, 1, -1],
[32, 1, -2],
[32, 1, -10],
[32, -31, -1],
[32, -31, -2],
[32, -31, -10],
],
)
class TestDeriveIndexRangeLength(PytorchLayerTest):
def _prepare_input(self):
input_data = np.array([self.start, self.stop, self.step])
return (input_data,)
def create_model(self):
class prim_derive_index_range_length(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
start = int(x[0])
stop = int(x[1])
step = int(x[2])
accumulator = 0
for idx in range(start, stop, step):
accumulator = idx
return accumulator
ref_net = None
return prim_derive_index_range_length(), ref_net, ["aten::__range_length", "aten::__derive_index"]
@pytest.mark.nightly
@pytest.mark.precommit
def test_derive_index_range_length(self, ie_device, precision, ir_version, start, stop, step):
self.start = start
self.stop = stop
self.step = step
if ((stop - start) / step) < 0:
pytest.xfail("Failed due to prim::Loop translation not supporting 0 iterations. Ticket: 110808")
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False, trace_model=False)