Add aten::narrow operator with layer test (#15788)
This commit is contained in:
parent
c8c4503672
commit
5d3cd81fd1
40
src/frontends/pytorch/src/op/narrow.cpp
Normal file
40
src/frontends/pytorch/src/op/narrow.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_narrow(NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto axis_input = context.get_input(1);
|
||||
auto start_input = context.get_input(2);
|
||||
auto length = context.get_input(3);
|
||||
|
||||
auto start = context.mark_node(std::make_shared<v0::Unsqueeze>(start_input, const_0));
|
||||
auto stop = context.mark_node(std::make_shared<v1::Add>(start, length));
|
||||
auto axis = context.mark_node(std::make_shared<v0::Unsqueeze>(axis_input, const_0));
|
||||
|
||||
auto narrow = context.mark_node(std::make_shared<v8::Slice>(input_tensor, start, stop, const_1, axis));
|
||||
return {narrow};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -69,6 +69,7 @@ OP_CONVERTER(translate_max_poolnd);
|
||||
OP_CONVERTER(translate_mean);
|
||||
OP_CONVERTER(translate_meshgrid);
|
||||
OP_CONVERTER(translate_min);
|
||||
OP_CONVERTER(translate_narrow);
|
||||
OP_CONVERTER(translate_neg);
|
||||
OP_CONVERTER(translate_new_full);
|
||||
OP_CONVERTER(translate_new_ones);
|
||||
@ -239,6 +240,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
|
||||
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
|
||||
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
|
||||
{"aten::narrow", op::translate_narrow},
|
||||
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
|
||||
{"aten::neg", op::translate_neg},
|
||||
{"aten::new_full", op::translate_new_full},
|
||||
|
45
tests/layer_tests/pytorch_tests/test_narrow.py
Normal file
45
tests/layer_tests/pytorch_tests/test_narrow.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestNarrow(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor, self.dim, self.start, self.length)
|
||||
|
||||
def create_model(self):
|
||||
|
||||
class aten_narrow(torch.nn.Module):
|
||||
|
||||
def forward(self, input_tensor, dim: int, start, length: int):
|
||||
return torch.narrow(input_tensor, dim=dim, start=start, length=length)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_narrow(), ref_net, "aten::narrow"
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", [
|
||||
np.random.randn(3, 3), np.random.randn(3, 4, 5)
|
||||
])
|
||||
@pytest.mark.parametrize("dim", [
|
||||
np.array(0).astype(np.int32), np.array(1).astype(np.int32), np.array(-1).astype(np.int32)
|
||||
])
|
||||
@pytest.mark.parametrize("start", [
|
||||
np.array(0).astype(np.int32), np.array(1).astype(np.int32)
|
||||
])
|
||||
@pytest.mark.parametrize("length", [
|
||||
np.array(1).astype(np.int32), np.array(2).astype(np.int32)
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_narrow(self, input_tensor, dim, start, length, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
self.dim = dim
|
||||
self.start = start
|
||||
self.length = length
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user