[GPU] Fix proposal sort condition (#16981)
This commit is contained in:
@@ -64,7 +64,7 @@ inline void float_write_helper(half_t* mem, float f) { *mem = static_cast<half_t
|
||||
void sort_and_keep_n_items(std::vector<proposal_t>& proposals, size_t n) {
|
||||
auto cmp_fn = [](const proposal_t& a, const proposal_t& b) { return (a.confidence > b.confidence); };
|
||||
|
||||
if (proposals.size() > n) {
|
||||
if (proposals.size() >= n) {
|
||||
std::partial_sort(proposals.begin(), proposals.begin() + n, proposals.end(), cmp_fn);
|
||||
proposals.resize(n);
|
||||
} else {
|
||||
|
||||
@@ -18,6 +18,7 @@ extern size_t cls_scores_data_size;
|
||||
extern float bbox_pred_data[];
|
||||
extern size_t bbox_pred_data_size;
|
||||
extern float proposal_ref[];
|
||||
extern float proposal_zero_score_ref[];
|
||||
extern size_t proposal_ref_size;
|
||||
|
||||
const float epsilon_fp16 = 0.125f;
|
||||
@@ -149,6 +150,24 @@ void test_proposal_basic(cldnn::tensor image_info_size, bool is_caching_test) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void test_proposal_zero_score(cldnn::tensor image_info_size) {
|
||||
std::vector<Dtype> cls_scores(cls_scores_data_size, 0);
|
||||
std::vector<Dtype> bbox_pred(&bbox_pred_data[0], &bbox_pred_data[bbox_pred_data_size]);
|
||||
|
||||
TestRunnerProposal<Dtype> t(image_info_size, false);
|
||||
|
||||
memory::ptr output = t.Run(cls_scores, bbox_pred);
|
||||
ASSERT_EQ(output->get_layout().count(), proposal_ref_size);
|
||||
|
||||
cldnn::mem_lock<Dtype> f(output, get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < proposal_ref_size/5; i++) {
|
||||
Dtype ref(proposal_zero_score_ref[i]);
|
||||
ASSERT_NEAR((float)f[i], (float)ref, epsilon_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(proposal, basic) {
|
||||
test_proposal_basic<float>({ 1, 3, 1, 1 }, false);
|
||||
}
|
||||
@@ -165,6 +184,14 @@ TEST(proposal, img_info_batch_only) {
|
||||
test_proposal_basic<float>({ 3, 1, 1, 1 }, false);
|
||||
}
|
||||
|
||||
TEST(proposal, zero_score_fp32) {
|
||||
test_proposal_zero_score<float>({ 1, 3, 1, 1 });
|
||||
}
|
||||
|
||||
TEST(proposal, zero_score_fp16) {
|
||||
test_proposal_zero_score<FLOAT16>({ 1, 3, 1, 1 });
|
||||
}
|
||||
|
||||
template <typename Dtype, typename ImInfoType>
|
||||
void test_proposal_basic_two_types(cldnn::tensor image_info_size, bool is_caching_test) {
|
||||
std::vector<Dtype> cls_scores(&cls_scores_data[0], &cls_scores_data[cls_scores_data_size]);
|
||||
|
||||
@@ -2534,4 +2534,32 @@ float proposal_ref[] = {
|
||||
0.0f, 183.979691f, 0.000000f, 341.391418f, 129.322388f, // 24
|
||||
};
|
||||
|
||||
float proposal_zero_score_ref[] = {
|
||||
0.0000, 0.0000, 10.7419, 349.0000, 209.0000,
|
||||
0.0000, 82.6684, 0.0000, 349.0000, 158.4882,
|
||||
0.0000, 0.0000, 0.0000, 168.8084, 209.0000,
|
||||
0.0000, 201.2879, 164.9072, 253.0708, 206.6294,
|
||||
0.0000, 171.4072, 0.0000, 349.0000, 165.9408,
|
||||
0.0000, 24.2378, 67.4444, 72.8695, 125.2722,
|
||||
0.0000, 48.6425, 18.6561, 189.1384, 129.6040,
|
||||
0.0000, 89.8496, 0.0000, 310.3151, 209.0000,
|
||||
0.0000, 173.5109, 134.3725, 349.0000, 209.0000,
|
||||
0.0000, 222.8794, 48.1469, 349.0000, 209.0000,
|
||||
0.0000, 0.0000, 68.6857, 103.2972, 147.2614,
|
||||
0.0000, 128.2033, 0.0000, 349.0000, 118.0934,
|
||||
0.0000, 250.3240, 89.0468, 349.0000, 197.7571,
|
||||
0.0000, 66.1722, 58.5458, 99.4360, 84.9097,
|
||||
0.0000, 235.6473, 74.9285, 334.7577, 184.4795,
|
||||
0.0000, 0.0000, 100.1960, 287.8029, 209.0000,
|
||||
0.0000, 139.2078, 0.0000, 251.0686, 74.5911,
|
||||
0.0000, 188.9240, 160.3442, 309.3414, 209.0000,
|
||||
0.0000, 36.0059, 0.0000, 246.2041, 135.8893,
|
||||
0.0000, 119.4575, 49.2755, 329.5901, 187.4402,
|
||||
0.0000, 294.3554, 26.5605, 322.4006, 49.2127,
|
||||
0.0000, 0.0000, 0.0000, 135.3708, 89.2054,
|
||||
0.0000, 113.0944, 67.6953, 256.4803, 170.4842,
|
||||
0.0000, 13.2844, 0.0000, 41.5031, 23.5193,
|
||||
0.0000, 0.0000, 0.0000, 322.6976, 152.9498,
|
||||
};
|
||||
|
||||
size_t proposal_ref_size = sizeof(proposal_ref) / sizeof(proposal_ref[0]);
|
||||
|
||||
Reference in New Issue
Block a user