[PT FE] Add translation for aten::__range_length and aten::__derive_index (#17618)
* Add operators and tests * Fix op kind * Merge tests * Fix freeze issue as separate bug * Fix indent * Fix print placement * Fix dtype
This commit is contained in:
parent
a6d3f9d093
commit
a1a753bb03
30
src/frontends/pytorch/src/op/derive_index.cpp
Normal file
30
src/frontends/pytorch/src/op/derive_index.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_derive_index(const NodeContext& context) {
|
||||
// aten::__derive_index(int index, int start, int step) -> int
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto index = context.get_input(0);
|
||||
auto start = context.get_input(1);
|
||||
auto step = context.get_input(2);
|
||||
auto index_step = context.mark_node(std::make_shared<v1::Multiply>(index, step));
|
||||
return {context.mark_node(std::make_shared<v1::Add>(start, index_step))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
38
src/frontends/pytorch/src/op/range_length.cpp
Normal file
38
src/frontends/pytorch/src/op/range_length.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/ceiling.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_range_length(const NodeContext& context) {
|
||||
// aten::__range_length(int lo, int hi, int step) -> int
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto lo = context.get_input(0);
|
||||
auto hi = context.get_input(1);
|
||||
auto step = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), ov::element::f32));
|
||||
auto length = context.mark_node(std::make_shared<v1::Subtract>(hi, lo));
|
||||
auto length_f32 = context.mark_node(std::make_shared<v0::Convert>(length, ov::element::f32));
|
||||
auto num_steps = context.mark_node(std::make_shared<v1::Divide>(length_f32, step, false, AutoBroadcastType::NUMPY));
|
||||
auto ceil = context.mark_node(std::make_shared<v0::Ceiling>(num_steps));
|
||||
auto ceil_int = context.mark_node(std::make_shared<v0::Convert>(ceil, ov::element::i32));
|
||||
return {context.mark_node(std::make_shared<v0::Relu>(ceil_int))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -38,6 +38,7 @@ OP_CONVERTER(translate_convolution);
|
||||
OP_CONVERTER(translate_convolution_mode);
|
||||
OP_CONVERTER(translate_cumsum);
|
||||
OP_CONVERTER(translate_deform_conv);
|
||||
OP_CONVERTER(translate_derive_index);
|
||||
OP_CONVERTER(translate_dim);
|
||||
OP_CONVERTER(translate_div);
|
||||
OP_CONVERTER(translate_elu);
|
||||
@ -98,6 +99,7 @@ OP_CONVERTER(translate_ones_like);
|
||||
OP_CONVERTER(translate_pad);
|
||||
OP_CONVERTER(translate_pow);
|
||||
OP_CONVERTER(translate_pythonop);
|
||||
OP_CONVERTER(translate_range_length);
|
||||
OP_CONVERTER(translate_reciprocal);
|
||||
OP_CONVERTER(translate_relu6);
|
||||
OP_CONVERTER(translate_remainder);
|
||||
@ -147,9 +149,11 @@ OP_CONVERTER(translate_zeros_like);
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
return {
|
||||
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
||||
{"aten::__derive_index", op::translate_derive_index},
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::__or__", op::translate_1to1_match_2_inputs<opset10::LogicalOr>},
|
||||
{"aten::__range_length", op::translate_range_length},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
{"aten::_set_item", op::translate_set_item},
|
||||
|
@ -0,0 +1,58 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start, stop, step",
|
||||
[
|
||||
[1, 32, 1],
|
||||
[1, 32, 2],
|
||||
[1, 32, 10],
|
||||
[1, 32, -1],
|
||||
[1, 32, -2],
|
||||
[1, 32, -10],
|
||||
[32, 1, -1],
|
||||
[32, 1, -2],
|
||||
[32, 1, -10],
|
||||
[32, -31, -1],
|
||||
[32, -31, -2],
|
||||
[32, -31, -10],
|
||||
],
|
||||
)
|
||||
class TestDeriveIndexRangeLength(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
input_data = np.array([self.start, self.stop, self.step])
|
||||
return (input_data,)
|
||||
|
||||
def create_model(self):
|
||||
class prim_derive_index_range_length(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
start = int(x[0])
|
||||
stop = int(x[1])
|
||||
step = int(x[2])
|
||||
accumulator = 0
|
||||
for idx in range(start, stop, step):
|
||||
accumulator = idx
|
||||
return accumulator
|
||||
|
||||
ref_net = None
|
||||
|
||||
return prim_derive_index_range_length(), ref_net, ["aten::__range_length", "aten::__derive_index"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_derive_index_range_length(self, ie_device, precision, ir_version, start, stop, step):
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
if ((stop - start) / step) < 0:
|
||||
pytest.xfail("Failed due to prim::Loop translation not supporting 0 iterations. Ticket: 110808")
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False, trace_model=False)
|
Loading…
Reference in New Issue
Block a user