diff --git a/src/frontends/pytorch/src/op/topk.cpp b/src/frontends/pytorch/src/op/topk.cpp new file mode 100644 index 00000000000..7dfc7af5226 --- /dev/null +++ b/src/frontends/pytorch/src/op/topk.cpp @@ -0,0 +1,43 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/topk.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/convert.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_topk(NodeContext& context) { + const auto input_tensor = context.get_input(0); + const auto largest = context.const_input(3); + const auto sorted = context.const_input(4); + auto k = context.get_input(1); + int64_t axis{-1}; + auto mode = ov::op::TopKMode::MIN; + auto sort = ov::op::TopKSortType::NONE; + + if (!context.input_is_none(2)) { + axis = context.const_input(2); + } + if (largest) { + mode = ov::op::TopKMode::MAX; + } + if (sorted) { + sort = ov::op::TopKSortType::SORT_VALUES; + } + + auto topk = context.mark_node(std::make_shared(input_tensor, k, axis, mode, sort)); + auto indices = context.mark_node(std::make_shared(topk->output(1), element::i64)); + + return {topk->output(0), indices}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 2ea63abc6e6..a6acc1e8a4d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -92,6 +92,7 @@ OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_sub); OP_CONVERTER(translate_sum); OP_CONVERTER(translate_to); +OP_CONVERTER(translate_topk); OP_CONVERTER(translate_transpose); OP_CONVERTER(translate_tril); OP_CONVERTER(translate_triu); @@ -269,6 +270,7 @@ const std::map get_supported_ops() { {"aten::tanh_", op::inplace_op>}, {"aten::tensor", op::translate_as_tensor}, {"aten::to", op::translate_to}, + {"aten::topk", op::translate_topk}, {"aten::transpose", op::translate_transpose}, {"aten::tril", op::translate_tril}, {"aten::triu", op::translate_triu}, diff --git a/tests/layer_tests/pytorch_tests/test_topk.py b/tests/layer_tests/pytorch_tests/test_topk.py new file mode 100644 index 00000000000..b2f27acdbb8 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_topk.py @@ -0,0 +1,65 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestTopK(PytorchLayerTest): + def _prepare_input(self): + return (self.input_tensor,) + + def create_model(self, k, dim, largest, sort): + import torch + + class aten_topk(torch.nn.Module): + def __init__(self, k, dim, largest, sort): + super(aten_topk, self).__init__() + self.k = k + self.dim = dim + self.largest = largest + self.sort = sort + + def forward(self, input_tensor): + if self.dim is None: + return torch.topk(input_tensor, k=self.k, largest=self.largest, sorted=self.sort) + else: + return torch.topk(input_tensor, k=self.k, dim=self.dim, largest=self.largest, sorted=self.sort) + ref_net = None + + return aten_topk(k, dim, largest, sort), ref_net, "aten::topk" + + @pytest.mark.parametrize(("input_tensor"), [ + np.random.rand(7, 5, 5, 4), + np.random.rand(5, 6, 6, 7, 8), + ]) + + @pytest.mark.parametrize(("k"), [ + 3, + 1, + 2, + ]) + + @pytest.mark.parametrize(("dim"), [ + 0, + 2, + -1, + None, + ]) + + @pytest.mark.parametrize(("largest"), [ + True, + False, + ]) + # For False it is hard to test because in Pytorch implementation + # there is not promise on the order of output values + @pytest.mark.parametrize(("sort"), [ + True, + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_topK(self, input_tensor, k, dim, largest, sort, ie_device, precision, ir_version): + self.input_tensor = input_tensor + self._test(*self.create_model(k, dim, largest, sort), ie_device, precision, ir_version)