From 8a21e4e062bdc0818be67c6fe09ad2fe49d22b13 Mon Sep 17 00:00:00 2001 From: opoluektov-lohika Date: Mon, 27 Jun 2022 17:11:03 +0300 Subject: [PATCH] [GPU] Implement ExperimentalDetectronDetectionOutput operation (#11772) * ExperimentalDetectronDetectionOutput: refine sorting criteria for NMS stage This is to ensure the operation produces stable predictable results across the possible sorting algorithm implementaions. This property is useful for the operation testing. * [GPU] Implement ExperimentalDetectronDetectionOutput operation * [GPU] ExperimentalDetectronDetectionOutput: use vector types and operations in kernel * Reformat changed files to make clang format checker happy * [GPU] ExperimentalDetectronDetectionOutput: add another test case to the unit test * [GPU] ExperimentalDetectronDetectionOutput: Add f16 test * ExperimentalDetectronDetectionOutput: single-layer test: use all three outputs * [GPU] ExperimentalDetectronDetectionOutput: increase single layer test coverage More attribute permutations were added. --- ...xperimental_detectron_detection_output.cpp | 2 +- ...xperimental_detectron_detection_output.cpp | 8 +- .../intel_gpu/plugin/primitives_list.hpp | 5 +- ...xperimental_detectron_detection_output.hpp | 98 ++++++ ...xperimental_detectron_detection_output.cpp | 49 +++ ...xperimental_detectron_detection_output.cpp | 80 +++++ .../src/graph/impls/ocl/register.cpp | 4 +- .../src/graph/impls/ocl/register.hpp | 28 +- ...mental_detectron_detection_output_inst.hpp | 66 ++++ .../src/kernel_selector/common/common_types.h | 1 + ..._detectron_detection_output_kernel_ref.cpp | 199 +++++++++++ ...al_detectron_detection_output_kernel_ref.h | 63 ++++ ...ctron_detection_output_kernel_selector.cpp | 25 ++ ...tectron_detection_output_kernel_selector.h | 19 ++ ...rimental_detectron_detection_output_ref.cl | 319 +++++++++++++++++ ...xperimental_detectron_detection_output.cpp | 101 ++++++ ...al_detectron_detection_output_gpu_test.cpp | 322 ++++++++++++++++++ ...xperimental_detectron_detection_output.cpp | 66 ++++ ...xperimental_detectron_detection_output.cpp | 91 +++-- 19 files changed, 1491 insertions(+), 55 deletions(-) create mode 100644 src/plugins/intel_gpu/include/intel_gpu/primitives/experimental_detectron_detection_output.hpp create mode 100644 src/plugins/intel_gpu/src/graph/experimental_detectron_detection_output.cpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl/experimental_detectron_detection_output.cpp create mode 100644 src/plugins/intel_gpu/src/graph/include/experimental_detectron_detection_output_inst.hpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.h create mode 100644 src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/experimental_detectron_detection_output_ref.cl create mode 100644 src/plugins/intel_gpu/src/plugin/ops/experimental_detectron_detection_output.cpp create mode 100644 src/plugins/intel_gpu/tests/test_cases/experimental_detectron_detection_output_gpu_test.cpp create mode 100644 src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/experimental_detectron_detection_output.cpp diff --git a/src/core/reference/src/runtime/reference/experimental_detectron_detection_output.cpp b/src/core/reference/src/runtime/reference/experimental_detectron_detection_output.cpp index 53e08ddb1b5..1edd6005082 100644 --- a/src/core/reference/src/runtime/reference/experimental_detectron_detection_output.cpp +++ b/src/core/reference/src/runtime/reference/experimental_detectron_detection_output.cpp @@ -203,7 +203,7 @@ void nms_cf(const float* conf_data, template bool SortScorePairDescend(const std::pair& pair1, const std::pair& pair2) { - return pair1.first > pair2.first; + return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second)); } } // namespace diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp b/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp index bdf51692cd7..1d1505b0f63 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp @@ -118,9 +118,9 @@ static void refine_boxes(const float* boxes, } } -template -static bool SortScorePairDescend(const std::pair& pair1, const std::pair& pair2) { - return pair1.first > pair2.first; +static bool SortScorePairDescend(const std::pair>& pair1, + const std::pair>& pair2) { + return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second)); } struct ConfidenceComparator { @@ -362,7 +362,7 @@ void ExperimentalDetectronDetectionOutput::execute(dnnl::stream strm) { std::partial_sort(conf_index_class_map.begin(), conf_index_class_map.begin() + max_detections_per_image_, conf_index_class_map.end(), - SortScorePairDescend>); + SortScorePairDescend); conf_index_class_map.resize(max_detections_per_image_); total_detections_num = max_detections_per_image_; } diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index be1e16ded34..f0510dabf89 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -3,7 +3,7 @@ // #ifndef REGISTER_FACTORY -#error "REGISTER_FACTORY is not defined" +# error "REGISTER_FACTORY is not defined" #endif // ------------------------------ Supported v0 ops ------------------------------ // @@ -190,7 +190,7 @@ REGISTER_FACTORY(v4, Swish); REGISTER_FACTORY(v5, HSigmoid); REGISTER_FACTORY(v5, LogSoftmax); REGISTER_FACTORY(v5, LSTMSequence); -//REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion +// REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion REGISTER_FACTORY(v5, Round); REGISTER_FACTORY(v5, GatherND); REGISTER_FACTORY(v5, Loop); @@ -208,6 +208,7 @@ REGISTER_FACTORY(v6, ExperimentalDetectronPriorGridGenerator); REGISTER_FACTORY(v6, ExperimentalDetectronROIFeatureExtractor); REGISTER_FACTORY(v6, ExperimentalDetectronTopKROIs) REGISTER_FACTORY(v6, ExperimentalDetectronGenerateProposalsSingleImage); +REGISTER_FACTORY(v6, ExperimentalDetectronDetectionOutput); // ------------------------------ Supported v7 ops ------------------------------ // REGISTER_FACTORY(v7, DFT); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/experimental_detectron_detection_output.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/experimental_detectron_detection_output.hpp new file mode 100644 index 00000000000..a70962be5e4 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/experimental_detectron_detection_output.hpp @@ -0,0 +1,98 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include +#include + +#include "primitive.hpp" + +namespace cldnn { +/// @addtogroup cpp_api C++ API +/// @{ +/// @addtogroup cpp_topology Network Topology +/// @{ +/// @addtogroup cpp_primitives Primitives +/// @{ + +/// @brief experimental detectron detection output +struct experimental_detectron_detection_output : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(experimental_detectron_detection_output) + + /// @brief Constructs experimental_detectron_detection_output primitive + /// @param id This primitive id + /// @param input_rois input rois + /// @param input_deltas input deltas + /// @param input_scores input scores + /// @param input_im_info image info + /// @param output_classes ROI scores + /// @param output_scores minimum box width and height + /// @param score_threshold a threshold to consider only detections whose score are larger than the threshold + /// @param nms_threshold a threshold to be used in the NMS stage + /// @param num_classes the number of detected classes + /// @param post_nms_count the maximum number of detections per class + /// @param max_detections_per_image the maximum number of detections per image + /// @param class_agnostic_box_regression specifies whether to delete background classes or not + /// @param max_delta_log_wh the maximum delta of logarithms for width and height + /// @param deltas_weights the weights for bounding boxes sizes deltas + experimental_detectron_detection_output(const primitive_id& id, + const primitive_id& input_rois, + const primitive_id& input_deltas, + const primitive_id& input_scores, + const primitive_id& input_im_info, + const primitive_id& output_classes, + const primitive_id& output_scores, + float score_threshold, + float nms_threshold, + int num_classes, + int post_nms_count, + int max_detections_per_image, + bool class_agnostic_box_regression, + float max_delta_log_wh, + std::vector deltas_weights, + const primitive_id& ext_prim_id = "", + const padding& output_padding = {}) + : primitive_base{id, + {input_rois, input_deltas, input_scores, input_im_info, output_classes, output_scores}, + ext_prim_id, + output_padding}, + output_classes{output_classes}, + output_scores{output_scores}, + score_threshold{score_threshold}, + nms_threshold{nms_threshold}, + num_classes{num_classes}, + post_nms_count{post_nms_count}, + class_agnostic_box_regression{class_agnostic_box_regression}, + max_detections_per_image{max_detections_per_image}, + max_delta_log_wh{max_delta_log_wh}, + deltas_weights{std::move(deltas_weights)} {} + + primitive_id output_classes; + primitive_id output_scores; + float score_threshold; + float nms_threshold; + int num_classes; + int post_nms_count; + int max_detections_per_image; + bool class_agnostic_box_regression; + float max_delta_log_wh; + std::vector deltas_weights; + +protected: + std::vector> get_dependencies() const override { + std::vector> ret; + if (!output_classes.empty()) + ret.emplace_back(output_classes); + + if (!output_scores.empty()) + ret.emplace_back(output_scores); + + return ret; + } +}; +/// @} +/// @} +/// @} +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/experimental_detectron_detection_output.cpp b/src/plugins/intel_gpu/src/graph/experimental_detectron_detection_output.cpp new file mode 100644 index 00000000000..26fa53be9cd --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/experimental_detectron_detection_output.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "experimental_detectron_detection_output_inst.hpp" +#include "intel_gpu/runtime/error_handler.hpp" +#include "json_object.h" +#include "primitive_type_base.h" + +namespace cldnn { +primitive_type_id experimental_detectron_detection_output::type_id() { + static primitive_type_base instance; + return &instance; +} + +layout experimental_detectron_detection_output_inst::calc_output_layout( + const experimental_detectron_detection_output_node& node) { + const layout data_layout = node.input().get_output_layout(); + auto desc = node.get_primitive(); + + return layout(data_layout.data_type, format::bfyx, {static_cast(desc->max_detections_per_image), 4, 1, 1}); +} + +std::string experimental_detectron_detection_output_inst::to_string( + const experimental_detectron_detection_output_node& node) { + auto desc = node.get_primitive(); + + std::stringstream primitive_description; + + json_composite ed_info; + ed_info.add("score_threshold", desc->score_threshold); + ed_info.add("nms_threshold", desc->nms_threshold); + ed_info.add("score_threshold", desc->score_threshold); + ed_info.add("max_delta_log_wh", desc->max_delta_log_wh); + ed_info.add("num_classes", desc->num_classes); + ed_info.add("post_nms_count", desc->post_nms_count); + ed_info.add("max_detections_per_image", desc->max_detections_per_image); + ed_info.add("class_agnostic_box_regression", desc->class_agnostic_box_regression); + ed_info.add("deltas_weights", desc->deltas_weights); + + auto node_info = node.desc_to_json(); + node_info->add("experimental_detectron_detection_output_info", ed_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/experimental_detectron_detection_output.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/experimental_detectron_detection_output.cpp new file mode 100644 index 00000000000..57d57b47ef3 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/experimental_detectron_detection_output.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "eddo/experimental_detectron_detection_output_kernel_ref.h" +#include "eddo/experimental_detectron_detection_output_kernel_selector.h" +#include "experimental_detectron_detection_output_inst.hpp" +#include "impls/implementation_map.hpp" +#include "kernel_selector_helper.h" +#include "primitive_base.hpp" + +namespace cldnn { +namespace ocl { +struct experimental_detectron_detection_output_impl + : public typed_primitive_impl_ocl { + using parent = typed_primitive_impl_ocl; + using parent::parent; + + std::unique_ptr clone() const override { + return make_unique(*this); + } + +protected: + kernel_arguments_data get_arguments(typed_primitive_inst& instance, + int32_t unused) const override { + kernel_arguments_data args = parent::get_arguments(instance, unused); + args.inputs.push_back(instance.output_classes_memory()); + args.inputs.push_back(instance.output_scores_memory()); + + return args; + } + +public: + static primitive_impl* create(const experimental_detectron_detection_output_node& arg) { + auto params = get_default_params(arg); + auto optional_params = + get_default_optional_params( + arg.get_program()); + + const auto& primitive = arg.get_primitive(); + + params.score_threshold = primitive->score_threshold; + params.nms_threshold = primitive->nms_threshold; + params.max_delta_log_wh = primitive->max_delta_log_wh; + params.num_classes = primitive->num_classes; + params.post_nms_count = primitive->post_nms_count; + params.max_detections_per_image = primitive->max_detections_per_image; + params.class_agnostic_box_regression = primitive->class_agnostic_box_regression; + params.deltas_weights = primitive->deltas_weights; + + params.inputs.push_back(convert_data_tensor(arg.deltas().get_output_layout())); + params.inputs.push_back(convert_data_tensor(arg.scores().get_output_layout())); + params.inputs.push_back(convert_data_tensor(arg.image_size_info().get_output_layout())); + + params.inputs.push_back(convert_data_tensor(arg.output_classes_node().get_output_layout())); + params.inputs.push_back(convert_data_tensor(arg.output_scores_node().get_output_layout())); + + const auto& kernel_selector = + kernel_selector::experimental_detectron_detection_output_kernel_selector::Instance(); + const auto best_kernels = kernel_selector.GetBestKernels(params, optional_params); + + CLDNN_ERROR_BOOL(arg.id(), + "best_kernels.empty()", + best_kernels.empty(), + "Cannot find a proper kernel with this arguments"); + + return new experimental_detectron_detection_output_impl(arg, best_kernels[0]); + } +}; + +namespace detail { +attach_experimental_detectron_detection_output_impl::attach_experimental_detectron_detection_output_impl() { + implementation_map::add( + impl_types::ocl, + experimental_detectron_detection_output_impl::create, + {std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::f32, format::bfyx)}); +} +} // namespace detail +} // namespace ocl +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp index 6b38d312339..9dd25effb85 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp @@ -8,8 +8,7 @@ namespace cldnn { namespace ocl { -#define REGISTER_OCL(prim) \ - static detail::attach_##prim##_impl attach_##prim +#define REGISTER_OCL(prim) static detail::attach_##prim##_impl attach_##prim void register_implementations() { REGISTER_OCL(activation); @@ -31,6 +30,7 @@ void register_implementations() { REGISTER_OCL(detection_output); REGISTER_OCL(dft); REGISTER_OCL(batch_to_space); + REGISTER_OCL(experimental_detectron_detection_output); REGISTER_OCL(experimental_detectron_generate_proposals_single_image); REGISTER_OCL(experimental_detectron_prior_grid_generator); REGISTER_OCL(experimental_detectron_roi_feature_extractor); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp index 712f83f028c..cad3e500f13 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp @@ -5,6 +5,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// #pragma once +#include "generic_layer.hpp" #include "intel_gpu/primitives/activation.hpp" #include "intel_gpu/primitives/arg_max_min.hpp" #include "intel_gpu/primitives/average_unpooling.hpp" @@ -14,8 +15,10 @@ #include "intel_gpu/primitives/broadcast.hpp" #include "intel_gpu/primitives/bucketize.hpp" #include "intel_gpu/primitives/concatenation.hpp" +#include "intel_gpu/primitives/convert_color.hpp" #include "intel_gpu/primitives/convolution.hpp" #include "intel_gpu/primitives/crop.hpp" +#include "intel_gpu/primitives/ctc_greedy_decoder.hpp" #include "intel_gpu/primitives/custom_gpu_primitive.hpp" #include "intel_gpu/primitives/deconvolution.hpp" #include "intel_gpu/primitives/depth_to_space.hpp" @@ -26,12 +29,16 @@ #include "intel_gpu/primitives/experimental_detectron_topk_rois.hpp" #include "intel_gpu/primitives/fully_connected.hpp" #include "intel_gpu/primitives/gather.hpp" -#include "intel_gpu/primitives/gather_nd.hpp" #include "intel_gpu/primitives/gather_elements.hpp" +#include "intel_gpu/primitives/gather_nd.hpp" +#include "intel_gpu/primitives/gather_tree.hpp" #include "intel_gpu/primitives/gemm.hpp" +#include "intel_gpu/primitives/grn.hpp" #include "intel_gpu/primitives/lrn.hpp" #include "intel_gpu/primitives/lstm.hpp" #include "intel_gpu/primitives/lstm_dynamic.hpp" +#include "intel_gpu/primitives/lstm_dynamic_input.hpp" +#include "intel_gpu/primitives/lstm_dynamic_timeloop.hpp" #include "intel_gpu/primitives/max_unpooling.hpp" #include "intel_gpu/primitives/mutable_data.hpp" #include "intel_gpu/primitives/mvn.hpp" @@ -48,15 +55,16 @@ #include "intel_gpu/primitives/region_yolo.hpp" #include "intel_gpu/primitives/reorder.hpp" #include "intel_gpu/primitives/reorg_yolo.hpp" +#include "intel_gpu/primitives/resample.hpp" #include "intel_gpu/primitives/reshape.hpp" #include "intel_gpu/primitives/reverse_sequence.hpp" #include "intel_gpu/primitives/roi_align.hpp" #include "intel_gpu/primitives/roi_pooling.hpp" #include "intel_gpu/primitives/roll.hpp" #include "intel_gpu/primitives/scale.hpp" -#include "intel_gpu/primitives/scatter_update.hpp" #include "intel_gpu/primitives/scatter_elements_update.hpp" #include "intel_gpu/primitives/scatter_nd_update.hpp" +#include "intel_gpu/primitives/scatter_update.hpp" #include "intel_gpu/primitives/select.hpp" #include "intel_gpu/primitives/shape_of.hpp" #include "intel_gpu/primitives/shuffle_channels.hpp" @@ -65,15 +73,6 @@ #include "intel_gpu/primitives/space_to_batch.hpp" #include "intel_gpu/primitives/strided_slice.hpp" #include "intel_gpu/primitives/tile.hpp" -#include "intel_gpu/primitives/resample.hpp" -#include "intel_gpu/primitives/gather_tree.hpp" -#include "intel_gpu/primitives/lstm_dynamic_input.hpp" -#include "intel_gpu/primitives/lstm_dynamic_timeloop.hpp" -#include "intel_gpu/primitives/grn.hpp" -#include "intel_gpu/primitives/ctc_greedy_decoder.hpp" -#include "intel_gpu/primitives/convert_color.hpp" -#include "generic_layer.hpp" - namespace cldnn { namespace ocl { @@ -81,9 +80,9 @@ void register_implementations(); namespace detail { -#define REGISTER_OCL(prim) \ - struct attach_##prim##_impl { \ - attach_##prim##_impl(); \ +#define REGISTER_OCL(prim) \ + struct attach_##prim##_impl { \ + attach_##prim##_impl(); \ } REGISTER_OCL(activation); @@ -106,6 +105,7 @@ REGISTER_OCL(deformable_interp); REGISTER_OCL(depth_to_space); REGISTER_OCL(detection_output); REGISTER_OCL(dft); +REGISTER_OCL(experimental_detectron_detection_output); REGISTER_OCL(experimental_detectron_generate_proposals_single_image); REGISTER_OCL(experimental_detectron_prior_grid_generator); REGISTER_OCL(experimental_detectron_roi_feature_extractor); diff --git a/src/plugins/intel_gpu/src/graph/include/experimental_detectron_detection_output_inst.hpp b/src/plugins/intel_gpu/src/graph/include/experimental_detectron_detection_output_inst.hpp new file mode 100644 index 00000000000..2a51e04d35d --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/experimental_detectron_detection_output_inst.hpp @@ -0,0 +1,66 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#pragma once +#include "intel_gpu/primitives/experimental_detectron_detection_output.hpp" +#include "primitive_inst.h" + +namespace cldnn { +template <> +struct typed_program_node + : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input() const { + return get_dependency(0); + } + + program_node& deltas() const { + return get_dependency(1); + } + program_node& scores() const { + return get_dependency(2); + } + program_node& image_size_info() const { + return get_dependency(3); + } + + program_node& output_classes_node() const { + return get_dependency(4); + } + program_node& output_scores_node() const { + return get_dependency(5); + } +}; + +using experimental_detectron_detection_output_node = typed_program_node; + +template <> +class typed_primitive_inst + : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + +public: + static layout calc_output_layout(const experimental_detectron_detection_output_node& node); + static std::string to_string(const experimental_detectron_detection_output_node& node); + + typed_primitive_inst(network& network, const experimental_detectron_detection_output_node& node) + : parent(network, node) {} + + memory::ptr output_classes_memory() const { + return dep_memory_ptr(4); + } + memory::ptr output_scores_memory() const { + return dep_memory_ptr(5); + } +}; + +using experimental_detectron_detection_output_inst = typed_primitive_inst; + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/kernel_selector/common/common_types.h b/src/plugins/intel_gpu/src/kernel_selector/common/common_types.h index 663f18d6c61..3063788e08e 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/common/common_types.h +++ b/src/plugins/intel_gpu/src/kernel_selector/common/common_types.h @@ -79,6 +79,7 @@ enum class KernelType { LOOP, NON_MAX_SUPPRESSION, DETECTION_OUTPUT, + EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT, EXPERIMENTAL_DETECTRON_GENERATE_PROPOSALS_SINGLE_IMAGE, EXPERIMENTAL_DETECTRON_PRIOR_GRID_GENERATOR, EXPERIMENTAL_DETECTRON_ROI_FEATURE_EXTRACTOR, diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.cpp new file mode 100644 index 00000000000..9acd536e202 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.cpp @@ -0,0 +1,199 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "experimental_detectron_detection_output_kernel_ref.h" + +#include +#include + +#include "kernel_selector_utils.h" + +namespace kernel_selector { + +ParamsKey ExperimentalDetectronDetectionOutputKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::INT64); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::INT64); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableBatching(); + k.EnableDifferentTypes(); + return k; +} + +KernelsPriority ExperimentalDetectronDetectionOutputKernelRef::GetKernelsPriority(const Params&, + const optional_params&) const { + return DONT_USE_IF_HAVE_SOMETHING_ELSE; +} + +bool ExperimentalDetectronDetectionOutputKernelRef::Validate(const Params& p, const optional_params& o) const { + if (p.GetType() != KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT || + o.GetType() != KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) { + return false; + } + return true; +} + +constexpr int kBoxesInputIdx = 0; +constexpr int kDeltasInputIdx = 1; +constexpr int kScoresInputIdx = 2; +constexpr int kImInfoInputIdx = 3; +constexpr int kOutputClassesInputIdx = 4; +constexpr int kOutputScoresInputIdx = 5; + +constexpr int kRefinedBoxesBufferIdx = 0; +constexpr int kRefinedBoxAreasBufferIdx = 1; +constexpr int kRefinedScoresBufferIdx = 2; +constexpr int kScoreClassIndexBufferIdx = 3; +constexpr int kDetectionCountBufferIdx = 4; + +constexpr int kBufferCount = 5; + +constexpr int kOutputIdx = 0; + +JitConstants ExperimentalDetectronDetectionOutputKernelRef::GetJitConstants( + const experimental_detectron_detection_output_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstants({ + MakeJitConstant("SCORE_THRESHOLD", params.score_threshold), + MakeJitConstant("NMS_THRESHOLD", params.nms_threshold), + MakeJitConstant("NUM_CLASSES", params.num_classes), + MakeJitConstant("POST_NMS_COUNT", params.post_nms_count), + MakeJitConstant("MAX_DETECTIONS_PER_IMAGE", params.max_detections_per_image), + MakeJitConstant("MAX_DELTA_LOG_WH", params.max_delta_log_wh), + MakeJitConstant("DELTA_WEIGHT_X", params.deltas_weights[0]), + MakeJitConstant("DELTA_WEIGHT_Y", params.deltas_weights[1]), + MakeJitConstant("DELTA_WEIGHT_LOG_W", params.deltas_weights[2]), + MakeJitConstant("DELTA_WEIGHT_LOG_H", params.deltas_weights[3]), + + MakeJitConstant("ROI_COUNT", params.inputs[kScoresInputIdx].Batch().v), + + MakeJitConstant("OUTPUT_INDICES_TYPE", "INPUT4_TYPE"), + }); + + return jit; +} + +using DispatchData = CommonDispatchData; + +void ExperimentalDetectronDetectionOutputKernelRef::PrepareKernelCommon( + const experimental_detectron_detection_output_params& params, + const optional_params& options, + std::vector gws, + const std::string& stage_name, + size_t stage_index, + clKernelData& kernel) const { + DispatchData dispatch_data; + dispatch_data.gws = std::move(gws); + dispatch_data.lws = GetOptimalLocalWorkGroupSizes(dispatch_data.gws, params.engineInfo); + + const auto entry_point = GetEntryPoint(kernelName, params.layerID, params, options, stage_index); + auto cldnn_jit = GetJitConstants(params); + cldnn_jit.AddConstant(MakeJitConstant(stage_name, "true")); + + const auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + KernelBase::CheckDispatchData(kernelName, dispatch_data, params.engineInfo.maxWorkGroupSize); + kernel.params.workGroups.global = dispatch_data.gws; + kernel.params.workGroups.local = dispatch_data.lws; + kernel.code.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo); +} + +void ExperimentalDetectronDetectionOutputKernelRef::PrepareRefineBoxesKernel( + const experimental_detectron_detection_output_params& params, + const optional_params& options, + clKernelData& kernel) const { + const size_t roi_count = params.inputs[kScoresInputIdx].Batch().v; + const size_t class_count = params.num_classes; + + PrepareKernelCommon(params, options, {roi_count, class_count, 1}, "EDDO_STAGE_0_REFINE_BOXES", 0, kernel); + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kBoxesInputIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kDeltasInputIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kScoresInputIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kImInfoInputIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxAreasBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedScoresBufferIdx}); +} + +void ExperimentalDetectronDetectionOutputKernelRef::PrepareNmsClassWiseKernel( + const experimental_detectron_detection_output_params& params, + const optional_params& options, + clKernelData& kernel) const { + PrepareKernelCommon(params, options, {1, 1, 1}, "EDDO_STAGE_1_NMS", 1, kernel); + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedScoresBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxAreasBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx}); +} + +void ExperimentalDetectronDetectionOutputKernelRef::PrepareTopKDetectionsKernel( + const experimental_detectron_detection_output_params& params, + const optional_params& options, + clKernelData& kernel) const { + PrepareKernelCommon(params, options, {1, 1, 1}, "EDDO_STAGE_2_TOPK", 2, kernel); + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx}); +} + +void ExperimentalDetectronDetectionOutputKernelRef::PrepareCopyOutputKernel( + const experimental_detectron_detection_output_params& params, + const optional_params& options, + clKernelData& kernel) const { + PrepareKernelCommon(params, + options, + {static_cast(params.max_detections_per_image), 1, 1}, + "EDDO_STAGE_3_COPY_OUTPUT", + 3, + kernel); + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, kOutputIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kOutputClassesInputIdx}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kOutputScoresInputIdx}); +} + +KernelsData ExperimentalDetectronDetectionOutputKernelRef::GetKernelsData(const Params& params, + const optional_params& options) const { + if (!Validate(params, options)) { + return {}; + } + + constexpr size_t kKernelCount = 4; + KernelData kd = KernelData::Default(params, kKernelCount); + const auto& eddo_params = static_cast(params); + + const auto roi_count = eddo_params.inputs[kScoresInputIdx].Batch().v; + const auto class_count = static_cast(eddo_params.num_classes); + + kd.internalBufferDataType = Datatype::F32; + + kd.internalBufferSizes.resize(kBufferCount); + kd.internalBufferSizes[kRefinedBoxesBufferIdx] = class_count * roi_count * 4 * sizeof(float); + kd.internalBufferSizes[kRefinedBoxAreasBufferIdx] = class_count * roi_count * sizeof(float); + kd.internalBufferSizes[kRefinedScoresBufferIdx] = class_count * roi_count * sizeof(float); + kd.internalBufferSizes[kScoreClassIndexBufferIdx] = class_count * roi_count * 12; // sizeof ScoreClassIndex + kd.internalBufferSizes[kDetectionCountBufferIdx] = sizeof(uint32_t); + + PrepareRefineBoxesKernel(eddo_params, options, kd.kernels[0]); + PrepareNmsClassWiseKernel(eddo_params, options, kd.kernels[1]); + PrepareTopKDetectionsKernel(eddo_params, options, kd.kernels[2]); + PrepareCopyOutputKernel(eddo_params, options, kd.kernels[3]); + + return {kd}; +} + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.h new file mode 100644 index 00000000000..2ce6f0b0449 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_ref.h @@ -0,0 +1,63 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" + +namespace kernel_selector { +struct experimental_detectron_detection_output_params : public base_params { + experimental_detectron_detection_output_params() + : base_params(KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {} + + float score_threshold; + float nms_threshold; + float max_delta_log_wh; + int num_classes; + int post_nms_count; + int max_detections_per_image; + bool class_agnostic_box_regression; + std::vector deltas_weights; +}; + +struct experimental_detectron_detection_output_optional_params : public optional_params { + experimental_detectron_detection_output_optional_params() + : optional_params(KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {} +}; + +class ExperimentalDetectronDetectionOutputKernelRef : public KernelBaseOpenCL { +public: + ExperimentalDetectronDetectionOutputKernelRef() : KernelBaseOpenCL("experimental_detectron_detection_output_ref") {} + + ~ExperimentalDetectronDetectionOutputKernelRef() = default; + +protected: + bool Validate(const Params& p, const optional_params& o) const override; + + KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; + KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override; + ParamsKey GetSupportedKey() const override; + +private: + JitConstants GetJitConstants(const experimental_detectron_detection_output_params& params) const; + void PrepareKernelCommon(const experimental_detectron_detection_output_params& params, + const optional_params& options, + std::vector gws, + const std::string& stage_name, + size_t stage_index, + clKernelData& kernel) const; + void PrepareRefineBoxesKernel(const experimental_detectron_detection_output_params&, + const optional_params&, + clKernelData&) const; + void PrepareNmsClassWiseKernel(const experimental_detectron_detection_output_params&, + const optional_params&, + clKernelData&) const; + void PrepareTopKDetectionsKernel(const experimental_detectron_detection_output_params&, + const optional_params&, + clKernelData&) const; + void PrepareCopyOutputKernel(const experimental_detectron_detection_output_params&, + const optional_params&, + clKernelData&) const; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.cpp b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.cpp new file mode 100644 index 00000000000..01cd2fa498b --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "experimental_detectron_detection_output_kernel_selector.h" + +#include "experimental_detectron_detection_output_kernel_ref.h" + +namespace kernel_selector { +experimental_detectron_detection_output_kernel_selector::experimental_detectron_detection_output_kernel_selector() { + Attach(); +} + +experimental_detectron_detection_output_kernel_selector& +experimental_detectron_detection_output_kernel_selector::Instance() { + static experimental_detectron_detection_output_kernel_selector instance_; + return instance_; +} + +KernelsData experimental_detectron_detection_output_kernel_selector::GetBestKernels( + const Params& params, + const optional_params& options) const { + return GetNaiveBestKernel(params, options, KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT); +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.h b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.h new file mode 100644 index 00000000000..811b7efb202 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/eddo/experimental_detectron_detection_output_kernel_selector.h @@ -0,0 +1,19 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class experimental_detectron_detection_output_kernel_selector : public kernel_selector_base { +public: + static experimental_detectron_detection_output_kernel_selector& Instance(); + + experimental_detectron_detection_output_kernel_selector(); + + KernelsData GetBestKernels(const Params& params, const optional_params& options) const override; +}; + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/experimental_detectron_detection_output_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/experimental_detectron_detection_output_ref.cl new file mode 100644 index 00000000000..a16a4c3d68f --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/experimental_detectron_detection_output_ref.cl @@ -0,0 +1,319 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/batch_headers/data_types.cl" + +#define INPUT_TYPE INPUT0_TYPE +#define INPUT_TYPE2 MAKE_VECTOR_TYPE(INPUT0_TYPE, 2) +#define INPUT_TYPE4 MAKE_VECTOR_TYPE(INPUT0_TYPE, 4) + +#if INPUT0_TYPE_SIZE == 2 // f16 +# define HALF_ONE (INPUT_TYPE2)(0.5h) +# define ZERO 0.0h +# define ONE 1.0h +#else +# define HALF_ONE (INPUT_TYPE2)(0.5f) +# define ZERO 0.0f +# define ONE 1.0f +#endif + +#define ZERO2 (INPUT_TYPE2)(ZERO) +#define ZERO4 (INPUT_TYPE4)(ZERO) + +#define COORDINATE_OFFSET (INPUT_TYPE2)(ONE) + +#define DELTA_WEIGHTS (INPUT_TYPE4)(DELTA_WEIGHT_X, DELTA_WEIGHT_Y, DELTA_WEIGHT_LOG_W, DELTA_WEIGHT_LOG_H) + +#define MAX_DELTA_LOG_SIZE (INPUT_TYPE2)(TO_INPUT0_TYPE(MAX_DELTA_LOG_WH)) + +typedef struct __attribute__((packed)) { + INPUT_TYPE score __attribute__((aligned(4))); + uint class_idx; + uint box_idx; +} FUNC(SCI); + +#define ScoreClassIndex FUNC(SCI) + +inline void FUNC(swap_info)(__global ScoreClassIndex* a, __global ScoreClassIndex* b) { + const ScoreClassIndex temp = *a; + *a = *b; + *b = temp; +} + +inline int FUNC(partition)(__global ScoreClassIndex* arr, int l, int h) { + const INPUT_TYPE pivot_score = arr[h].score; + const size_t pivot_box_idx = arr[h].box_idx; + int i = (l - 1); + for (int j = l; j <= h - 1; j++) { + if ((arr[j].score > pivot_score) || (arr[j].score == pivot_score && arr[j].box_idx < pivot_box_idx)) { + i++; + FUNC_CALL(swap_info)(&arr[i], &arr[j]); + } + } + FUNC_CALL(swap_info)(&arr[i + 1], &arr[h]); + return (i + 1); +} + +inline void FUNC(bubbleSortIterative)(__global ScoreClassIndex* arr, int l, int h) { + for (int i = 0; i < h - l; i++) { + bool swapped = false; + for (int j = l; j < h - i; j++) { + if ((arr[j].score > arr[j + 1].score) || + (arr[j].score == arr[j + 1].score && arr[j].box_idx < arr[j + 1].box_idx)) { + FUNC_CALL(swap_info)(&arr[j], &arr[j + 1]); + swapped = true; + } + } + + if (!swapped) + break; + } +} + +inline void FUNC(quickSortIterative)(__global ScoreClassIndex* arr, int l, int h) { + // Create an auxiliary stack + const int kStackSize = 100; + int stack[kStackSize]; + + // initialize top of stack + int top = -1; + + // push initial values of l and h to stack + stack[++top] = l; + stack[++top] = h; + + // Keep popping from stack while is not empty + while (top >= 0) { + // Pop h and l + h = stack[top--]; + l = stack[top--]; + + // Set pivot element at its correct position + // in sorted array + int p = FUNC_CALL(partition)(arr, l, h); + + // If there are elements on left side of pivot, + // then push left side to stack + if (p - 1 > l) { + if (top >= (kStackSize - 1)) { + FUNC_CALL(bubbleSortIterative)(arr, l, p - 1); + } else { + stack[++top] = l; + stack[++top] = p - 1; + } + } + + // If there are elements on right side of pivot, + // then push right side to stack + if (p + 1 < h) { + if (top >= (kStackSize - 1)) { + FUNC_CALL(bubbleSortIterative)(arr, p + 1, h); + } else { + stack[++top] = p + 1; + stack[++top] = h; + } + } + } +} + +// FIXME: rename stages accordingly +#ifdef EDDO_STAGE_0_REFINE_BOXES + +// 0. Refine boxes +KERNEL(eddo_ref_stage_0) +(const __global INPUT_TYPE* boxes, + const __global INPUT_TYPE* deltas, + const __global INPUT_TYPE* scores, + const __global INPUT_TYPE* im_info, + __global INPUT_TYPE* refined_boxes, + __global INPUT_TYPE* refined_box_areas, + __global INPUT_TYPE* refined_scores) { + const size_t roi_count = get_global_size(0); + size_t roi_idx = get_global_id(0); + size_t class_idx = get_global_id(1); + + INPUT_TYPE4 box = vload4(roi_idx, boxes); + + if (any(islessequal(box.hi - box.lo, ZERO2))) { + const int refined_offset = roi_count * class_idx + roi_idx; + refined_scores[refined_offset] = ZERO; + } else { + const int offset = NUM_CLASSES * roi_idx + class_idx; + + // width & height of box + INPUT_TYPE2 box_size = (box.hi - box.lo + COORDINATE_OFFSET); + + // center location of box + const INPUT_TYPE2 center = box.lo + HALF_ONE * box_size; + + const INPUT_TYPE4 delta = vload4(offset, deltas) / DELTA_WEIGHTS; + + // new center location according to deltas (dx, dy) + const INPUT_TYPE2 new_center = delta.lo * box_size + center; + // new width & height according to deltas d(log w), d(log h) + const INPUT_TYPE2 new_size = exp(min(delta.hi, MAX_DELTA_LOG_SIZE)) * box_size; + + // update upper-left corner and lower-right corners respectively + INPUT_TYPE4 new_box = + (INPUT_TYPE4)(new_center - HALF_ONE * new_size, new_center + HALF_ONE * new_size - COORDINATE_OFFSET); + + // adjust new corner locations to be within the image region + const INPUT_TYPE2 img_size = vload2(0, im_info).s10; + new_box = clamp(new_box, ZERO4, img_size.xyxy); + + // recompute new width & height + const INPUT_TYPE2 new_box_size = new_box.hi - new_box.lo + COORDINATE_OFFSET; + + const int refined_offset = roi_count * class_idx + roi_idx; + vstore4(new_box, refined_offset, refined_boxes); + refined_box_areas[refined_offset] = new_box_size.x * new_box_size.y; + refined_scores[refined_offset] = scores[offset]; + } +} + +#endif /* EDDO_STAGE_0_REFINE_BOXES */ + +#ifdef EDDO_STAGE_1_NMS + +inline INPUT_TYPE FUNC(jaccard_overlap)(const __global INPUT_TYPE* refined_boxes, + const __global INPUT_TYPE* refined_box_areas, + size_t idx1, + size_t idx2) { + INPUT_TYPE4 box1 = vload4(idx1, refined_boxes); + INPUT_TYPE4 box2 = vload4(idx2, refined_boxes); + + const bool bbox_not_covered = any(isgreater((INPUT_TYPE4)(box1.lo, box2.lo), (INPUT_TYPE4)(box2.hi, box1.hi))); + if (bbox_not_covered) { + return ZERO; + } + + INPUT_TYPE2 intersect_min = max(box1.lo, box2.lo); + INPUT_TYPE2 intersect_max = min(box1.hi, box2.hi); + + INPUT_TYPE2 intersect_size = intersect_max - intersect_min + COORDINATE_OFFSET; + + if (any(islessequal(intersect_size, ZERO2))) { + return ZERO; + } + + INPUT_TYPE intersect_area = intersect_size.x * intersect_size.y; + INPUT_TYPE bbox1_area = refined_box_areas[idx1]; + INPUT_TYPE bbox2_area = refined_box_areas[idx2]; + + return intersect_area / (bbox1_area + bbox2_area - intersect_area); +} + +inline void FUNC(nms_cf)(const __global INPUT_TYPE* refined_scores, + const __global INPUT_TYPE* refined_boxes, + const __global INPUT_TYPE* refined_box_areas, + size_t class_idx, + size_t roi_count, + __global ScoreClassIndex* score_class_index_map, + __global uint* detection_count) { + size_t count = 0; + for (size_t i = 0; i < roi_count; ++i) { + if (refined_scores[i] > SCORE_THRESHOLD) { + score_class_index_map[count] = (ScoreClassIndex){refined_scores[i], class_idx, i}; + count++; + } + } + + FUNC_CALL(quickSortIterative)(score_class_index_map, 0, count - 1); + + int detections = 0; + for (size_t i = 0; i < count; ++i) { + const size_t idx = score_class_index_map[i].box_idx; + + bool keep = true; + for (size_t k = 0; k < detections; ++k) { + const size_t kept_idx = score_class_index_map[k].box_idx; + INPUT_TYPE overlap = FUNC_CALL(jaccard_overlap)(refined_boxes, refined_box_areas, idx, kept_idx); + if (overlap > NMS_THRESHOLD) { + keep = false; + break; + } + } + if (keep) { + score_class_index_map[detections] = score_class_index_map[i]; + detections++; + } + } + + *detection_count = min(POST_NMS_COUNT, detections); +} + +KERNEL(eddo_ref_stage_1) +(const __global INPUT_TYPE* refined_scores, + const __global INPUT_TYPE* refined_boxes, + const __global INPUT_TYPE* refined_box_areas, + __global ScoreClassIndex* score_class_index_map, + __global uint* detection_count) { + size_t total_detections_num = 0; + // FIXME: figure out how to parallelize this!!! + for (int class_idx = 0; class_idx < NUM_CLASSES; ++class_idx) { + FUNC_CALL(nms_cf) + (&refined_scores[ROI_COUNT * class_idx], + &refined_boxes[ROI_COUNT * 4 * class_idx], + &refined_box_areas[ROI_COUNT * class_idx], + class_idx, + ROI_COUNT, + &score_class_index_map[total_detections_num], + detection_count); + total_detections_num += *detection_count; + } + + *detection_count = total_detections_num; +} + +#endif /* EDDO_STAGE_1_NMS */ + +#ifdef EDDO_STAGE_2_TOPK + +KERNEL(eddo_ref_stage_2) +(__global ScoreClassIndex* score_class_index_map, const __global uint* detection_count) { + if (*detection_count > MAX_DETECTIONS_PER_IMAGE) { + FUNC_CALL(quickSortIterative)(score_class_index_map, 0, *detection_count - 1); + } +} + +#endif /* EDDO_STAGE_2_TOPK */ + +#ifdef EDDO_STAGE_3_COPY_OUTPUT + +KERNEL(eddo_ref_stage_3) +(const __global ScoreClassIndex* score_class_index_map, + const __global uint* detection_count, + const __global INPUT_TYPE* refined_boxes, + __global OUTPUT_TYPE* output_boxes, + __global OUTPUT_INDICES_TYPE* output_classes, + __global OUTPUT_TYPE* output_scores) { + size_t i = get_global_id(0); + + if (i < *detection_count) { + OUTPUT_TYPE score = score_class_index_map[i].score; + OUTPUT_INDICES_TYPE cls = score_class_index_map[i].class_idx; + OUTPUT_INDICES_TYPE idx = score_class_index_map[i].box_idx; + vstore4(vload4(ROI_COUNT * cls + idx, refined_boxes), i, output_boxes); + output_scores[i] = score; + output_classes[i] = cls; + } else { + vstore4(ZERO4, i, output_boxes); + output_scores[i] = ZERO; + output_classes[i] = 0; + } +} + +#endif /* EDDO_STAGE_3_COPY_OUTPUT */ + +#undef INPUT_TYPE +#undef INPUT_TYPE2 +#undef INPUT_TYPE4 +#undef HALF_ONE +#undef ZERO +#undef ONE +#undef ZERO2 +#undef ZERO4 +#undef COORDINATE_OFFSET +#undef DELTA_WEIGHTS +#undef MAX_DELTA_LOG_SIZE diff --git a/src/plugins/intel_gpu/src/plugin/ops/experimental_detectron_detection_output.cpp b/src/plugins/intel_gpu/src/plugin/ops/experimental_detectron_detection_output.cpp new file mode 100644 index 00000000000..c3636564580 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/ops/experimental_detectron_detection_output.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph/op/experimental_detectron_detection_output.hpp" + +#include "intel_gpu/plugin/common_utils.hpp" +#include "intel_gpu/plugin/program.hpp" +#include "intel_gpu/primitives/experimental_detectron_detection_output.hpp" +#include "intel_gpu/primitives/mutable_data.hpp" + +namespace ov { +namespace runtime { +namespace intel_gpu { + +static void CreateExperimentalDetectronDetectionOutputOp( + Program& p, + const std::shared_ptr& op) { + p.ValidateInputs(op, {4}); + + if (op->get_output_size() != 3) { + IE_THROW() << "ExperimentalDetectronDetectionOutput requires 3 outputs"; + } + + auto inputs = p.GetInputPrimitiveIDs(op); + + const auto& attrs = op->get_attrs(); + + const auto op_friendly_name = op->get_friendly_name(); + + const auto layer_type_name = layer_type_name_ID(op); + const auto layer_name = layer_type_name + ".0"; + + const auto mutable_precision1 = op->get_output_element_type(1); + const auto output_shape1 = op->get_output_shape(1); + const cldnn::layout mutable_layout1{DataTypeFromPrecision(mutable_precision1), + DefaultFormatForDims(output_shape1.size()), + tensor_from_dims(output_shape1)}; + cldnn::memory::ptr shared_memory1{p.GetEngine().allocate_memory(mutable_layout1)}; + + const auto mutable_id_w1 = layer_type_name + "_md_write.1"; + const cldnn::mutable_data mutable_prim_w{mutable_id_w1, shared_memory1, op_friendly_name}; + p.primitiveIDs[mutable_id_w1] = mutable_id_w1; + p.AddPrimitive(mutable_prim_w); + inputs.push_back(mutable_id_w1); + + const auto mutable_precision2 = op->get_output_element_type(2); + const auto output_shape2 = op->get_output_shape(2); + const cldnn::layout mutable_layout2{DataTypeFromPrecision(mutable_precision2), + DefaultFormatForDims(output_shape2.size()), + tensor_from_dims(output_shape2)}; + cldnn::memory::ptr shared_memory2{p.GetEngine().allocate_memory(mutable_layout2)}; + + const auto mutable_id_w2 = layer_type_name + "_md_write.2"; + const cldnn::mutable_data mutable_prim_w2{mutable_id_w2, shared_memory2, op_friendly_name}; + p.primitiveIDs[mutable_id_w2] = mutable_id_w2; + p.AddPrimitive(mutable_prim_w2); + inputs.push_back(mutable_id_w2); + + const auto expectedPrimInputCount = 4 + 2; // 4 operation inputs plus 2 input-outputs + if (inputs.size() != expectedPrimInputCount) { + IE_THROW() << "experimental_detectron_detection_output primitive requires 6 inputs"; + } + + const cldnn::experimental_detectron_detection_output prim{layer_name, + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], // output classes + inputs[5], // output scores + attrs.score_threshold, + attrs.nms_threshold, + static_cast(attrs.num_classes), + static_cast(attrs.post_nms_count), + static_cast(attrs.max_detections_per_image), + attrs.class_agnostic_box_regression, + attrs.max_delta_log_wh, + attrs.deltas_weights, + op_friendly_name}; + + p.AddPrimitive(prim); + + const auto mutable_id_r1 = layer_type_name + ".1"; + const cldnn::mutable_data mutable_prim_r1{mutable_id_r1, {layer_name}, shared_memory1, op_friendly_name}; + p.primitiveIDs[mutable_id_r1] = mutable_id_r1; + p.AddPrimitive(mutable_prim_r1); + + const auto mutable_id_r2 = layer_type_name + ".2"; + const cldnn::mutable_data mutable_prim_r2{mutable_id_r2, {layer_name}, shared_memory2, op_friendly_name}; + p.primitiveIDs[mutable_id_r2] = mutable_id_r2; + p.AddPrimitive(mutable_prim_r2); + + p.AddPrimitiveToProfiler(prim, op); +} + +REGISTER_FACTORY_IMPL(v6, ExperimentalDetectronDetectionOutput); + +} // namespace intel_gpu +} // namespace runtime +} // namespace ov diff --git a/src/plugins/intel_gpu/tests/test_cases/experimental_detectron_detection_output_gpu_test.cpp b/src/plugins/intel_gpu/tests/test_cases/experimental_detectron_detection_output_gpu_test.cpp new file mode 100644 index 00000000000..d0787f62bc8 --- /dev/null +++ b/src/plugins/intel_gpu/tests/test_cases/experimental_detectron_detection_output_gpu_test.cpp @@ -0,0 +1,322 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "test_utils.h" + +using namespace cldnn; +using namespace ::tests; + +namespace { + +template +std::vector getValues(const std::vector& values) { + std::vector result(values.begin(), values.end()); + return result; +} + +template +float getError(); + +template <> +float getError() { + return 0.001; +} + +template <> +float getError() { + return 0.2; +} + +}; // namespace + +template +struct ExperimentalDetectronDetectionOutputParams { + float score_threshold; + float nms_threshold; + float max_delta_log_wh; + int num_classes; + int post_nms_count; + int max_detections_per_image; + bool class_agnostic_box_regression; + std::vector deltas_weights; + + size_t roi_count; + + std::vector boxes; + std::vector deltas; + std::vector scores; + std::vector im_info; + + std::vector expected_boxes; + std::vector expected_classes; + std::vector expected_scores; +}; + +template +struct experimental_detectron_detection_output_test + : public ::testing::TestWithParam> { +public: + void test() { + const ExperimentalDetectronDetectionOutputParams param = + testing::TestWithParam>::GetParam(); + auto data_type = type_to_data_type::value; + + auto& engine = get_test_engine(); + + const primitive_id input_boxes_id = "InputBoxes"; + const auto input_boxes = + engine.allocate_memory({data_type, format::bfyx, tensor{batch(param.roi_count), feature(4)}}); + set_values(input_boxes, param.boxes); + + const primitive_id input_deltas_id = "InputDeltas"; + auto input_deltas = engine.allocate_memory( + {data_type, format::bfyx, tensor{batch(param.roi_count), feature(param.num_classes * 4)}}); + set_values(input_deltas, param.deltas); + + const primitive_id input_scores_id = "InputScores"; + auto input_scores = engine.allocate_memory( + {data_type, format::bfyx, tensor{batch(param.roi_count), feature(param.num_classes)}}); + set_values(input_scores, param.scores); + + const primitive_id input_im_info_id = "InputImInfo"; + const auto input_im_info = engine.allocate_memory({data_type, format::bfyx, tensor{batch(1), feature(3)}}); + set_values(input_im_info, param.im_info); + + const primitive_id output_scores_id = "OutputScores"; + auto output_scores = + engine.allocate_memory({data_type, format::bfyx, tensor{batch(param.max_detections_per_image)}}); + + const primitive_id output_classes_id = "OutputClasses"; + auto output_classes = + engine.allocate_memory({data_types::i32, format::bfyx, tensor{batch(param.max_detections_per_image)}}); + + topology topology; + + topology.add(input_layout(input_boxes_id, input_boxes->get_layout())); + topology.add(input_layout(input_deltas_id, input_deltas->get_layout())); + topology.add(input_layout(input_scores_id, input_scores->get_layout())); + topology.add(input_layout(input_im_info_id, input_im_info->get_layout())); + topology.add(mutable_data(output_classes_id, output_classes)); + topology.add(mutable_data(output_scores_id, output_scores)); + + const primitive_id eddo_id = "experimental_detectron_detection_output"; + const auto eddo_primitive = experimental_detectron_detection_output{ + eddo_id, + input_boxes_id, + input_deltas_id, + input_scores_id, + input_im_info_id, + output_classes_id, + output_scores_id, + param.score_threshold, + param.nms_threshold, + param.num_classes, + param.post_nms_count, + param.max_detections_per_image, + param.class_agnostic_box_regression, + param.max_delta_log_wh, + param.deltas_weights, + }; + + topology.add(eddo_primitive); + + network network(engine, topology); + + network.set_input_data(input_boxes_id, input_boxes); + network.set_input_data(input_deltas_id, input_deltas); + network.set_input_data(input_scores_id, input_scores); + network.set_input_data(input_im_info_id, input_im_info); + + const auto outputs = network.execute(); + + const auto output_boxes = outputs.at(eddo_id).get_memory(); + + const cldnn::mem_lock output_boxes_ptr(output_boxes, get_test_stream()); + ASSERT_EQ(output_boxes_ptr.size(), param.max_detections_per_image * 4); + + const cldnn::mem_lock output_classes_ptr(output_classes, get_test_stream()); + ASSERT_EQ(output_classes_ptr.size(), param.max_detections_per_image); + + const cldnn::mem_lock output_scores_ptr(output_scores, get_test_stream()); + ASSERT_EQ(output_scores_ptr.size(), param.max_detections_per_image); + + const auto& expected_boxes = param.expected_boxes; + const auto& expected_classes = param.expected_classes; + const auto& expected_scores = param.expected_scores; + for (size_t i = 0; i < param.max_detections_per_image; ++i) { + EXPECT_NEAR(expected_scores[i], output_scores_ptr[i], 0.001) << "i=" << i; + for (size_t coord = 0; coord < 4; ++coord) { + const auto roi_idx = i * 4 + coord; + EXPECT_NEAR(expected_boxes[roi_idx], output_boxes_ptr[roi_idx], getError()) + << "i=" << i << ", coord=" << coord; + } + + EXPECT_EQ(expected_classes[i], output_classes_ptr[i]) << "i=" << i; + } + } +}; + +using experimental_detectron_detection_output_test_f32 = experimental_detectron_detection_output_test; +using experimental_detectron_detection_output_test_f16 = experimental_detectron_detection_output_test; + +TEST_P(experimental_detectron_detection_output_test_f32, basic) { + ASSERT_NO_FATAL_FAILURE(test()); +} + +TEST_P(experimental_detectron_detection_output_test_f16, basic) { + ASSERT_NO_FATAL_FAILURE(test()); +} + +template +std::vector> getExperimentalDetectronDetectionOutputParams() { + std::vector> params = { + { + 0.01000000074505806f, // score_threshold + 0.2f, // nms_threshold + 2.0f, // max_delta_log_wh + 2, // num_classes + 500, // post_nms_count + 5, // max_detections_per_image + true, // class_agnostic_box_regression + {10.0f, 10.0f, 5.0f, 5.0f}, // deltas_weights + 16, // roi count + + // boxes + getValues({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, + 1.0f, 8.0f, 5.0f, 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}), + + getValues({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}), + + getValues({0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}), + + getValues({16.0f, 12.0f, 1.0f}), + getValues({4.8929863f, 0.892986298f, 12.0f, 12.1070137f, 0.0f, 0.892986298f, 10.1070137f, + 12.1070137f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}), + std::vector{0, 1, 0, 0, 0}, + getValues({0.8f, 0.9f, 0.0f, 0.0f, 0.0f}), + }, + { + 0.0500000007, // score_threshold + 0.5, // nms_threshold + 4.13516665, // max_delta_log_wh + 5, // num_classes + 10, // post_nms_count + 10, // max_detections_per_image + false, // class_agnostic_box_regression + {10.0f, 10.0f, 5.0f, 5.0f}, // deltas_weights + 10, // roi count + + // boxes + getValues({ + 4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188, + 6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281, + 0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125, + 6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719, + }), + + // deltas + getValues({ + 4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188, + 6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281, + 0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125, + 6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719, + 3.74609, 4.54688, 5.83203, 5.91406, 2.85547, 7.46875, 4.31641, 2.71094, 9.71484, 1.14062, + 6.55078, 0.257812, 4.32422, 9.5625, 8.53516, 0.554688, 8.68359, 2.73438, 6.26953, 5.60156, + 2.79297, 8.65625, 5.75391, 5.39844, 2.65234, 7.32812, 8.98828, 7.94531, 6.26172, 4.75, + 7.97266, 1.24219, 5.62109, 8.92188, 2.70703, 1.28906, 4.73047, 7.84375, 5.19141, 6.08594, + 7.58984, 9.51562, 7.42578, 5.63281, 6.19922, 7.9375, 5.41016, 9.92969, 2.55859, 1.10938, + 1.14453, 8.97656, 4.66797, 9.03125, 4.62891, 0.773438, 4.52734, 1.70312, 9.86328, 1.32031, + 0.136719, 9.125, 2.84766, 4.61719, 9.49609, 5.29688, 5.58203, 0.664062, 2.60547, 6.21875, + 8.06641, 5.46094, 1.46484, 7.89062, 0.300781, 5.00781, 0.0742188, 0.3125, 6.28516, 3.30469, + 4.43359, 1.48438, 2.01953, 8.35156, 8.54297, 7.40625, 9.50391, 2.14844, 2.40234, 2.07812, + 2.73828, 2.69531, 4.01172, 9.5, 7.72266, 9.99219, 1.37109, 3.67188, 2.45703, 2.03906, + 0.480469, 4.59375, 2.94141, 4.83594, 1.33984, 0.265625, 1.17578, 4.38281, 5.94922, 8.6875, + 5.16016, 0.679688, 4.30859, 5.85938, 4.89453, 7.72656, 4.41797, 5.78125, 4.37891, 1.52344, + 8.27734, 4.45312, 3.61328, 4.07031, 7.88672, 9.875, 4.59766, 1.36719, 7.24609, 8.04688, + 5.33203, 5.41406, 4.35547, 0.96875, 1.81641, 8.21094, 3.21484, 4.64062, 4.05078, 9.75781, + 7.82422, 3.0625, 4.03516, 0.0546875, 8.18359, 8.23438, 1.76953, 1.10156, 2.29297, 8.15625, + 9.25391, 0.898438, 6.15234, 8.82812, 6.48828, 7.44531, 1.76172, 2.25, 9.47266, 0.742188, + }), + + // scores + getValues({ + 4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188, + 6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281, + 0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125, + 6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719, + 3.74609, 4.54688, 5.83203, 5.91406, 2.85547, 7.46875, 4.31641, 2.71094, 9.71484, 1.14062, + }), + + // im_info + getValues({ + 4.90234, + 6.57812, + 5.23828, + }), + + // out_boxes + getValues({ + 0, 2.97829, 6.57812, 4.90234, 0, 4.90234, 6.57812, 4.90234, 4.37184, 4.90234, + 6.03075, 4.90234, 5.95093, 3.66966, 6.57812, 4.90234, 0, 4.90234, 6.57812, 4.90234, + 1.31075, 4.90234, 6.57812, 4.90234, 3.24829, 4.90234, 6.57812, 4.90234, 0, 0, + 6.57812, 4.90234, 4.20346, 0, 6.57812, 4.90234, 0, 0, 6.57812, 4.90234, + }), + + // out_classes + std::vector({ + 4, + 3, + 3, + 4, + 2, + 0, + 1, + 0, + 2, + 3, + }), + + // out_scores + getValues({ + 9.86719, + 9.71484, + 9.19531, + 8.51172, + 8.375, + 7.46875, + 6.57812, + 6.57031, + 5.23828, + 5.09766, + }), + }, + }; + return params; +} + +INSTANTIATE_TEST_SUITE_P(experimental_detectron_detection_output_gpu_test, + experimental_detectron_detection_output_test_f32, + ::testing::ValuesIn(getExperimentalDetectronDetectionOutputParams())); + +INSTANTIATE_TEST_SUITE_P(experimental_detectron_detection_output_gpu_test, + experimental_detectron_detection_output_test_f16, + ::testing::ValuesIn(getExperimentalDetectronDetectionOutputParams())); diff --git a/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/experimental_detectron_detection_output.cpp b/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/experimental_detectron_detection_output.cpp new file mode 100644 index 00000000000..4a1a4b8b67a --- /dev/null +++ b/src/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/experimental_detectron_detection_output.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "single_layer_tests/experimental_detectron_detection_output.hpp" + +#include + +#include "common_test_utils/ov_tensor_utils.hpp" + +using namespace ov::test; +using namespace ov::test::subgraph; + +namespace { + +const std::vector netPrecisions = { + ov::element::Type_t::f16, + ov::element::Type_t::f32, +}; + +const std::vector score_threshold = {0.01f, 0.8f}; + +const std::vector nms_threshold = {0.2f, 0.5f}; + +// specifies maximal delta of logarithms for width and height +const std::vector max_delta_log_wh = {2.0f, 5.0f}; + +// specifies number of detected classes +const std::vector num_classes = {2}; + +// specifies maximal number of detections per class +const std::vector post_nms_count = {5, 25}; + +// specifies maximual number of detections per image +const std::vector max_detections_per_image = {5, 25}; + +// a flag specifies whether to delete background classes or not +// `true` means background classes should be deleted, +// `false` means background classes shouldn't be deleted. +const std::vector class_agnostic_box_regression = {true, false}; + +// specifies deltas of weights +const std::vector> deltas_weights = {{10.0f, 10.0f, 5.0f, 5.0f}}; + +const std::vector> inputShapes = { + // inputRois / inputDeltas / inputScores / inputImInfos + static_shapes_to_test_representation({{16, 4}, {16, 8}, {16, 2}, {1, 3}}), +}; + +INSTANTIATE_TEST_SUITE_P(smoke_ExperimentalDetectronDetectionOutput, + ExperimentalDetectronDetectionOutputLayerTest, + ::testing::Combine(::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(score_threshold), + ::testing::ValuesIn(nms_threshold), + ::testing::ValuesIn(max_delta_log_wh), + ::testing::ValuesIn(num_classes), + ::testing::ValuesIn(post_nms_count), + ::testing::ValuesIn(max_detections_per_image), + ::testing::ValuesIn(class_agnostic_box_regression), + ::testing::ValuesIn(deltas_weights), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU)), + ExperimentalDetectronDetectionOutputLayerTest::getTestCaseName); + + +} // namespace diff --git a/src/tests/functional/shared_test_classes/src/single_layer/experimental_detectron_detection_output.cpp b/src/tests/functional/shared_test_classes/src/single_layer/experimental_detectron_detection_output.cpp index f4314827338..de81d356673 100644 --- a/src/tests/functional/shared_test_classes/src/single_layer/experimental_detectron_detection_output.cpp +++ b/src/tests/functional/shared_test_classes/src/single_layer/experimental_detectron_detection_output.cpp @@ -79,6 +79,9 @@ void ExperimentalDetectronDetectionOutputLayerTest::SetUp() { netPrecision, targetName) = this->GetParam(); + if (netPrecision == element::f16) + abs_threshold = 0.01; + inType = outType = netPrecision; targetDevice = targetName; @@ -93,42 +96,66 @@ void ExperimentalDetectronDetectionOutputLayerTest::SetUp() { params[2], // input_scores params[3], // input_im_info attributes); - function = std::make_shared( - ov::OutputVector{experimentalDetectron->output(0), experimentalDetectron->output(1)}, - "ExperimentalDetectronDetectionOutput"); + function = std::make_shared(ov::OutputVector{experimentalDetectron->output(0), + experimentalDetectron->output(1), + experimentalDetectron->output(2)}, + "ExperimentalDetectronDetectionOutput"); } +namespace { + +template +std::vector getValues(const std::vector& values) { + std::vector result(values.begin(), values.end()); + return result; +} + +template +std::vector generateInputTensors() { + const auto netPrecision = ov::element::from(); + std::vector inputTensors = { + // 16 x 4 = 64 + ov::test::utils::create_tensor( + netPrecision, + Shape{16, 4}, + getValues({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, + 1.0f, 8.0f, 5.0f, 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})), + // 16 x 8 + ov::test::utils::create_tensor( + netPrecision, + Shape{16, 8}, + getValues({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})), + // 16 x 2 = 32 + ov::test::utils::create_tensor( + netPrecision, + Shape{16, 2}, + getValues({0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f})), + // 1 x 3 = 3 + ov::test::utils::create_tensor(netPrecision, Shape{1, 3}, getValues({16.0f, 12.0f, 1.0f}))}; + + return inputTensors; +} +} // namespace + void ExperimentalDetectronDetectionOutputLayerTest::generate_inputs( const std::vector& targetInputStaticShapes) { - static const std::vector inputTensors = { - // 16 x 4 = 64 - ov::test::utils::create_tensor( - ov::element::f32, - Shape{16, 4}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 8.0f, 5.0f, - 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}), - // 16 x 8 - ov::test::utils::create_tensor( - ov::element::f32, - Shape{16, 8}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}), - // 16 x 2 = 32 - ov::test::utils::create_tensor( - ov::element::f32, - Shape{16, 2}, - {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}), - // 1 x 3 = 3 - ov::test::utils::create_tensor(ov::element::f32, Shape{1, 3}, {16.0f, 12.0f, 1.0f})}; + const auto netPrecision = std::get<9>(GetParam()); + + const std::vector inputTensors = + (netPrecision == element::f16) ? generateInputTensors() : generateInputTensors(); inputs.clear(); const auto& funcInputs = function->inputs();