[IE CLDNN] Add common comparator for ngraph reference and cldnn detection output primitive (#6530)

- Add common comparator for sort function to be aligned with others(cldnn/mkldnn DO primitive)

Signed-off-by: Andrew Kwangwoong Park <andrew.kwangwoong.park@intel.com>

Apply clang-format to do ngraph reference code

Signed-off-by: Andrew Kwangwoong Park <andrew.kwangwoong.park@intel.com>

Apply code review

Signed-off-by: Andrew Kwangwoong Park <andrew.kwangwoong.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2021-07-13 20:52:45 +09:00 committed by GitHub
parent 01eebba54b
commit 2edb0e05cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 14 deletions

View File

@ -14,9 +14,15 @@ using namespace MKLDNNPlugin;
using namespace InferenceEngine;
template <typename T>
static bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second < pair2.second);
}
template <>
bool SortScorePairDescend<std::pair<int, int>>(const std::pair<float, std::pair<int, int>>& pair1,
const std::pair<float, std::pair<int, int>>& pair2) {
return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second.second < pair2.second.second);
}
bool MKLDNNDetectionOutputNode::isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept {

View File

@ -30,6 +30,18 @@ namespace {
using bounding_box = cldnn::cpu::bounding_box;
} // namespace
template <typename T>
bool comp_score_descend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second < pair2.second);
}
template <>
bool comp_score_descend<std::pair<int, int>>(const std::pair<float, std::pair<int, int>>& pair1,
const std::pair<float, std::pair<int, int>>& pair2) {
return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second.second < pair2.second.second);
}
/************************ Detection Output CPU ************************/
struct detection_output_cpu : typed_primitive_impl<detection_output> {
enum NMSType {CAFFE, MXNET};
@ -186,7 +198,9 @@ struct detection_output_cpu : typed_primitive_impl<detection_output> {
const bool share_location,
std::map<int, std::vector<int>>& indices,
std::vector<std::pair<float, std::pair<int, int>>>& scoreIndexPairs) {
std::sort(scoreIndexPairs.begin(), scoreIndexPairs.end(), comp_score_descend<std::pair<int, int>>);
std::sort(scoreIndexPairs.begin(),
scoreIndexPairs.end(),
comp_score_descend<std::pair<int, int>>);
if (top_k != -1)
if (scoreIndexPairs.size() > static_cast<size_t>(top_k))
@ -244,12 +258,6 @@ struct detection_output_cpu : typed_primitive_impl<detection_output> {
}
}
template <typename T>
static bool comp_score_descend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <typename dtype>
void generate_detections(stream& stream, const detection_output_inst& instance,
const int num_of_images,
@ -311,7 +319,9 @@ struct detection_output_cpu : typed_primitive_impl<detection_output> {
}
}
std::sort(score_index_pairs.begin(), score_index_pairs.end(), comp_score_descend<std::pair<int, int>>);
std::sort(score_index_pairs.begin(),
score_index_pairs.end(),
comp_score_descend<std::pair<int, int>>);
score_index_pairs.resize(args.keep_top_k);
std::vector<std::vector<std::pair<float, int>>> new_indices(args.num_classes);

View File

@ -381,7 +381,8 @@ namespace ngraph
static bool SortScorePairDescend(const std::pair<dataType, T>& pair1,
const std::pair<dataType, T>& pair2)
{
return pair1.first > pair2.first;
return (pair1.first > pair2.first) ||
(pair1.first == pair2.first && pair1.second < pair2.second);
}
void GetMaxScoreIndex(const std::vector<dataType>& scores,
@ -505,7 +506,12 @@ namespace ngraph
}
std::sort(scoreIndexPairs.begin(),
scoreIndexPairs.end(),
SortScorePairDescend<std::pair<int, int>>);
[](const std::pair<dataType, std::pair<int, int>>& p1,
const std::pair<dataType, std::pair<int, int>>& p2) {
return (p1.first > p2.first) ||
(p1.first == p2.first &&
p1.second.second < p2.second.second);
});
if (attrs.top_k != -1)
if (scoreIndexPairs.size() > static_cast<size_t>(attrs.top_k))
@ -651,7 +657,12 @@ namespace ngraph
}
std::sort(scoreIndexPairs.begin(),
scoreIndexPairs.end(),
SortScorePairDescend<std::pair<int, int>>);
[](const std::pair<dataType, std::pair<int, int>>& p1,
const std::pair<dataType, std::pair<int, int>>& p2) {
return (p1.first > p2.first) ||
(p1.first == p2.first &&
p1.second.second < p2.second.second);
});
scoreIndexPairs.resize(attrs.keep_top_k[0]);
std::map<int, std::vector<int>> newIndices;
for (size_t j = 0; j < scoreIndexPairs.size(); ++j)