From 0860db0dc3e366e6cb54c03943de1ef3710a9afa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Krzemi=C5=84ski?= Date: Sun, 5 Mar 2023 20:12:32 +0100 Subject: [PATCH] [PT FE] Add aten::ArgSort & aten::Sort (#15769) * [PT FE] Add aten::argsort implementation & tests * [PT FE] Fix formatting * [PT FE] Fix incorrect node type for Gather * [PT FE] Fix Reshape missing argument * [PT FE] Simplify syntax, fix int/int64 conversion error * [PT FE] Fix argsort incorrectly sorting negative dimension, fix tests * [PT FE] Revert modify test class * [PT FE] Fix formatting of argsort * [PT FE] Fix define macro style * [PT FE] Add missing EOF * [PT FE] Add stable==false check, add support for different constructor calls * [PT FE] Add aten::sort implementation & tests * [PT FE] Apply style changes, add XFail test for stable sorting * Update sort.cpp * Update sort.cpp * [PT FE] Apply style changes from aten::sort t PR * Update test_argsort.py * [PT FE] Apply suggested modifications * Update test_argsort.py * [PT FE] Apply review suggestions, add tests and extract sort method to utils * [PT FE] Use utils sort function to implement argsort * [PT FE] Fix input size check 4->5 * [PT FE] Implement improved tests * [PT FE] Implement improved tests * [PT FE] Add xfail to not yet supported tests * [PT FE] Merge 2 implementations of sort and argsort into a single file * [PT FE] Remove redundant sort_elements from utils * [PT FE] Add num_inputs_check --------- Co-authored-by: Maxim Vafin --- .../openvino/frontend/pytorch/decoder.hpp | 4 +- src/frontends/pytorch/src/op/sort.cpp | 51 ++++++++++ src/frontends/pytorch/src/op_table.cpp | 4 + .../layer_tests/pytorch_tests/test_argsort.py | 91 ++++++++++++++++++ tests/layer_tests/pytorch_tests/test_sort.py | 92 +++++++++++++++++++ 5 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 src/frontends/pytorch/src/op/sort.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_argsort.py create mode 100644 tests/layer_tests/pytorch_tests/test_sort.py diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index 290fc159721..b48c1d1065e 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -37,7 +37,7 @@ public: // Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar virtual PartialShape get_input_shape(size_t index) const = 0; - // Return element::Type when it the original type can be represented, otherwise returns PT-sepcific data type object + // Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object // (see custom_type.hpp) virtual Any get_input_type(size_t index) const = 0; @@ -50,7 +50,7 @@ public: // Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar virtual PartialShape get_output_shape(size_t index) const = 0; - // Return element::Type when it the original type can be represented, otherwise returns PT-sepcific data type object + // Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object // (see custom_type.hpp) virtual Any get_output_type(size_t index) const = 0; diff --git a/src/frontends/pytorch/src/op/sort.cpp b/src/frontends/pytorch/src/op/sort.cpp new file mode 100644 index 00000000000..c0e54d54d9b --- /dev/null +++ b/src/frontends/pytorch/src/op/sort.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset10.hpp" +#include "utils.hpp" +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_sort(NodeContext& context) { + num_inputs_check(context, 3, 4); + const auto input_tensor = context.get_input(0); + bool stable, descending; + int64_t dim; + + if (context.get_input_size() == 4) { + stable = context.const_input(1); + dim = context.const_input(2); + descending = context.const_input(3); + FRONT_END_OP_CONVERSION_CHECK(stable == false, "Stable sorting in aten::sort is not yet supported."); + } else { + dim = context.const_input(1); + descending = context.const_input(2); + } + + auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN; + auto zero_axis = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {0})); + auto dim_axis = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {dim})); + auto shape = context.mark_node(std::make_shared(input_tensor)); + auto k_values_node = context.mark_node(std::make_shared(shape, dim_axis, zero_axis)); + auto k_values = context.mark_node(std::make_shared(k_values_node)); + auto topk = context.mark_node(std::make_shared(input_tensor, + k_values, + dim, + mode, + ov::op::TopKSortType::SORT_VALUES, + element::i64)); + return topk->outputs(); +}; + +OutputVector translate_argsort(NodeContext& context) { + auto sort = translate_sort(context); + return {sort[1]}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index edd40cb25a9..b053704678b 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -20,6 +20,7 @@ OP_CONVERTER(translate_add); OP_CONVERTER(translate_addcmul); OP_CONVERTER(translate_addmm); OP_CONVERTER(translate_arange); +OP_CONVERTER(translate_argsort); OP_CONVERTER(translate_as_tensor); OP_CONVERTER(translate_avg_poolnd); OP_CONVERTER(translate_bool); @@ -100,6 +101,7 @@ OP_CONVERTER(translate_selu); OP_CONVERTER(translate_size); OP_CONVERTER(translate_slice); OP_CONVERTER(translate_softmax); +OP_CONVERTER(translate_sort); OP_CONVERTER(translate_square); OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_sub); @@ -145,6 +147,7 @@ const std::map get_supported_ops() { {"aten::add_", op::inplace_op}, {"aten::addcmul", op::translate_addcmul}, {"aten::addmm", op::translate_addmm}, + {"aten::argsort", op::translate_argsort}, {"aten::arange", op::translate_arange}, {"aten::as_tensor", op::translate_as_tensor}, {"aten::asin", op::translate_1to1_match_1_inputs}, @@ -287,6 +290,7 @@ const std::map get_supported_ops() { {"aten::size", op::translate_size}, {"aten::slice", op::translate_slice}, {"aten::softmax", op::translate_softmax}, + {"aten::sort", op::translate_sort}, {"aten::sqrt", op::translate_1to1_match_1_inputs}, {"aten::square", op::translate_square}, {"aten::squeeze", op::translate_squeeze}, diff --git a/tests/layer_tests/pytorch_tests/test_argsort.py b/tests/layer_tests/pytorch_tests/test_argsort.py new file mode 100644 index 00000000000..c29a5e91ae9 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_argsort.py @@ -0,0 +1,91 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +def not_yet_supported(value): + return pytest.param( + value, + marks = pytest.mark.xfail( + reason="Failed due to aten::sargsort not yet supporting stable sorting. Ticket 105242" + ), + ) + +class TestArgSort(PytorchLayerTest): + + def _prepare_input(self): + return (self.input_tensor,) + + def create_model(self, dim, descending, stable): + class aten_argsort(torch.nn.Module): + def __init__(self, dim, descending, stable) -> None: + torch.nn.Module.__init__(self) + self.dim = dim + self.descending = descending + self.stable = stable + + def forward(self, input_tensor): + if self.stable is not None: + return torch.argsort(input_tensor, + dim = self.dim, + descending = self.descending, + stable = self.stable + ) + else: + return torch.argsort(input_tensor, + dim = self.dim, + descending = self.descending + ) + ref_net = None + + return aten_argsort(dim, descending, stable), ref_net, "aten::argsort" + + @pytest.mark.parametrize("input_tensor", [ + np.random.rand(1, 4), + np.random.rand(4, 4), + np.random.rand(4, 4, 4), + np.array([1, 2, 4, 6, 5, 8, 7]), + np.array([6, 5, 4, 2, 3, 0, 1]), + not_yet_supported(np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0])), + not_yet_supported(np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3], + [1, 1, 1], [1, 2, 1], [1, 2, 3], + [1, 2, 3], [1, 1, 1], [1, 2, 1]])), + not_yet_supported(np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6], + [8, 8, 9], [7, 7, 8], [6, 5, 7], + [8, 9, 8], [7, 8, 7], [5, 6, 7]])), + not_yet_supported(np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[5, 2, 4], [4, 9, 0], [7, 7, 9]], + [[5, 2, 4], [4, 9, 0], [7, 7, 9]]])), + not_yet_supported(np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]], + [[1, 2, 1], [4, 3, 4], [3, 2, 2]], + [[3, 2, 2], [1, 2, 1], [7, 9, 9]]])), + not_yet_supported(np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]], + [[2, 0, 2], [1, 2, 1], [3, 2, 8]], + [[3, 2, 2], [3, 2, 1], [1, 2, 3]], + [[2, 1, 3], [3, 2, 1], [1, 2, 3]], + [[2, 0, 2], [1, 2, 1], [3, 2, 8]], + [[3, 2, 2], [3, 2, 1], [1, 2, 3]], + [[2, 1, 3], [3, 2, 1], [1, 2, 3]], + [[2, 0, 2], [1, 2, 1], [3, 2, 8]], + [[3, 2, 2], [3, 2, 1], [1, 2, 3]]])) + ]) + @pytest.mark.parametrize("descending", [ + True, + False + ]) + @pytest.mark.parametrize("stable", [ + False, + None, + not_yet_supported(True) + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_argsort(self, input_tensor, descending, stable, ie_device, precision, ir_version): + self.input_tensor = input_tensor + dims = len(input_tensor.shape) + for dim in range(-dims, dims): + self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_sort.py b/tests/layer_tests/pytorch_tests/test_sort.py new file mode 100644 index 00000000000..f68ce270743 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_sort.py @@ -0,0 +1,92 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + +def not_yet_supported(value): + return pytest.param( + value, + marks = pytest.mark.xfail( + reason="Failed due to aten::sort not yet supporting stable sorting. Ticket 105242" + ), + ) + +class TestSortConstants(PytorchLayerTest): + def _prepare_input(self): + return (self.input_tensor,) + + def create_model(self, dim, descending, stable): + class aten_sort(torch.nn.Module): + def __init__(self, dim, descending, stable) -> None: + torch.nn.Module.__init__(self) + self.stable = stable + self.dim = dim + self.descending = descending + + def forward(self, input_tensor): + if self.stable is not None: + return torch.sort(input_tensor, + stable = self.stable, + dim = self.dim, + descending = self.descending + )[0] + else: + return torch.sort(input_tensor, + dim = self.dim, + descending = self.descending + )[0] + + ref_net = None + return aten_sort(dim, descending, stable), ref_net, "aten::sort" + + @pytest.mark.parametrize("input_tensor", [ + np.random.rand(16), + np.random.rand(1, 4), + np.random.rand(4, 4), + np.random.rand(4, 4, 4), + np.array([1, 2, 4, 6, 5, 8, 7]), + np.array([6, 5, 4, 2, 3, 0, 1]), + np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0]), + np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3], + [1, 1, 1], [1, 2, 1], [1, 2, 3], + [1, 2, 3], [1, 1, 1], [1, 2, 1]]), + np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6], + [8, 8, 9], [7, 7, 8], [6, 5, 7], + [8, 9, 8], [7, 8, 7], [5, 6, 7]]), + np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[5, 2, 4], [4, 9, 0], [7, 7, 9]], + [[5, 2, 4], [4, 9, 0], [7, 7, 9]]]), + np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]], + [[1, 2, 1], [4, 3, 4], [3, 2, 2]], + [[3, 2, 2], [1, 2, 1], [7, 9, 9]]]), + np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]], + [[2, 0, 2], [1, 2, 1], [3, 2, 8]], + [[3, 2, 2], [3, 2, 1], [1, 2, 3]], + [[2, 1, 3], [3, 2, 1], [1, 2, 3]], + [[2, 0, 2], [1, 2, 1], [3, 2, 8]], + [[3, 2, 2], [3, 2, 1], [1, 2, 3]], + [[2, 1, 3], [3, 2, 1], [1, 2, 3]], + [[2, 0, 2], [1, 2, 1], [3, 2, 8]], + [[3, 2, 2], [3, 2, 1], [1, 2, 3]]]) + + ]) + @pytest.mark.parametrize("descending", [ + True, + False + ]) + @pytest.mark.parametrize("stable", [ + False, + None, + not_yet_supported(True) + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_sort(self, input_tensor, descending, stable, ie_device, precision, ir_version): + self.input_tensor = input_tensor + dims = len(input_tensor.shape) + for dim in range(-dims, dims): + self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version)