Fixed default value of score threshold (#17448)

This commit is contained in:
Mateusz Bencer
2023-05-10 13:49:35 +02:00
committed by GitHub
parent 014eafda00
commit 5eab00a116
3 changed files with 159 additions and 1 deletions

View File

@@ -44,7 +44,7 @@ OutputVector non_max_suppression(const Node& node) {
if (ng_inputs.size() > 4 && !is_null(ng_inputs.at(4))) {
score_threshold = ngraph::onnx_import::reshape::interpret_as_scalar(ng_inputs.at(4));
} else {
score_threshold = default_opset::Constant::create(element::f32, Shape{}, {.0f});
score_threshold = default_opset::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::max()});
}
const auto center_point_box = node.get_attribute_value<std::int64_t>("center_point_box", 0);

View File

@@ -0,0 +1,89 @@
ir_version: 6
producer_name: "ONNX Frontend"
graph {
node {
output: "max_output_boxes"
name: "Constant_1521"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\377\377\377\377\377\377\377\177"
}
type: TENSOR
}
}
node {
output: "iou_threshold"
name: "Constant_1522"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 1
raw_data: "333?"
}
type: TENSOR
}
}
node {
input: "boxes"
input: "scores"
input: "max_output_boxes"
input: "iou_threshold"
output: "selected_indices"
op_type: "NonMaxSuppression"
}
input {
name: "boxes"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 50
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "scores"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 50
}
}
}
}
}
output {
name: "selected_indices"
type {
tensor_type {
elem_type: 7
}
}
}
}
opset_import {
version: 11
}

View File

@@ -1080,6 +1080,75 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_nonmaxsuppression_v9_single_box) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_nonmaxsuppression_default_score_threshold) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/nms_default_score_threshold.onnx"));
auto test_case = test::TestCase(function, s_device);
test_case.add_input(
Shape{1, 50, 4},
std::vector<float>(
{278.862060546875f, 453.5412902832031f, 295.09234619140625f, 470.2095031738281f, 225.9730682373047f,
387.33990478515625f, 241.69297790527344f, 403.43377685546875f, 281.3062438964844f, 453.8412170410156f,
298.6865539550781f, 470.9977111816406f, 216.9517364501953f, 450.6717529296875f, 232.95777893066406f,
466.14276123046875f, 217.54473876953125f, 449.9130859375f, 233.97265625f, 466.1539306640625f,
279.0079650878906f, 453.865234375f, 294.8210144042969f, 470.123046875f, 226.5626983642578f,
388.5235290527344f, 242.2290496826172f, 404.2589416503906f, 216.49752807617188f, 450.7710876464844f,
233.07443237304688f, 466.7010192871094f, 281.3638000488281f, 454.33892822265625f, 298.5252990722656f,
471.1678466796875f, 217.3330841064453f, 451.484130859375f, 234.1898651123047f, 466.83148193359375f,
187.2439727783203f, 466.8524475097656f, 208.7089385986328f, 489.7967224121094f, 257.8833923339844f,
515.705322265625f, 280.8927917480469f, 539.775146484375f, 226.52525329589844f, 387.7011413574219f,
241.6272430419922f, 403.7854919433594f, 187.38221740722656f, 466.5717468261719f, 209.05845642089844f,
489.4494323730469f, 217.56448364257812f, 451.1393737792969f, 233.90216064453125f, 466.1475524902344f,
279.45611572265625f, 454.00299072265625f, 296.16424560546875f, 471.84521484375f, 279.04486083984375f,
453.9889221191406f, 295.2816162109375f, 470.4144592285156f, 187.18997192382812f, 466.4650573730469f,
209.26266479492188f, 488.8149719238281f, 189.04197692871094f, 469.8923034667969f, 208.8195037841797f,
491.5357971191406f, 216.47879028320312f, 450.1073303222656f, 233.21575927734375f, 466.9475402832031f,
278.86163330078125f, 454.966552734375f, 296.38958740234375f, 471.9764404296875f, 259.4800720214844f,
515.1390991210938f, 282.3655090332031f, 539.4806518554688f, 285.031494140625f, 389.0125427246094f,
302.09747314453125f, 406.9799499511719f, 285.1270446777344f, 389.06890869140625f, 301.2108459472656f,
405.7711181640625f, 188.17117309570312f, 467.71533203125f, 208.49929809570312f, 490.401611328125f,
278.93292236328125f, 453.8080139160156f, 295.4295654296875f, 469.9015808105469f, 279.0393371582031f,
454.2393798828125f, 296.3529357910156f, 471.6363525390625f, 187.29873657226562f, 467.9837951660156f,
208.29107666015625f, 489.8014221191406f, 187.79478454589844f, 466.6510314941406f, 208.3644561767578f,
490.2976989746094f, 188.4196014404297f, 468.3448486328125f, 209.06849670410156f, 491.94384765625f,
281.4726867675781f, 454.0541687011719f, 298.2876892089844f, 470.2845764160156f, 225.8560333251953f,
387.4819030761719f, 241.4767608642578f, 403.4317321777344f, 280.7021484375f, 455.43206787109375f,
297.9931640625f, 471.99749755859375f, 226.0373077392578f, 387.4749450683594f, 241.48097229003906f,
403.4716491699219f, 259.018310546875f, 515.3871459960938f, 281.7872314453125f, 540.0093383789062f,
217.71246337890625f, 450.4556884765625f, 234.254150390625f, 467.68182373046875f, 257.5479736328125f,
518.8912353515625f, 280.48260498046875f, 541.3863525390625f, 216.87359619140625f, 450.3395080566406f,
232.39752197265625f, 465.5039367675781f, 258.2445068359375f, 515.2009887695312f, 280.29803466796875f,
540.3602905273438f, 217.54478454589844f, 451.3944091796875f, 233.6602020263672f, 467.51971435546875f,
258.30133056640625f, 515.2357788085938f, 280.1400146484375f, 541.3275756835938f, 217.05136108398438f,
451.8975524902344f, 232.9573974609375f, 466.9907531738281f, 215.86386108398438f, 450.801025390625f,
232.117919921875f, 466.3701171875f, 279.01593017578125f, 453.6647644042969f, 296.13372802734375f,
471.4644470214844f, 280.1851806640625f, 454.41900634765625f, 296.481201171875f, 471.63104248046875f,
259.1214904785156f, 516.8644409179688f, 281.7276306152344f, 541.0162963867188f, 285.2935485839844f,
389.03515625f, 302.1134948730469f, 406.89373779296875f, 279.6715393066406f, 455.1846923828125f,
296.6995544433594f, 471.5782470703125f, 258.1405029296875f, 518.9312744140625f, 281.019287109375f,
541.5760498046875f, 187.80953979492188f, 466.8480224609375f, 208.54336547851562f, 489.9696044921875f}));
test_case.add_input(
Shape{1, 1, 50},
std::vector<float>(
{5.485373497009277f, 5.469169616699219f, 5.450349807739258f, 5.446445465087891f, 5.43833065032959f,
5.407294273376465f, 5.3790669441223145f, 5.3575520515441895f, 5.348986625671387f, 5.309826850891113f,
5.266261577606201f, 5.230800151824951f, 5.079848766326904f, 5.066829204559326f, 4.913329601287842f,
4.895563125610352f, 4.8786115646362305f, 4.872953414916992f, 4.825906753540039f, 4.812736511230469f,
4.761179447174072f, 4.657320022583008f, 4.640903949737549f, 4.63286828994751f, 4.600266933441162f,
4.599870204925537f, 4.5536088943481445f, 4.521742820739746f, 4.465426445007324f, 4.4556074142456055f,
4.451722621917725f, 4.416017055511475f, 4.410635471343994f, 4.403003215789795f, 4.387508392333984f,
4.3634934425354f, 4.362300872802734f, 4.348748683929443f, 4.345107555389404f, 4.32416296005249f,
4.3132781982421875f, 4.287333965301514f, 4.223401069641113f, 4.220005035400391f, 4.179988861083984f,
4.099865436553955f, 4.097578048706055f, 4.075544357299805f, 4.0459885597229f}));
test_case.add_expected_output<int64_t>(Shape{7, 3},
{0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 10, 0, 0, 11, 0, 0, 22});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_log_sum) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(CommonTestUtils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/reduce_log_sum.onnx"));