[PT FE] Add aten::ArgSort & aten::Sort (#15769)

* [PT FE] Add aten::argsort implementation & tests

* [PT FE] Fix formatting

* [PT FE] Fix incorrect node type for Gather

* [PT FE] Fix Reshape missing argument

* [PT FE] Simplify syntax, fix int/int64 conversion error

* [PT FE] Fix argsort incorrectly sorting negative dimension, fix tests

* [PT FE] Revert modify test class

* [PT FE] Fix formatting of argsort

* [PT FE] Fix define macro style

* [PT FE] Add missing EOF

* [PT FE] Add stable==false check, add support for different constructor calls

* [PT FE] Add aten::sort implementation & tests

* [PT FE] Apply style changes, add XFail test for stable sorting

* Update sort.cpp

* Update sort.cpp

* [PT FE] Apply style changes from aten::sort t PR

* Update test_argsort.py

* [PT FE] Apply suggested modifications

* Update test_argsort.py

* [PT FE] Apply review suggestions, add tests and extract sort method to utils

* [PT FE] Use utils sort function to implement argsort

* [PT FE] Fix input size check 4->5

* [PT FE] Implement improved tests

* [PT FE] Implement improved tests

* [PT FE] Add xfail to not yet supported tests

* [PT FE] Merge 2 implementations of sort and argsort into a single file

* [PT FE] Remove redundant sort_elements from utils

* [PT FE] Add num_inputs_check

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Piotr Krzemiński 2023-03-05 20:12:32 +01:00 committed by GitHub
parent e1fbb7d768
commit 0860db0dc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 240 additions and 2 deletions

View File

@ -37,7 +37,7 @@ public:
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
virtual PartialShape get_input_shape(size_t index) const = 0;
// Return element::Type when it the original type can be represented, otherwise returns PT-sepcific data type object
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
// (see custom_type.hpp)
virtual Any get_input_type(size_t index) const = 0;
@ -50,7 +50,7 @@ public:
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
virtual PartialShape get_output_shape(size_t index) const = 0;
// Return element::Type when it the original type can be represented, otherwise returns PT-sepcific data type object
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
// (see custom_type.hpp)
virtual Any get_output_type(size_t index) const = 0;

View File

@ -0,0 +1,51 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_sort(NodeContext& context) {
num_inputs_check(context, 3, 4);
const auto input_tensor = context.get_input(0);
bool stable, descending;
int64_t dim;
if (context.get_input_size() == 4) {
stable = context.const_input<bool>(1);
dim = context.const_input<int64_t>(2);
descending = context.const_input<bool>(3);
FRONT_END_OP_CONVERSION_CHECK(stable == false, "Stable sorting in aten::sort is not yet supported.");
} else {
dim = context.const_input<int64_t>(1);
descending = context.const_input<bool>(2);
}
auto mode = descending ? ov::op::TopKMode::MAX : ov::op::TopKMode::MIN;
auto zero_axis = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {0}));
auto dim_axis = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {dim}));
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input_tensor));
auto k_values_node = context.mark_node(std::make_shared<opset10::Gather>(shape, dim_axis, zero_axis));
auto k_values = context.mark_node(std::make_shared<opset10::Squeeze>(k_values_node));
auto topk = context.mark_node(std::make_shared<opset10::TopK>(input_tensor,
k_values,
dim,
mode,
ov::op::TopKSortType::SORT_VALUES,
element::i64));
return topk->outputs();
};
OutputVector translate_argsort(NodeContext& context) {
auto sort = translate_sort(context);
return {sort[1]};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -20,6 +20,7 @@ OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_arange);
OP_CONVERTER(translate_argsort);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
@ -100,6 +101,7 @@ OP_CONVERTER(translate_selu);
OP_CONVERTER(translate_size);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_softmax);
OP_CONVERTER(translate_sort);
OP_CONVERTER(translate_square);
OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_sub);
@ -145,6 +147,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::argsort", op::translate_argsort},
{"aten::arange", op::translate_arange},
{"aten::as_tensor", op::translate_as_tensor},
{"aten::asin", op::translate_1to1_match_1_inputs<opset10::Asin>},
@ -287,6 +290,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::size", op::translate_size},
{"aten::slice", op::translate_slice},
{"aten::softmax", op::translate_softmax},
{"aten::sort", op::translate_sort},
{"aten::sqrt", op::translate_1to1_match_1_inputs<opset10::Sqrt>},
{"aten::square", op::translate_square},
{"aten::squeeze", op::translate_squeeze},

View File

@ -0,0 +1,91 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
def not_yet_supported(value):
return pytest.param(
value,
marks = pytest.mark.xfail(
reason="Failed due to aten::sargsort not yet supporting stable sorting. Ticket 105242"
),
)
class TestArgSort(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, dim, descending, stable):
class aten_argsort(torch.nn.Module):
def __init__(self, dim, descending, stable) -> None:
torch.nn.Module.__init__(self)
self.dim = dim
self.descending = descending
self.stable = stable
def forward(self, input_tensor):
if self.stable is not None:
return torch.argsort(input_tensor,
dim = self.dim,
descending = self.descending,
stable = self.stable
)
else:
return torch.argsort(input_tensor,
dim = self.dim,
descending = self.descending
)
ref_net = None
return aten_argsort(dim, descending, stable), ref_net, "aten::argsort"
@pytest.mark.parametrize("input_tensor", [
np.random.rand(1, 4),
np.random.rand(4, 4),
np.random.rand(4, 4, 4),
np.array([1, 2, 4, 6, 5, 8, 7]),
np.array([6, 5, 4, 2, 3, 0, 1]),
not_yet_supported(np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0])),
not_yet_supported(np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3],
[1, 1, 1], [1, 2, 1], [1, 2, 3],
[1, 2, 3], [1, 1, 1], [1, 2, 1]])),
not_yet_supported(np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6],
[8, 8, 9], [7, 7, 8], [6, 5, 7],
[8, 9, 8], [7, 8, 7], [5, 6, 7]])),
not_yet_supported(np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[5, 2, 4], [4, 9, 0], [7, 7, 9]],
[[5, 2, 4], [4, 9, 0], [7, 7, 9]]])),
not_yet_supported(np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]],
[[1, 2, 1], [4, 3, 4], [3, 2, 2]],
[[3, 2, 2], [1, 2, 1], [7, 9, 9]]])),
not_yet_supported(np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
[[3, 2, 2], [3, 2, 1], [1, 2, 3]]]))
])
@pytest.mark.parametrize("descending", [
True,
False
])
@pytest.mark.parametrize("stable", [
False,
None,
not_yet_supported(True)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_argsort(self, input_tensor, descending, stable, ie_device, precision, ir_version):
self.input_tensor = input_tensor
dims = len(input_tensor.shape)
for dim in range(-dims, dims):
self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version)

View File

@ -0,0 +1,92 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
import pytest
from pytorch_layer_test_class import PytorchLayerTest
def not_yet_supported(value):
return pytest.param(
value,
marks = pytest.mark.xfail(
reason="Failed due to aten::sort not yet supporting stable sorting. Ticket 105242"
),
)
class TestSortConstants(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, dim, descending, stable):
class aten_sort(torch.nn.Module):
def __init__(self, dim, descending, stable) -> None:
torch.nn.Module.__init__(self)
self.stable = stable
self.dim = dim
self.descending = descending
def forward(self, input_tensor):
if self.stable is not None:
return torch.sort(input_tensor,
stable = self.stable,
dim = self.dim,
descending = self.descending
)[0]
else:
return torch.sort(input_tensor,
dim = self.dim,
descending = self.descending
)[0]
ref_net = None
return aten_sort(dim, descending, stable), ref_net, "aten::sort"
@pytest.mark.parametrize("input_tensor", [
np.random.rand(16),
np.random.rand(1, 4),
np.random.rand(4, 4),
np.random.rand(4, 4, 4),
np.array([1, 2, 4, 6, 5, 8, 7]),
np.array([6, 5, 4, 2, 3, 0, 1]),
np.array([1, 1, 1, 2, 1, 3, 1, 4, 2, 5, 1, 2, 4, 4, 0]),
np.array([[1, 1, 1], [1, 2, 1], [1, 2, 3],
[1, 1, 1], [1, 2, 1], [1, 2, 3],
[1, 2, 3], [1, 1, 1], [1, 2, 1]]),
np.array([[9, 8, 8], [8, 7, 7], [7, 5, 6],
[8, 8, 9], [7, 7, 8], [6, 5, 7],
[8, 9, 8], [7, 8, 7], [5, 6, 7]]),
np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[5, 2, 4], [4, 9, 0], [7, 7, 9]],
[[5, 2, 4], [4, 9, 0], [7, 7, 9]]]),
np.array([[[3, 2, 2], [1, 2, 1], [3, 2, 2]],
[[1, 2, 1], [4, 3, 4], [3, 2, 2]],
[[3, 2, 2], [1, 2, 1], [7, 9, 9]]]),
np.array([[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
[[3, 2, 2], [3, 2, 1], [1, 2, 3]],
[[2, 1, 3], [3, 2, 1], [1, 2, 3]],
[[2, 0, 2], [1, 2, 1], [3, 2, 8]],
[[3, 2, 2], [3, 2, 1], [1, 2, 3]]])
])
@pytest.mark.parametrize("descending", [
True,
False
])
@pytest.mark.parametrize("stable", [
False,
None,
not_yet_supported(True)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_sort(self, input_tensor, descending, stable, ie_device, precision, ir_version):
self.input_tensor = input_tensor
dims = len(input_tensor.shape)
for dim in range(-dims, dims):
self._test(*self.create_model(dim, descending, stable), ie_device, precision, ir_version)