TopK v11 reference implementation (#16137)
* Stabilize ascending comparison of ref impl * Use reference to gtest param * Create ref impl tests * Fix descending by index sorting * Sort by index both ways * Make sort by index always ascending (revert)
This commit is contained in:
parent
497b7885da
commit
0145e538f5
@ -121,7 +121,7 @@ public:
|
||||
/// \param axis The axis along which the TopK operation should be executed
|
||||
/// \param mode Specifies whether TopK selects the largest or the smallest elements from each slice
|
||||
/// \param sort Specifies the order of corresponding elements of the output tensor
|
||||
/// \param index_element_type Specifies the data type type of of the elements in the 'indices' output tensor.
|
||||
/// \param index_element_type Specifies the data type of the elements in the 'indices' output tensor.
|
||||
/// \param stable Specifies whether the equivalent elements should maintain their relative order
|
||||
/// from the input tensor during sorting.
|
||||
TopK(const Output<Node>& data,
|
||||
@ -139,7 +139,7 @@ public:
|
||||
/// \param axis The axis along which the TopK operation should be executed
|
||||
/// \param mode Specifies whether TopK selects the largest or the smallest elements from each slice
|
||||
/// \param sort Specifies the order of corresponding elements of the output tensor
|
||||
/// \param index_element_type Specifies the data type type of of the elements in the 'indices' output tensor.
|
||||
/// \param index_element_type Specifies the data type of the elements in the 'indices' output tensor.
|
||||
/// \param stable Specifies whether the equivalent elements should maintain their relative order
|
||||
/// from the input tensor during sorting.
|
||||
TopK(const Output<Node>& data,
|
||||
@ -153,6 +153,11 @@ public:
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool has_evaluate() const override;
|
||||
|
||||
bool get_stable() const {
|
||||
return m_stable;
|
||||
}
|
||||
|
@ -14,9 +14,8 @@
|
||||
namespace ngraph {
|
||||
namespace runtime {
|
||||
namespace reference {
|
||||
// Had to split out these two functions. They used to be lambda expressions but
|
||||
// MSVC had difficulty compiling. This way is more explicit.
|
||||
template <typename T, typename U>
|
||||
// This used to be lambda expressions but MSVC had difficulty compiling it. This way is more explicit.
|
||||
template <bool D, typename T, typename U>
|
||||
inline bool compare_max(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
|
||||
// this is intentional to be able to compare floats directly
|
||||
// without using relative or absolute tolerance
|
||||
@ -30,19 +29,19 @@ inline bool compare_max(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
|
||||
#if defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
return a > b;
|
||||
|
||||
if (D)
|
||||
return std::get<0>(a) > std::get<0>(b);
|
||||
else
|
||||
return std::get<0>(a) < std::get<0>(b);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool compare_min(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool sort_indices_ascending(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
|
||||
inline bool compare_indices_ascending(const std::tuple<T, U>& a, const std::tuple<T, U>& b) {
|
||||
return std::get<1>(a) < std::get<1>(b);
|
||||
}
|
||||
|
||||
// TopK reference implementation provides stable indices output
|
||||
template <typename T, typename U>
|
||||
void topk(const T* arg,
|
||||
U* out_indices,
|
||||
@ -52,7 +51,7 @@ void topk(const T* arg,
|
||||
size_t axis,
|
||||
size_t k,
|
||||
bool compute_max,
|
||||
op::v1::TopK::SortType sort = op::v1::TopK::SortType::NONE) {
|
||||
op::TopKSortType sort = op::TopKSortType::NONE) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
using namespace std;
|
||||
// reorder source axis visit order and make "axis" inner most
|
||||
@ -87,25 +86,25 @@ void topk(const T* arg,
|
||||
}
|
||||
// Sort the temp vector
|
||||
if (compute_max) {
|
||||
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<T, U>);
|
||||
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<true, T, U>);
|
||||
} else {
|
||||
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_min<T, U>);
|
||||
nth_element(workspace.begin(), workspace.begin() + k, workspace.end(), compare_max<false, T, U>);
|
||||
}
|
||||
// Write temp vector to output
|
||||
switch (sort) {
|
||||
case op::v1::TopK::SortType::NONE:
|
||||
case op::TopKSortType::NONE:
|
||||
break;
|
||||
case op::v1::TopK::SortType::SORT_INDICES:
|
||||
std::sort(workspace.begin(), workspace.begin() + k, sort_indices_ascending<T, U>);
|
||||
case op::TopKSortType::SORT_INDICES:
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_indices_ascending<T, U>);
|
||||
break;
|
||||
case op::v1::TopK::SortType::SORT_VALUES:
|
||||
case op::TopKSortType::SORT_VALUES:
|
||||
if (compute_max)
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_max<T, U>);
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_max<true, T, U>);
|
||||
else
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_min<T, U>);
|
||||
std::sort(workspace.begin(), workspace.begin() + k, compare_max<false, T, U>);
|
||||
}
|
||||
for (size_t j = 0; j < k; j++) {
|
||||
tuple<T, U> entry = workspace[j];
|
||||
const auto& entry = workspace[j];
|
||||
out_values[out_index] = get<0>(entry);
|
||||
out_indices[out_index] = get<1>(entry);
|
||||
out_index += out_axis_stride;
|
||||
|
@ -103,6 +103,37 @@ bool evaluate_topk(const HostTensorPtr& arg,
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
bool TopK_evaluate(const ov::op::util::TopKBase* const node,
|
||||
const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) {
|
||||
const auto& arg_shape = inputs[0]->get_shape();
|
||||
const auto axis = normalize_axis(node, node->get_provided_axis(), arg_shape.size());
|
||||
const auto compute_max = node->get_mode() == ov::op::TopKMode::MAX;
|
||||
const auto sort_type = node->get_sort_type();
|
||||
|
||||
const auto input_shapes = vector<PartialShape>{inputs[0]->get_partial_shape(), inputs[1]->get_partial_shape()};
|
||||
const auto constant_data = map<size_t, HostTensorPtr>{{1, inputs[1]}};
|
||||
auto output_shape = shape_infer(node, input_shapes, constant_data).front().to_shape();
|
||||
|
||||
if (output_shape[axis] == 0) {
|
||||
// the kernel can't handle K (output_shape[axis]) equal 0, use arg_shape[axis] instead.
|
||||
output_shape[axis] = arg_shape[axis];
|
||||
}
|
||||
|
||||
const size_t k = output_shape[axis];
|
||||
OPENVINO_ASSERT(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
|
||||
|
||||
// TopK reference implementation provides stable indices output so this parameter is not passed on
|
||||
return evaluate_topk(inputs[0],
|
||||
outputs[1],
|
||||
outputs[0],
|
||||
output_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort_type,
|
||||
node->get_index_element_type());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace topk
|
||||
|
||||
@ -145,34 +176,7 @@ shared_ptr<Node> op::v1::TopK::clone_with_new_inputs(const OutputVector& new_arg
|
||||
|
||||
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_TopK_evaluate);
|
||||
const auto& arg_shape = inputs[0]->get_shape();
|
||||
// 1. get axis, mode (max/min), sort_type
|
||||
auto axis = ngraph::normalize_axis(this, m_axis, arg_shape.size());
|
||||
auto compute_max = get_mode() == TopKMode::MAX;
|
||||
auto sort_type = get_sort_type();
|
||||
|
||||
const auto input_shapes = std::vector<PartialShape>{inputs[0]->get_partial_shape(), inputs[1]->get_partial_shape()};
|
||||
const auto constant_data = std::map<size_t, HostTensorPtr>{{1, inputs[1]}};
|
||||
auto output_shape = shape_infer(this, input_shapes, constant_data).front().to_shape();
|
||||
|
||||
if (output_shape[axis] == 0) {
|
||||
// the kernel can't handle K (output_shape[axis]) equal 0, use arg_shape[axis] instead.
|
||||
output_shape[axis] = arg_shape[axis];
|
||||
}
|
||||
|
||||
// 2. get value of k
|
||||
size_t k = output_shape[axis];
|
||||
OPENVINO_ASSERT(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
|
||||
|
||||
return topk::evaluate_topk(inputs[0],
|
||||
outputs[1],
|
||||
outputs[0],
|
||||
output_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort_type,
|
||||
get_index_element_type());
|
||||
return topk::TopK_evaluate(this, outputs, inputs);
|
||||
}
|
||||
|
||||
bool op::v1::TopK::has_evaluate() const {
|
||||
@ -245,34 +249,7 @@ shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_arg
|
||||
|
||||
bool op::v3::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v3_TopK_evaluate);
|
||||
const auto& arg_shape = inputs[0]->get_shape();
|
||||
// 1. get axis, mode (max/min), sort_type
|
||||
auto axis = ngraph::normalize_axis(this, m_axis, arg_shape.size());
|
||||
auto compute_max = get_mode() == TopKMode::MAX;
|
||||
auto sort_type = get_sort_type();
|
||||
|
||||
const auto input_shapes = std::vector<PartialShape>{inputs[0]->get_partial_shape(), inputs[1]->get_partial_shape()};
|
||||
const auto constant_data = std::map<size_t, HostTensorPtr>{{1, inputs[1]}};
|
||||
auto output_shape = shape_infer(this, input_shapes, constant_data).front().to_shape();
|
||||
|
||||
if (output_shape[axis] == 0) {
|
||||
// the kernel can't handle K (output_shape[axis]) equal 0, use arg_shape[axis] instead.
|
||||
output_shape[axis] = arg_shape[axis];
|
||||
}
|
||||
|
||||
// 2. get value of k
|
||||
size_t k = output_shape[axis];
|
||||
OPENVINO_ASSERT(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
|
||||
|
||||
return topk::evaluate_topk(inputs[0],
|
||||
outputs[1],
|
||||
outputs[0],
|
||||
output_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort_type,
|
||||
get_index_element_type());
|
||||
return topk::TopK_evaluate(this, outputs, inputs);
|
||||
}
|
||||
|
||||
bool op::v3::TopK::has_evaluate() const {
|
||||
@ -372,3 +349,25 @@ std::shared_ptr<Node> ov::op::v11::TopK::clone_with_new_inputs(const OutputVecto
|
||||
m_index_element_type,
|
||||
m_stable);
|
||||
}
|
||||
|
||||
bool ov::op::v11::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v11_TopK_evaluate);
|
||||
return topk::TopK_evaluate(this, outputs, inputs);
|
||||
}
|
||||
|
||||
bool ov::op::v11::TopK::has_evaluate() const {
|
||||
OV_OP_SCOPE(v11_TopK_has_evaluate);
|
||||
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -4,9 +4,10 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "openvino/opsets/opset3.hpp"
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include "base_reference_test.hpp"
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include "openvino/opsets/opset11.hpp"
|
||||
#include "openvino/opsets/opset3.hpp"
|
||||
|
||||
using namespace reference_tests;
|
||||
using namespace ov;
|
||||
@ -36,7 +37,7 @@ struct TopKParams {
|
||||
class ReferenceTopKTest : public testing::TestWithParam<TopKParams>, public CommonReferenceTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<TopKParams>& obj) {
|
||||
auto param = obj.param;
|
||||
const auto& param = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "aType=" << param.A.type;
|
||||
result << "_aShape=" << param.A.shape;
|
||||
@ -74,7 +75,7 @@ struct TopKParamsResnet50 {
|
||||
class ReferenceTopKTestResnet50 : public testing::TestWithParam<TopKParamsResnet50>, public CommonReferenceTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
const auto& params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.A.data};
|
||||
refOutData = {params.result5Value.data, params.result5Index.data,
|
||||
@ -82,7 +83,7 @@ public:
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<TopKParamsResnet50>& obj) {
|
||||
auto param = obj.param;
|
||||
const auto& param = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "aType=" << param.A.type;
|
||||
result << "_aShape=" << param.A.shape;
|
||||
@ -211,7 +212,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_TopK_With_Hardcoded_Refs, ReferenceTopKTestResnet
|
||||
class ReferenceTopKTestMaxMinSort : public ReferenceTopKTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
const auto& params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.A.data};
|
||||
refOutData = {params.result0.data, params.result1.data};
|
||||
@ -538,7 +539,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_TopK_With_Hardcoded_Refs, ReferenceTopKTestMaxMin
|
||||
class ReferenceTopKTestBackend : public ReferenceTopKTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
const auto& params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.A.data};
|
||||
refOutData = {params.result0.data, params.result1.data};
|
||||
@ -561,59 +562,6 @@ TEST_P(ReferenceTopKTestBackend, CompareWithRefs) {
|
||||
Exec();
|
||||
}
|
||||
|
||||
template <element::Type_t ET, element::Type_t ET2, element::Type_t ET_OUT>
|
||||
std::vector<TopKParams> generateParamsV3() {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
using T2 = typename element_type_traits<ET2>::value_type;
|
||||
using T_OUT = typename element_type_traits<ET_OUT>::value_type;
|
||||
std::vector<TopKParams> params {
|
||||
TopKParams(
|
||||
reference_tests::Tensor(ET, {5}, std::vector<T>{3, 1, 2, 5, 4}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{3}),
|
||||
0,
|
||||
opset1::TopK::Mode::MAX,
|
||||
opset1::TopK::SortType::SORT_VALUES,
|
||||
reference_tests::Tensor(ET, {3}, std::vector<T>{5, 4, 3}),
|
||||
reference_tests::Tensor(ET_OUT, {3}, std::vector<T_OUT>{3, 4, 0}),
|
||||
0,
|
||||
"topk_mode_sort_order"),
|
||||
|
||||
TopKParams(
|
||||
reference_tests::Tensor(ET, {5}, std::vector<T>{3, 1, 2, 5, 4}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{3}),
|
||||
0,
|
||||
opset1::TopK::Mode::MAX,
|
||||
opset1::TopK::SortType::SORT_INDICES,
|
||||
reference_tests::Tensor(ET, {3}, std::vector<T>{3, 5, 4}),
|
||||
reference_tests::Tensor(ET_OUT, {3}, std::vector<T_OUT>{0, 3, 4}),
|
||||
0,
|
||||
"topk_mode_sort_order_1"),
|
||||
|
||||
TopKParams(
|
||||
reference_tests::Tensor(ET, {5}, std::vector<T>{3, 1, 2, 5, 4}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{3}),
|
||||
0,
|
||||
opset1::TopK::Mode::MIN,
|
||||
opset1::TopK::SortType::SORT_VALUES,
|
||||
reference_tests::Tensor(ET, {3}, std::vector<T>{1, 2, 3}),
|
||||
reference_tests::Tensor(ET_OUT, {3}, std::vector<T_OUT>{1, 2, 0}),
|
||||
0,
|
||||
"topk_mode_sort_order_2"),
|
||||
|
||||
TopKParams(
|
||||
reference_tests::Tensor(ET, {5}, std::vector<T>{3, 1, 2, 5, 4}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{3}),
|
||||
0,
|
||||
opset1::TopK::Mode::MIN,
|
||||
opset1::TopK::SortType::SORT_INDICES,
|
||||
reference_tests::Tensor(ET, {3}, std::vector<T>{3, 1, 2}),
|
||||
reference_tests::Tensor(ET_OUT, {3}, std::vector<T_OUT>{0, 1, 2}),
|
||||
0,
|
||||
"topk_mode_sort_order_3"),
|
||||
};
|
||||
return params;
|
||||
}
|
||||
|
||||
std::vector<TopKParams> generateCombinedParamsBackend() {
|
||||
const std::vector<std::vector<TopKParams>> generatedParams {
|
||||
generateParamsMaxMinSort<element::Type_t::i8, element::Type_t::i64, element::Type_t::i32>(),
|
||||
@ -643,7 +591,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_TopK_With_Hardcoded_Refs, ReferenceTopKTestBacken
|
||||
class ReferenceTopKTest1dMaxMin : public ReferenceTopKTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
const auto& params = GetParam();
|
||||
function = CreateFunction(params, params.outIdx);
|
||||
inputData = {params.A.data};
|
||||
if (params.outIdx != 0) {
|
||||
@ -654,7 +602,7 @@ public:
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<TopKParams>& obj) {
|
||||
auto param = obj.param;
|
||||
const auto& param = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "aType=" << param.A.type;
|
||||
result << "_aShape=" << param.A.shape;
|
||||
@ -1459,7 +1407,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_TopK_With_Hardcoded_Refs, ReferenceTopKTestInt64,
|
||||
class ReferenceTopKTestSingleOutput : public ReferenceTopKTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
const auto& params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.A.data};
|
||||
refOutData = {params.result1.data};
|
||||
@ -1706,4 +1654,103 @@ TEST(ReferenceTopKTestInvalidV3, topk_v3_invalid_k) {
|
||||
const auto k_negative = opset1::Constant::create(element::i8, Shape{}, {-1});
|
||||
EXPECT_THROW(opset3::TopK(data, k_negative, 0, "max", "index"), ngraph::NodeValidationFailure);
|
||||
}
|
||||
|
||||
class ReferenceTopKv11StableTest : public ReferenceTopKTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const auto& params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.A.data};
|
||||
refOutData = {
|
||||
params.result0.data, // stable output values
|
||||
params.result1.data, // stable output indices
|
||||
params.result0.data // unstable output values
|
||||
// unstable output indices need not be compared, by definition these might differ for
|
||||
// equal data values
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
static std::shared_ptr<Model> CreateFunction(const TopKParams& params) {
|
||||
const auto A = std::make_shared<opset11::Parameter>(params.A.type, params.A.shape);
|
||||
const auto k = opset11::Constant::create(params.k.type, params.k.shape, params.k.data.data());
|
||||
const auto topk_stable =
|
||||
std::make_shared<opset11::TopK>(A, k, params.axis, params.mode, params.sort, params.result1.type, true);
|
||||
const auto topk_unstable =
|
||||
std::make_shared<opset11::TopK>(A, k, params.axis, params.mode, params.sort, params.result1.type, false);
|
||||
|
||||
return std::make_shared<Model>(
|
||||
OutputVector{topk_stable->output(0), topk_stable->output(1), topk_unstable->output(0)},
|
||||
ParameterVector{A});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ReferenceTopKv11StableTest, CompareWithRefs) {
|
||||
Exec();
|
||||
}
|
||||
|
||||
template <element::Type_t ET, element::Type_t ET2, element::Type_t ET_OUT>
|
||||
std::vector<TopKParams> generateParamsForStableTest() {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
using T2 = typename element_type_traits<ET2>::value_type;
|
||||
using T_OUT = typename element_type_traits<ET_OUT>::value_type;
|
||||
std::vector<TopKParams> params{
|
||||
TopKParams(reference_tests::Tensor(ET, {2, 7}, std::vector<T>{5, 4, 3, 1, 7, 1, 3, 2, 1, 2, 5, 1, 7, 3}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{3}),
|
||||
1,
|
||||
opset1::TopK::Mode::MIN,
|
||||
opset1::TopK::SortType::SORT_VALUES,
|
||||
reference_tests::Tensor(ET, {2, 3}, std::vector<T>{1, 1, 3, 1, 1, 2}),
|
||||
reference_tests::Tensor(ET_OUT, {2, 3}, std::vector<T_OUT>{3, 5, 2, 1, 4, 0}),
|
||||
0,
|
||||
"repeated_values"),
|
||||
TopKParams(reference_tests::Tensor(ET,
|
||||
{7, 3},
|
||||
std::vector<T>{
|
||||
5, 7, 1, 7, 9, 1, 5, 7, 2, 2, 8, 2, 7, 7, 5, 8, 1, 4, 2, 2, 3,
|
||||
}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{4}),
|
||||
0,
|
||||
opset1::TopK::Mode::MAX,
|
||||
opset1::TopK::SortType::SORT_VALUES,
|
||||
reference_tests::Tensor(ET, {4, 3}, std::vector<T>{8, 9, 5, 7, 8, 4, 7, 7, 3, 5, 7, 2}),
|
||||
reference_tests::Tensor(ET_OUT, {4, 3}, std::vector<T_OUT>{5, 1, 4, 1, 3, 5, 4, 0, 6, 0, 2, 2}),
|
||||
0,
|
||||
"repeated_values"),
|
||||
TopKParams(reference_tests::Tensor(ET,
|
||||
{2, 3, 3},
|
||||
std::vector<T>{1, 3, 3, 1, 2, 4, 2, 2, 3, 7, 7, 1, 7, 9, 7, 5, 7, 7}),
|
||||
reference_tests::Tensor(ET2, {}, std::vector<T2>{2}),
|
||||
1,
|
||||
opset1::TopK::Mode::MIN,
|
||||
opset1::TopK::SortType::SORT_VALUES,
|
||||
reference_tests::Tensor(ET, {2, 2, 3}, std::vector<T>{1, 2, 3, 1, 2, 3, 5, 7, 1, 7, 7, 7}),
|
||||
reference_tests::Tensor(ET_OUT, {2, 2, 3}, std::vector<T_OUT>{0, 1, 0, 1, 2, 2, 2, 0, 0, 0, 2, 1}),
|
||||
0,
|
||||
"repeated_values"),
|
||||
};
|
||||
return params;
|
||||
}
|
||||
|
||||
std::vector<TopKParams> generateCombinedParamsForStableTest() {
|
||||
std::vector<std::vector<TopKParams>> generatedParams{
|
||||
generateParamsForStableTest<element::Type_t::i32, element::Type_t::i32, element::Type_t::i32>(),
|
||||
generateParamsForStableTest<element::Type_t::i64, element::Type_t::i64, element::Type_t::i64>(),
|
||||
generateParamsForStableTest<element::Type_t::u32, element::Type_t::i64, element::Type_t::i32>(),
|
||||
generateParamsForStableTest<element::Type_t::u64, element::Type_t::i32, element::Type_t::i64>(),
|
||||
generateParamsForStableTest<element::Type_t::f16, element::Type_t::i64, element::Type_t::i32>(),
|
||||
generateParamsForStableTest<element::Type_t::f32, element::Type_t::i32, element::Type_t::i32>(),
|
||||
};
|
||||
std::vector<TopKParams> combinedParams;
|
||||
for (auto& params : generatedParams) {
|
||||
std::move(params.begin(), params.end(), std::back_inserter(combinedParams));
|
||||
}
|
||||
return combinedParams;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_TopK_With_Hardcoded_Refs,
|
||||
ReferenceTopKv11StableTest,
|
||||
testing::ValuesIn(generateCombinedParamsForStableTest()),
|
||||
ReferenceTopKv11StableTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user