[PT FE] Add aten::ArgSort & aten::Sort (#15769)
* [PT FE] Add aten::argsort implementation & tests * [PT FE] Fix formatting * [PT FE] Fix incorrect node type for Gather * [PT FE] Fix Reshape missing argument * [PT FE] Simplify syntax, fix int/int64 conversion error * [PT FE] Fix argsort incorrectly sorting negative dimension, fix tests * [PT FE] Revert modify test class * [PT FE] Fix formatting of argsort * [PT FE] Fix define macro style * [PT FE] Add missing EOF * [PT FE] Add stable==false check, add support for different constructor calls * [PT FE] Add aten::sort implementation & tests * [PT FE] Apply style changes, add XFail test for stable sorting * Update sort.cpp * Update sort.cpp * [PT FE] Apply style changes from aten::sort t PR * Update test_argsort.py * [PT FE] Apply suggested modifications * Update test_argsort.py * [PT FE] Apply review suggestions, add tests and extract sort method to utils * [PT FE] Use utils sort function to implement argsort * [PT FE] Fix input size check 4->5 * [PT FE] Implement improved tests * [PT FE] Implement improved tests * [PT FE] Add xfail to not yet supported tests * [PT FE] Merge 2 implementations of sort and argsort into a single file * [PT FE] Remove redundant sort_elements from utils * [PT FE] Add num_inputs_check --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
e1fbb7d768
commit
0860db0dc3
@ -37,7 +37,7 @@ public:
|
||||
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
|
||||
virtual PartialShape get_input_shape(size_t index) const = 0;
|
||||
|
||||
// Return element::Type when it the original type can be represented, otherwise returns PT-sepcific data type object
|
||||
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
|
||||
// (see custom_type.hpp)
|
||||
virtual Any get_input_type(size_t index) const = 0;
|
||||
|
||||
@ -50,7 +50,7 @@ public:
|
||||
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
|
||||
virtual PartialShape get_output_shape(size_t index) const = 0;
|
||||
|
||||
// Return element::Type when it the original type can be represented, otherwise returns PT-sepcific data type object
|
||||
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
|
||||
// (see custom_type.hpp)
|
||||
virtual Any get_output_type(size_t index) const = 0;
|
||||
|
||||
|
51
src/frontends/pytorch/src/op/sort.cpp
Normal file
51
src/frontends/pytorch/src/op/sort.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_sort(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 4);
|
||||
const auto input_tensor = context.get_input(0);
|
||||
bool stable, descending;
|
||||
int64_t dim;
|
||||
|
||||
if (context.get_input_size() == 4) {
|
||||
stable = context.const_input<bool>(1);
|
||||
dim = context.const_input<int64_t>(2);
|
||||
descending = context.const_input<bool>(3);
|
||||
FRONT_END_OP_CONVERSION_CHECK(stable == false, "Stable sorting in aten::sort is not yet supported.");
|
||||
} else {
|
||||
dim = context.const_input<int64_t>(1);
|
||||
descending = context.const_input<bool>(2);
|
||||
}
|
||||
|
||||
auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN;
|
||||
auto zero_axis = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto dim_axis = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {dim}));
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input_tensor));
|
||||
auto k_values_node = context.mark_node(std::make_shared<opset10::Gather>(shape, dim_axis, zero_axis));
|
||||
auto k_values = context.mark_node(std::make_shared<opset10::Squeeze>(k_values_node));
|
||||
auto topk = context.mark_node(std::make_shared<opset10::TopK>(input_tensor,
|
||||
k_values,
|
||||
dim,
|
||||
mode,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
element::i64));
|
||||
return topk->outputs();
|
||||
};
|
||||
|
||||
OutputVector translate_argsort(NodeContext& context) {
|
||||
auto sort = translate_sort(context);
|
||||
return {sort[1]};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -20,6 +20,7 @@ OP_CONVERTER(translate_add);
|
||||
OP_CONVERTER(translate_addcmul);
|
||||
OP_CONVERTER(translate_addmm);
|
||||
OP_CONVERTER(translate_arange);
|
||||
OP_CONVERTER(translate_argsort);
|
||||
OP_CONVERTER(translate_as_tensor);
|
||||
OP_CONVERTER(translate_avg_poolnd);
|
||||
OP_CONVERTER(translate_bool);
|
||||
@ -100,6 +101,7 @@ OP_CONVERTER(translate_selu);
|
||||
OP_CONVERTER(translate_size);
|
||||
OP_CONVERTER(translate_slice);
|
||||
OP_CONVERTER(translate_softmax);
|
||||
OP_CONVERTER(translate_sort);
|
||||
OP_CONVERTER(translate_square);
|
||||
OP_CONVERTER(translate_squeeze);
|
||||
OP_CONVERTER(translate_sub);
|
||||
@ -145,6 +147,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
|
||||
{"aten::add_", op::inplace_op<op::translate_add>},
|
||||
{"aten::addcmul", op::translate_addcmul},
|
||||
{"aten::addmm", op::translate_addmm},
|
||||
{"aten::argsort", op::translate_argsort},
|
||||
{"aten::arange", op::translate_arange},
|
||||
{"aten::as_tensor", op::translate_as_tensor},
|
||||
{"aten::asin", op::translate_1to1_match_1_inputs<opset10::Asin>},
|
||||
@ -287,6 +290,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
|
||||
{"aten::size", op::translate_size},
|
||||
{"aten::slice", op::translate_slice},
|
||||
{"aten::softmax", op::translate_softmax},
|
||||
{"aten::sort", op::translate_sort},
|
||||
{"aten::sqrt", op::translate_1to1_match_1_inputs<opset10::Sqrt>},
|
||||
{"aten::square", op::translate_square},
|
||||
{"aten::squeeze", op::translate_squeeze},
|
||||
|
91
tests/layer_tests/pytorch_tests/test_argsort.py
Normal file
91
tests/layer_tests/pytorch_tests/test_argsort.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
def not_yet_supported(value):
|
||||
return pytest.param(
|
||||
value,
|
||||
marks = pytest.mark.xfail(
|
||||
reason="Failed due to aten::sargsort not yet supporting stable sorting. Ticket 105242"
|
||||
),
|
||||
)
|
||||
|
||||
class TestArgSort(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor,)
|
||||
|
||||
def create_model(self, dim, descending, stable):
|
||||
class aten_argsort(torch.nn.Module):
|
||||
def __init__(self, dim, descending, stable) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.dim = dim
|
||||
self.descending = descending
|
||||
self.stable = stable
|
||||
|
||||
def forward(self, input_tensor):
|
||||
if self.stable is not None:
|
||||
return torch.argsort(input_tensor,
|
||||
dim = self.dim,
|
||||
descending = self.descending,
|
||||
stable = self.stable
|
||||
)
|
||||
else:
|
||||
return torch.argsort(input_tensor,
|
||||
dim = self.dim,
|
||||
descending = self.descending
|
||||
)
|
||||
ref_net = None
|
||||
|
||||
return aten_argsort(dim, descending, stable), ref_net, "aten::argsort"
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", [
|
||||
np.random.rand(1, 4),
|
||||
np.random.rand(4, 4),
|
||||
np.random.rand(4, 4, 4),
|
||||
np.array([1, 2, 4, 6, 5, 8, 7]),
|
||||
np.array([6, 5, 4, 2, 3, 0, 1]),
|
||||
not_yet_supported(np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0])),
|
||||
not_yet_supported(np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
[1, 2, 3], [1, 1, 1], [1, 2, 1]])),
|
||||
not_yet_supported(np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6],
|
||||
[8, 8, 9], [7, 7, 8], [6, 5, 7],
|
||||
[8, 9, 8], [7, 8, 7], [5, 6, 7]])),
|
||||
not_yet_supported(np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]]])),
|
||||
not_yet_supported(np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]],
|
||||
[[1, 2, 1], [4, 3, 4], [3, 2, 2]],
|
||||
[[3, 2, 2], [1, 2, 1], [7, 9, 9]]])),
|
||||
not_yet_supported(np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]]]))
|
||||
])
|
||||
@pytest.mark.parametrize("descending", [
|
||||
True,
|
||||
False
|
||||
])
|
||||
@pytest.mark.parametrize("stable", [
|
||||
False,
|
||||
None,
|
||||
not_yet_supported(True)
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_argsort(self, input_tensor, descending, stable, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
dims = len(input_tensor.shape)
|
||||
for dim in range(-dims, dims):
|
||||
self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version)
|
92
tests/layer_tests/pytorch_tests/test_sort.py
Normal file
92
tests/layer_tests/pytorch_tests/test_sort.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
def not_yet_supported(value):
|
||||
return pytest.param(
|
||||
value,
|
||||
marks = pytest.mark.xfail(
|
||||
reason="Failed due to aten::sort not yet supporting stable sorting. Ticket 105242"
|
||||
),
|
||||
)
|
||||
|
||||
class TestSortConstants(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor,)
|
||||
|
||||
def create_model(self, dim, descending, stable):
|
||||
class aten_sort(torch.nn.Module):
|
||||
def __init__(self, dim, descending, stable) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.stable = stable
|
||||
self.dim = dim
|
||||
self.descending = descending
|
||||
|
||||
def forward(self, input_tensor):
|
||||
if self.stable is not None:
|
||||
return torch.sort(input_tensor,
|
||||
stable = self.stable,
|
||||
dim = self.dim,
|
||||
descending = self.descending
|
||||
)[0]
|
||||
else:
|
||||
return torch.sort(input_tensor,
|
||||
dim = self.dim,
|
||||
descending = self.descending
|
||||
)[0]
|
||||
|
||||
ref_net = None
|
||||
return aten_sort(dim, descending, stable), ref_net, "aten::sort"
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", [
|
||||
np.random.rand(16),
|
||||
np.random.rand(1, 4),
|
||||
np.random.rand(4, 4),
|
||||
np.random.rand(4, 4, 4),
|
||||
np.array([1, 2, 4, 6, 5, 8, 7]),
|
||||
np.array([6, 5, 4, 2, 3, 0, 1]),
|
||||
np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0]),
|
||||
np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
[1, 2, 3], [1, 1, 1], [1, 2, 1]]),
|
||||
np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6],
|
||||
[8, 8, 9], [7, 7, 8], [6, 5, 7],
|
||||
[8, 9, 8], [7, 8, 7], [5, 6, 7]]),
|
||||
np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]]]),
|
||||
np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]],
|
||||
[[1, 2, 1], [4, 3, 4], [3, 2, 2]],
|
||||
[[3, 2, 2], [1, 2, 1], [7, 9, 9]]]),
|
||||
np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]]])
|
||||
|
||||
])
|
||||
@pytest.mark.parametrize("descending", [
|
||||
True,
|
||||
False
|
||||
])
|
||||
@pytest.mark.parametrize("stable", [
|
||||
False,
|
||||
None,
|
||||
not_yet_supported(True)
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_sort(self, input_tensor, descending, stable, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
dims = len(input_tensor.shape)
|
||||
for dim in range(-dims, dims):
|
||||
self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user