add aten::topk (#15221)

* add aten::topk

* remove commented lines

* remove white space

* move include to invidual ops

* swithc include statements

* fix style

* trim test cases
This commit is contained in:
Bartek Szmelczynski 2023-01-27 09:34:55 +01:00 committed by GitHub
parent 73c9a3dcf2
commit ce4c082cb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 0 deletions

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/topk.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_topk(NodeContext& context) {
const auto input_tensor = context.get_input(0);
const auto largest = context.const_input<bool>(3);
const auto sorted = context.const_input<bool>(4);
auto k = context.get_input(1);
int64_t axis{-1};
auto mode = ov::op::TopKMode::MIN;
auto sort = ov::op::TopKSortType::NONE;
if (!context.input_is_none(2)) {
axis = context.const_input<int64_t>(2);
}
if (largest) {
mode = ov::op::TopKMode::MAX;
}
if (sorted) {
sort = ov::op::TopKSortType::SORT_VALUES;
}
auto topk = context.mark_node(std::make_shared<ov::op::v3::TopK>(input_tensor, k, axis, mode, sort));
auto indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(topk->output(1), element::i64));
return {topk->output(0), indices};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -92,6 +92,7 @@ OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_sub);
OP_CONVERTER(translate_sum);
OP_CONVERTER(translate_to);
OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
OP_CONVERTER(translate_tril);
OP_CONVERTER(translate_triu);
@ -269,6 +270,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::tanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tanh>>},
{"aten::tensor", op::translate_as_tensor},
{"aten::to", op::translate_to},
{"aten::topk", op::translate_topk},
{"aten::transpose", op::translate_transpose},
{"aten::tril", op::translate_tril},
{"aten::triu", op::translate_triu},

View File

@ -0,0 +1,65 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestTopK(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, k, dim, largest, sort):
import torch
class aten_topk(torch.nn.Module):
def __init__(self, k, dim, largest, sort):
super(aten_topk, self).__init__()
self.k = k
self.dim = dim
self.largest = largest
self.sort = sort
def forward(self, input_tensor):
if self.dim is None:
return torch.topk(input_tensor, k=self.k, largest=self.largest, sorted=self.sort)
else:
return torch.topk(input_tensor, k=self.k, dim=self.dim, largest=self.largest, sorted=self.sort)
ref_net = None
return aten_topk(k, dim, largest, sort), ref_net, "aten::topk"
@pytest.mark.parametrize(("input_tensor"), [
np.random.rand(7, 5, 5, 4),
np.random.rand(5, 6, 6, 7, 8),
])
@pytest.mark.parametrize(("k"), [
3,
1,
2,
])
@pytest.mark.parametrize(("dim"), [
0,
2,
-1,
None,
])
@pytest.mark.parametrize(("largest"), [
True,
False,
])
# For False it is hard to test because in Pytorch implementation
# there is not promise on the order of output values
@pytest.mark.parametrize(("sort"), [
True,
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_topK(self, input_tensor, k, dim, largest, sort, ie_device, precision, ir_version):
self.input_tensor = input_tensor
self._test(*self.create_model(k, dim, largest, sort), ie_device, precision, ir_version)