add aten::topk (#15221)
* add aten::topk * remove commented lines * remove white space * move include to invidual ops * swithc include statements * fix style * trim test cases
This commit is contained in:
parent
73c9a3dcf2
commit
ce4c082cb2
43
src/frontends/pytorch/src/op/topk.cpp
Normal file
43
src/frontends/pytorch/src/op/topk.cpp
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/topk.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_topk(NodeContext& context) {
|
||||
const auto input_tensor = context.get_input(0);
|
||||
const auto largest = context.const_input<bool>(3);
|
||||
const auto sorted = context.const_input<bool>(4);
|
||||
auto k = context.get_input(1);
|
||||
int64_t axis{-1};
|
||||
auto mode = ov::op::TopKMode::MIN;
|
||||
auto sort = ov::op::TopKSortType::NONE;
|
||||
|
||||
if (!context.input_is_none(2)) {
|
||||
axis = context.const_input<int64_t>(2);
|
||||
}
|
||||
if (largest) {
|
||||
mode = ov::op::TopKMode::MAX;
|
||||
}
|
||||
if (sorted) {
|
||||
sort = ov::op::TopKSortType::SORT_VALUES;
|
||||
}
|
||||
|
||||
auto topk = context.mark_node(std::make_shared<ov::op::v3::TopK>(input_tensor, k, axis, mode, sort));
|
||||
auto indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(topk->output(1), element::i64));
|
||||
|
||||
return {topk->output(0), indices};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -92,6 +92,7 @@ OP_CONVERTER(translate_squeeze);
|
||||
OP_CONVERTER(translate_sub);
|
||||
OP_CONVERTER(translate_sum);
|
||||
OP_CONVERTER(translate_to);
|
||||
OP_CONVERTER(translate_topk);
|
||||
OP_CONVERTER(translate_transpose);
|
||||
OP_CONVERTER(translate_tril);
|
||||
OP_CONVERTER(translate_triu);
|
||||
@ -269,6 +270,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::tanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tanh>>},
|
||||
{"aten::tensor", op::translate_as_tensor},
|
||||
{"aten::to", op::translate_to},
|
||||
{"aten::topk", op::translate_topk},
|
||||
{"aten::transpose", op::translate_transpose},
|
||||
{"aten::tril", op::translate_tril},
|
||||
{"aten::triu", op::translate_triu},
|
||||
|
65
tests/layer_tests/pytorch_tests/test_topk.py
Normal file
65
tests/layer_tests/pytorch_tests/test_topk.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestTopK(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor,)
|
||||
|
||||
def create_model(self, k, dim, largest, sort):
|
||||
import torch
|
||||
|
||||
class aten_topk(torch.nn.Module):
|
||||
def __init__(self, k, dim, largest, sort):
|
||||
super(aten_topk, self).__init__()
|
||||
self.k = k
|
||||
self.dim = dim
|
||||
self.largest = largest
|
||||
self.sort = sort
|
||||
|
||||
def forward(self, input_tensor):
|
||||
if self.dim is None:
|
||||
return torch.topk(input_tensor, k=self.k, largest=self.largest, sorted=self.sort)
|
||||
else:
|
||||
return torch.topk(input_tensor, k=self.k, dim=self.dim, largest=self.largest, sorted=self.sort)
|
||||
ref_net = None
|
||||
|
||||
return aten_topk(k, dim, largest, sort), ref_net, "aten::topk"
|
||||
|
||||
@pytest.mark.parametrize(("input_tensor"), [
|
||||
np.random.rand(7, 5, 5, 4),
|
||||
np.random.rand(5, 6, 6, 7, 8),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(("k"), [
|
||||
3,
|
||||
1,
|
||||
2,
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(("dim"), [
|
||||
0,
|
||||
2,
|
||||
-1,
|
||||
None,
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(("largest"), [
|
||||
True,
|
||||
False,
|
||||
])
|
||||
# For False it is hard to test because in Pytorch implementation
|
||||
# there is not promise on the order of output values
|
||||
@pytest.mark.parametrize(("sort"), [
|
||||
True,
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_topK(self, input_tensor, k, dim, largest, sort, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
self._test(*self.create_model(k, dim, largest, sort), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user