[PT FE] Enable stable tests for sort & argsort (#16415)
* [PT FE] Enable stable tests for sort & argsort * Update test_argsort.py * [PT FE] Update to opset11 * [PT FE] Remove redundant argument from argsort test --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
9fce01f8cc
commit
22a81e0e58
@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/opsets/opset11.hpp"
|
||||
#include "utils.hpp"
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -19,24 +19,25 @@ OutputVector translate_sort(const NodeContext& context) {
|
||||
stable = context.const_input<bool>(1);
|
||||
dim = context.const_input<int64_t>(2);
|
||||
descending = context.const_input<bool>(3);
|
||||
FRONT_END_OP_CONVERSION_CHECK(stable == false, "Stable sorting in aten::sort is not yet supported.");
|
||||
} else {
|
||||
stable = false;
|
||||
dim = context.const_input<int64_t>(1);
|
||||
descending = context.const_input<bool>(2);
|
||||
}
|
||||
|
||||
auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN;
|
||||
auto zero_axis = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto dim_axis = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {dim}));
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input_tensor));
|
||||
auto k_values_node = context.mark_node(std::make_shared<opset10::Gather>(shape, dim_axis, zero_axis));
|
||||
auto k_values = context.mark_node(std::make_shared<opset10::Squeeze>(k_values_node));
|
||||
auto topk = context.mark_node(std::make_shared<opset10::TopK>(input_tensor,
|
||||
auto zero_axis = context.mark_node(opset11::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto dim_axis = context.mark_node(opset11::Constant::create(element::i64, Shape{1}, {dim}));
|
||||
auto shape = context.mark_node(std::make_shared<opset11::ShapeOf>(input_tensor));
|
||||
auto k_values_node = context.mark_node(std::make_shared<opset11::Gather>(shape, dim_axis, zero_axis));
|
||||
auto k_values = context.mark_node(std::make_shared<opset11::Squeeze>(k_values_node));
|
||||
auto topk = context.mark_node(std::make_shared<opset11::TopK>(input_tensor,
|
||||
k_values,
|
||||
dim,
|
||||
mode,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
element::i64));
|
||||
element::i64,
|
||||
stable));
|
||||
return topk->outputs();
|
||||
};
|
||||
|
||||
|
@ -7,14 +7,6 @@ import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
def not_yet_supported(value):
|
||||
return pytest.param(
|
||||
value,
|
||||
marks = pytest.mark.xfail(
|
||||
reason="Failed due to aten::argsort not yet supporting stable sorting. Ticket 105242"
|
||||
),
|
||||
)
|
||||
|
||||
class TestArgSort(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
@ -44,26 +36,26 @@ class TestArgSort(PytorchLayerTest):
|
||||
|
||||
return aten_argsort(dim, descending, stable), ref_net, "aten::argsort"
|
||||
|
||||
@pytest.mark.parametrize("input_tensor", [
|
||||
np.random.rand(1, 4),
|
||||
np.random.rand(4, 4),
|
||||
np.random.rand(4, 4, 4),
|
||||
np.array([1, 2, 4, 6, 5, 8, 7]),
|
||||
np.array([6, 5, 4, 2, 3, 0, 1]),
|
||||
not_yet_supported(np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0])),
|
||||
not_yet_supported(np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
@pytest.mark.parametrize("tensor_stable_pair", [
|
||||
(np.random.rand(1, 4), False),
|
||||
(np.random.rand(4, 4), False),
|
||||
(np.random.rand(4, 4, 4), False),
|
||||
(np.array([1, 2, 4, 6, 5, 8, 7]), False),
|
||||
(np.array([6, 5, 4, 2, 3, 0, 1]), False),
|
||||
(np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0]), True),
|
||||
(np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
[1, 1, 1], [1, 2, 1], [1, 2, 3],
|
||||
[1, 2, 3], [1, 1, 1], [1, 2, 1]])),
|
||||
not_yet_supported(np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6],
|
||||
[1, 2, 3], [1, 1, 1], [1, 2, 1]]), True),
|
||||
(np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6],
|
||||
[8, 8, 9], [7, 7, 8], [6, 5, 7],
|
||||
[8, 9, 8], [7, 8, 7], [5, 6, 7]])),
|
||||
not_yet_supported(np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[8, 9, 8], [7, 8, 7], [5, 6, 7]]), True),
|
||||
(np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]]])),
|
||||
not_yet_supported(np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]],
|
||||
[[5, 2, 4], [4, 9, 0], [7, 7, 9]]]), True),
|
||||
(np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]],
|
||||
[[1, 2, 1], [4, 3, 4], [3, 2, 2]],
|
||||
[[3, 2, 2], [1, 2, 1], [7, 9, 9]]])),
|
||||
not_yet_supported(np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[3, 2, 2], [1, 2, 1], [7, 9, 9]]]), True),
|
||||
(np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
@ -71,21 +63,18 @@ class TestArgSort(PytorchLayerTest):
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
|
||||
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]]]))
|
||||
[[3, 2, 2], [3, 2, 1], [1, 2, 3]]]), True)
|
||||
])
|
||||
@pytest.mark.parametrize("descending", [
|
||||
True,
|
||||
False
|
||||
])
|
||||
@pytest.mark.parametrize("stable", [
|
||||
False,
|
||||
None,
|
||||
not_yet_supported(True)
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_argsort(self, input_tensor, descending, stable, ie_device, precision, ir_version):
|
||||
self.input_tensor = input_tensor
|
||||
dims = len(input_tensor.shape)
|
||||
def test_argsort(self, tensor_stable_pair, descending, ie_device, precision, ir_version):
|
||||
self.input_tensor, stable = tensor_stable_pair
|
||||
dims = len(self.input_tensor.shape)
|
||||
for dim in range(-dims, dims):
|
||||
self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version)
|
||||
stable_values = [True] if stable else [True, False, None]
|
||||
for stable_value in stable_values:
|
||||
self._test(*self.create_model(dim, descending, stable_value), ie_device, precision, ir_version)
|
||||
|
Loading…
Reference in New Issue
Block a user