From 2edb0e05cc08bf33ea3a12f5d931f6f5f062496f Mon Sep 17 00:00:00 2001 From: Andrew Kwangwoong Park Date: Tue, 13 Jul 2021 20:52:45 +0900 Subject: [PATCH] [IE CLDNN] Add common comparator for ngraph reference and cldnn detection output primitive (#6530) - Add common comparator for sort function to be aligned with others(cldnn/mkldnn DO primitive) Signed-off-by: Andrew Kwangwoong Park Apply clang-format to do ngraph reference code Signed-off-by: Andrew Kwangwoong Park Apply code review Signed-off-by: Andrew Kwangwoong Park --- .../nodes/mkldnn_detection_output_node.cpp | 12 ++++++--- .../clDNN/src/gpu/detection_output_cpu.cpp | 26 +++++++++++++------ .../runtime/reference/detection_output.hpp | 17 +++++++++--- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_detection_output_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_detection_output_node.cpp index 4b8c695a987..0a2f4fc8140 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_detection_output_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_detection_output_node.cpp @@ -14,9 +14,15 @@ using namespace MKLDNNPlugin; using namespace InferenceEngine; template -static bool SortScorePairDescend(const std::pair& pair1, - const std::pair& pair2) { - return pair1.first > pair2.first; +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second < pair2.second); +} + +template <> +bool SortScorePairDescend>(const std::pair>& pair1, + const std::pair>& pair2) { + return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second.second < pair2.second.second); } bool MKLDNNDetectionOutputNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { diff --git a/inference-engine/thirdparty/clDNN/src/gpu/detection_output_cpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/detection_output_cpu.cpp index 020ac340d48..efe7033fb9c 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/detection_output_cpu.cpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/detection_output_cpu.cpp @@ -30,6 +30,18 @@ namespace { using bounding_box = cldnn::cpu::bounding_box; } // namespace +template +bool comp_score_descend(const std::pair& pair1, + const std::pair& pair2) { + return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second < pair2.second); +} + +template <> +bool comp_score_descend>(const std::pair>& pair1, + const std::pair>& pair2) { + return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second.second < pair2.second.second); +} + /************************ Detection Output CPU ************************/ struct detection_output_cpu : typed_primitive_impl { enum NMSType {CAFFE, MXNET}; @@ -186,7 +198,9 @@ struct detection_output_cpu : typed_primitive_impl { const bool share_location, std::map>& indices, std::vector>>& scoreIndexPairs) { - std::sort(scoreIndexPairs.begin(), scoreIndexPairs.end(), comp_score_descend>); + std::sort(scoreIndexPairs.begin(), + scoreIndexPairs.end(), + comp_score_descend>); if (top_k != -1) if (scoreIndexPairs.size() > static_cast(top_k)) @@ -244,12 +258,6 @@ struct detection_output_cpu : typed_primitive_impl { } } - template - static bool comp_score_descend(const std::pair& pair1, - const std::pair& pair2) { - return pair1.first > pair2.first; - } - template void generate_detections(stream& stream, const detection_output_inst& instance, const int num_of_images, @@ -311,7 +319,9 @@ struct detection_output_cpu : typed_primitive_impl { } } - std::sort(score_index_pairs.begin(), score_index_pairs.end(), comp_score_descend>); + std::sort(score_index_pairs.begin(), + score_index_pairs.end(), + comp_score_descend>); score_index_pairs.resize(args.keep_top_k); std::vector>> new_indices(args.num_classes); diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/detection_output.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/detection_output.hpp index 42a2d59d07b..bfc5ff1c584 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/detection_output.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/detection_output.hpp @@ -381,7 +381,8 @@ namespace ngraph static bool SortScorePairDescend(const std::pair& pair1, const std::pair& pair2) { - return pair1.first > pair2.first; + return (pair1.first > pair2.first) || + (pair1.first == pair2.first && pair1.second < pair2.second); } void GetMaxScoreIndex(const std::vector& scores, @@ -505,7 +506,12 @@ namespace ngraph } std::sort(scoreIndexPairs.begin(), scoreIndexPairs.end(), - SortScorePairDescend>); + [](const std::pair>& p1, + const std::pair>& p2) { + return (p1.first > p2.first) || + (p1.first == p2.first && + p1.second.second < p2.second.second); + }); if (attrs.top_k != -1) if (scoreIndexPairs.size() > static_cast(attrs.top_k)) @@ -651,7 +657,12 @@ namespace ngraph } std::sort(scoreIndexPairs.begin(), scoreIndexPairs.end(), - SortScorePairDescend>); + [](const std::pair>& p1, + const std::pair>& p2) { + return (p1.first > p2.first) || + (p1.first == p2.first && + p1.second.second < p2.second.second); + }); scoreIndexPairs.resize(attrs.keep_top_k[0]); std::map> newIndices; for (size_t j = 0; j < scoreIndexPairs.size(); ++j)