ONNX TopK missing tests (#15556)
This commit is contained in:
parent
6582ad7e4d
commit
44f5238e3a
62
src/frontends/onnx/tests/models/top_k_repeating.prototxt
Normal file
62
src/frontends/onnx/tests/models/top_k_repeating.prototxt
Normal file
@ -0,0 +1,62 @@
|
||||
ir_version: 7
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "k"
|
||||
output: "values"
|
||||
output: "indices"
|
||||
op_type: "TopK"
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "k"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 15
|
||||
}
|
59
src/frontends/onnx/tests/models/top_k_repeating_1D.prototxt
Normal file
59
src/frontends/onnx/tests/models/top_k_repeating_1D.prototxt
Normal file
@ -0,0 +1,59 @@
|
||||
ir_version: 7
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "k"
|
||||
output: "values"
|
||||
output: "indices"
|
||||
op_type: "TopK"
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "k"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 15
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
ir_version: 7
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "k"
|
||||
output: "values"
|
||||
output: "indices"
|
||||
op_type: "TopK"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "k"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 15
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
ir_version: 7
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "k"
|
||||
output: "values"
|
||||
output: "indices"
|
||||
op_type: "TopK"
|
||||
attribute {
|
||||
name: "sorted"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "largest"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "k"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 15
|
||||
}
|
@ -2597,6 +2597,61 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_top_k_opset_11_const_k_smallest_negative_axis)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k_repeating_1D) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/top_k_repeating_1D.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<int32_t>({1, 1, 2, 0, 2, 100});
|
||||
test_case.add_input<int64_t>({5});
|
||||
|
||||
test_case.add_expected_output<int32_t>(Shape{5}, {100, 2, 2, 1, 1});
|
||||
test_case.add_expected_output<int64_t>(Shape{5}, {5, 2, 4, 0, 1});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k_repeating) {
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(CommonTestUtils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/top_k_repeating.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<int32_t>(Shape{3, 6}, {100, 1, 1, 2, 0, 2, 1, 2, 3, 4, 5, 6, 100, 1, 1, 2, 0, 2});
|
||||
test_case.add_input<int64_t>({3});
|
||||
|
||||
test_case.add_expected_output<int32_t>(Shape{3, 3}, {100, 2, 2, 6, 5, 4, 7, 2, 2});
|
||||
test_case.add_expected_output<int64_t>(Shape{3, 3}, {0, 3, 5, 5, 4, 3, 0, 2, 4});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k_repeating_axis_0) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/top_k_repeating_axis_0.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<int32_t>(Shape{3, 6}, {100, 1, 1, 2, 0, 2, 1, 2, 3, 4, 5, 6, 7, 1, 2, 0, 2, 1});
|
||||
test_case.add_input<int64_t>({2});
|
||||
|
||||
test_case.add_expected_output<int32_t>(Shape{2, 6}, {100, 2, 3, 4, 5, 6, 7, 1, 2, 2, 2, 2});
|
||||
test_case.add_expected_output<int64_t>(Shape{2, 6}, {0, 1, 1, 1, 1, 1, 2, 0, 2, 0, 2, 0});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k_repeating_unsorted) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/top_k_repeating_unsorted.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<int32_t>(Shape{3, 6}, {100, 1, 1, 2, 0, 2, 1, 2, 3, 4, 5, 6, 7, 1, 2, 0, 2, 1});
|
||||
test_case.add_input<int64_t>({3});
|
||||
|
||||
test_case.add_expected_output<int32_t>(Shape{3, 3}, {1, 1, 0, 3, 2, 1, 1, 1, 0});
|
||||
test_case.add_expected_output<int64_t>(Shape{3, 3}, {2, 1, 4, 2, 1, 0, 5, 1, 3});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_acosh) {
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(CommonTestUtils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/acosh.onnx"));
|
||||
|
@ -425,3 +425,8 @@ IE_CPU.onnx_softmax_crossentropy_loss_mean
|
||||
|
||||
# Cannot find blob with name: Y
|
||||
IE_CPU.onnx_bool_init_and
|
||||
|
||||
# Incorrect order of elements returned by the TopK implementation
|
||||
IE_CPU.onnx_model_top_k_repeating_1D
|
||||
IE_CPU.onnx_model_top_k_repeating
|
||||
IE_CPU.onnx_model_top_k_repeating_unsorted
|
||||
|
@ -83,3 +83,7 @@ onnx_clip_no_min_no_max_int64
|
||||
# z node not found in graph cache - ticket: 81976
|
||||
INTERPRETER.onnx_expand_context_dependent_function
|
||||
INTERPRETER.onnx_softmax_crossentropy_loss_mean
|
||||
|
||||
# Incorrect order of elements returned by the TopK implementation
|
||||
INTERPRETER.onnx_model_top_k_repeating
|
||||
INTERPRETER.onnx_model_top_k_repeating_unsorted
|
||||
|
@ -73,71 +73,71 @@ std::shared_ptr<Function> function_from_ir(const std::string& xml_path, const st
|
||||
return c.read_model(xml_path, bin_path);
|
||||
}
|
||||
|
||||
testing::AssertionResult TestCase::compare_results(size_t tolerance_bits) {
|
||||
auto compare_results = testing::AssertionSuccess();
|
||||
for (size_t i = 0; i < m_expected_outputs.size(); i++) {
|
||||
const auto& result_tensor = m_request.get_output_tensor(i);
|
||||
const auto& exp_result = m_expected_outputs.at(i);
|
||||
std::pair<testing::AssertionResult, size_t> TestCase::compare_results(size_t tolerance_bits) {
|
||||
auto res = testing::AssertionSuccess();
|
||||
size_t output_idx = 0;
|
||||
for (; output_idx < m_expected_outputs.size(); ++output_idx) {
|
||||
const auto& result_tensor = m_request.get_output_tensor(output_idx);
|
||||
const auto& exp_result = m_expected_outputs.at(output_idx);
|
||||
|
||||
const auto& element_type = result_tensor.get_element_type();
|
||||
const auto& res_shape = result_tensor.get_shape();
|
||||
const auto& exp_shape = exp_result.get_shape();
|
||||
|
||||
if (exp_shape != res_shape) {
|
||||
compare_results = testing::AssertionFailure();
|
||||
compare_results << "Computed data shape(" << res_shape << ") does not match the expected shape("
|
||||
<< exp_shape << ") for output " << i << std::endl;
|
||||
res = testing::AssertionFailure();
|
||||
res << "Computed data shape(" << res_shape << ") does not match the expected shape(" << exp_shape
|
||||
<< ") for output " << output_idx << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
switch (element_type) {
|
||||
case ov::element::Type_t::f16:
|
||||
compare_results = compare_values<ov::float16>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<ov::float16>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case ov::element::Type_t::bf16:
|
||||
compare_results = compare_values<ov::bfloat16>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<ov::bfloat16>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
compare_results = compare_values<float>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<float>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::f64:
|
||||
compare_results = compare_values<double>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<double>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
compare_results = compare_values<int8_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<int8_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::i16:
|
||||
compare_results = compare_values<int16_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<int16_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
compare_results = compare_values<int32_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<int32_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
compare_results = compare_values<int64_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<int64_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
compare_results = compare_values<uint8_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<uint8_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::u16:
|
||||
compare_results = compare_values<uint16_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<uint16_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
compare_results = compare_values<uint32_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<uint32_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
compare_results = compare_values<uint64_t>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<uint64_t>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
case element::Type_t::boolean:
|
||||
compare_results = compare_values<char>(exp_result, result_tensor, tolerance_bits);
|
||||
res = compare_values<char>(exp_result, result_tensor, tolerance_bits);
|
||||
break;
|
||||
default:
|
||||
compare_results = testing::AssertionFailure()
|
||||
<< "Unsupported data type encountered in 'compare_results' method";
|
||||
res = testing::AssertionFailure() << "Unsupported data type encountered in 'res' method";
|
||||
}
|
||||
if (compare_results == testing::AssertionFailure())
|
||||
if (res == testing::AssertionFailure())
|
||||
break;
|
||||
}
|
||||
return compare_results;
|
||||
return std::make_pair(res, output_idx);
|
||||
}
|
||||
|
||||
testing::AssertionResult TestCase::compare_results_with_tolerance_as_fp(float tolerance) {
|
||||
|
@ -178,8 +178,9 @@ public:
|
||||
m_request.infer();
|
||||
const auto res = compare_results(tolerance_bits);
|
||||
|
||||
if (res != testing::AssertionSuccess()) {
|
||||
std::cout << res.message() << std::endl;
|
||||
if (res.first != testing::AssertionSuccess()) {
|
||||
std::cout << "Results comparison failed for output: " << res.second << std::endl;
|
||||
std::cout << res.first.message() << std::endl;
|
||||
}
|
||||
|
||||
m_input_index = 0;
|
||||
@ -187,7 +188,7 @@ public:
|
||||
|
||||
m_expected_outputs.clear();
|
||||
|
||||
EXPECT_TRUE(res);
|
||||
EXPECT_TRUE(res.first);
|
||||
}
|
||||
|
||||
void run_with_tolerance_as_fp(const float tolerance = 1.0e-5f) {
|
||||
@ -213,7 +214,7 @@ private:
|
||||
std::vector<ov::Tensor> m_expected_outputs;
|
||||
size_t m_input_index = 0;
|
||||
size_t m_output_index = 0;
|
||||
testing::AssertionResult compare_results(size_t tolerance_bits);
|
||||
std::pair<testing::AssertionResult, size_t> compare_results(size_t tolerance_bits);
|
||||
testing::AssertionResult compare_results_with_tolerance_as_fp(float tolerance_bits);
|
||||
};
|
||||
} // namespace test
|
||||
|
Loading…
Reference in New Issue
Block a user