Add aten::narrow operator with layer test (#15788)

This commit is contained in:
Leonard Sikorski 2023-02-20 15:47:25 +01:00 committed by GitHub
parent c8c4503672
commit 5d3cd81fd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 0 deletions

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_narrow(NodeContext& context) {
num_inputs_check(context, 4, 4);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto input_tensor = context.get_input(0);
auto axis_input = context.get_input(1);
auto start_input = context.get_input(2);
auto length = context.get_input(3);
auto start = context.mark_node(std::make_shared<v0::Unsqueeze>(start_input, const_0));
auto stop = context.mark_node(std::make_shared<v1::Add>(start, length));
auto axis = context.mark_node(std::make_shared<v0::Unsqueeze>(axis_input, const_0));
auto narrow = context.mark_node(std::make_shared<v8::Slice>(input_tensor, start, stop, const_1, axis));
return {narrow};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -69,6 +69,7 @@ OP_CONVERTER(translate_max_poolnd);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_min);
OP_CONVERTER(translate_narrow);
OP_CONVERTER(translate_neg);
OP_CONVERTER(translate_new_full);
OP_CONVERTER(translate_new_ones);
@ -239,6 +240,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
{"aten::narrow", op::translate_narrow},
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten::neg", op::translate_neg},
{"aten::new_full", op::translate_new_full},

View File

@ -0,0 +1,45 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestNarrow(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor, self.dim, self.start, self.length)
def create_model(self):
class aten_narrow(torch.nn.Module):
def forward(self, input_tensor, dim: int, start, length: int):
return torch.narrow(input_tensor, dim=dim, start=start, length=length)
ref_net = None
return aten_narrow(), ref_net, "aten::narrow"
@pytest.mark.parametrize("input_tensor", [
np.random.randn(3, 3), np.random.randn(3, 4, 5)
])
@pytest.mark.parametrize("dim", [
np.array(0).astype(np.int32), np.array(1).astype(np.int32), np.array(-1).astype(np.int32)
])
@pytest.mark.parametrize("start", [
np.array(0).astype(np.int32), np.array(1).astype(np.int32)
])
@pytest.mark.parametrize("length", [
np.array(1).astype(np.int32), np.array(2).astype(np.int32)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_narrow(self, input_tensor, dim, start, length, ie_device, precision, ir_version):
self.input_tensor = input_tensor
self.dim = dim
self.start = start
self.length = length
self._test(*self.create_model(), ie_device, precision, ir_version)