[GPU] Implement ExperimentalDetectronDetectionOutput operation (#11772)
* ExperimentalDetectronDetectionOutput: refine sorting criteria for NMS stage This is to ensure the operation produces stable predictable results across the possible sorting algorithm implementaions. This property is useful for the operation testing. * [GPU] Implement ExperimentalDetectronDetectionOutput operation * [GPU] ExperimentalDetectronDetectionOutput: use vector types and operations in kernel * Reformat changed files to make clang format checker happy * [GPU] ExperimentalDetectronDetectionOutput: add another test case to the unit test * [GPU] ExperimentalDetectronDetectionOutput: Add f16 test * ExperimentalDetectronDetectionOutput: single-layer test: use all three outputs * [GPU] ExperimentalDetectronDetectionOutput: increase single layer test coverage More attribute permutations were added.
This commit is contained in:
parent
95a297ed68
commit
8a21e4e062
@ -203,7 +203,7 @@ void nms_cf(const float* conf_data,
|
||||
|
||||
template <typename T>
|
||||
bool SortScorePairDescend(const std::pair<float, T>& pair1, const std::pair<float, T>& pair2) {
|
||||
return pair1.first > pair2.first;
|
||||
return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -118,9 +118,9 @@ static void refine_boxes(const float* boxes,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool SortScorePairDescend(const std::pair<float, T>& pair1, const std::pair<float, T>& pair2) {
|
||||
return pair1.first > pair2.first;
|
||||
static bool SortScorePairDescend(const std::pair<float, std::pair<int, int>>& pair1,
|
||||
const std::pair<float, std::pair<int, int>>& pair2) {
|
||||
return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second));
|
||||
}
|
||||
|
||||
struct ConfidenceComparator {
|
||||
@ -362,7 +362,7 @@ void ExperimentalDetectronDetectionOutput::execute(dnnl::stream strm) {
|
||||
std::partial_sort(conf_index_class_map.begin(),
|
||||
conf_index_class_map.begin() + max_detections_per_image_,
|
||||
conf_index_class_map.end(),
|
||||
SortScorePairDescend<std::pair<int, int>>);
|
||||
SortScorePairDescend);
|
||||
conf_index_class_map.resize(max_detections_per_image_);
|
||||
total_detections_num = max_detections_per_image_;
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#ifndef REGISTER_FACTORY
|
||||
#error "REGISTER_FACTORY is not defined"
|
||||
# error "REGISTER_FACTORY is not defined"
|
||||
#endif
|
||||
|
||||
// ------------------------------ Supported v0 ops ------------------------------ //
|
||||
@ -190,7 +190,7 @@ REGISTER_FACTORY(v4, Swish);
|
||||
REGISTER_FACTORY(v5, HSigmoid);
|
||||
REGISTER_FACTORY(v5, LogSoftmax);
|
||||
REGISTER_FACTORY(v5, LSTMSequence);
|
||||
//REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
|
||||
// REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
|
||||
REGISTER_FACTORY(v5, Round);
|
||||
REGISTER_FACTORY(v5, GatherND);
|
||||
REGISTER_FACTORY(v5, Loop);
|
||||
@ -208,6 +208,7 @@ REGISTER_FACTORY(v6, ExperimentalDetectronPriorGridGenerator);
|
||||
REGISTER_FACTORY(v6, ExperimentalDetectronROIFeatureExtractor);
|
||||
REGISTER_FACTORY(v6, ExperimentalDetectronTopKROIs)
|
||||
REGISTER_FACTORY(v6, ExperimentalDetectronGenerateProposalsSingleImage);
|
||||
REGISTER_FACTORY(v6, ExperimentalDetectronDetectionOutput);
|
||||
|
||||
// ------------------------------ Supported v7 ops ------------------------------ //
|
||||
REGISTER_FACTORY(v7, DFT);
|
||||
|
@ -0,0 +1,98 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#pragma once
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "primitive.hpp"
|
||||
|
||||
namespace cldnn {
|
||||
/// @addtogroup cpp_api C++ API
|
||||
/// @{
|
||||
/// @addtogroup cpp_topology Network Topology
|
||||
/// @{
|
||||
/// @addtogroup cpp_primitives Primitives
|
||||
/// @{
|
||||
|
||||
/// @brief experimental detectron detection output
|
||||
struct experimental_detectron_detection_output : public primitive_base<experimental_detectron_detection_output> {
|
||||
CLDNN_DECLARE_PRIMITIVE(experimental_detectron_detection_output)
|
||||
|
||||
/// @brief Constructs experimental_detectron_detection_output primitive
|
||||
/// @param id This primitive id
|
||||
/// @param input_rois input rois
|
||||
/// @param input_deltas input deltas
|
||||
/// @param input_scores input scores
|
||||
/// @param input_im_info image info
|
||||
/// @param output_classes ROI scores
|
||||
/// @param output_scores minimum box width and height
|
||||
/// @param score_threshold a threshold to consider only detections whose score are larger than the threshold
|
||||
/// @param nms_threshold a threshold to be used in the NMS stage
|
||||
/// @param num_classes the number of detected classes
|
||||
/// @param post_nms_count the maximum number of detections per class
|
||||
/// @param max_detections_per_image the maximum number of detections per image
|
||||
/// @param class_agnostic_box_regression specifies whether to delete background classes or not
|
||||
/// @param max_delta_log_wh the maximum delta of logarithms for width and height
|
||||
/// @param deltas_weights the weights for bounding boxes sizes deltas
|
||||
experimental_detectron_detection_output(const primitive_id& id,
|
||||
const primitive_id& input_rois,
|
||||
const primitive_id& input_deltas,
|
||||
const primitive_id& input_scores,
|
||||
const primitive_id& input_im_info,
|
||||
const primitive_id& output_classes,
|
||||
const primitive_id& output_scores,
|
||||
float score_threshold,
|
||||
float nms_threshold,
|
||||
int num_classes,
|
||||
int post_nms_count,
|
||||
int max_detections_per_image,
|
||||
bool class_agnostic_box_regression,
|
||||
float max_delta_log_wh,
|
||||
std::vector<float> deltas_weights,
|
||||
const primitive_id& ext_prim_id = "",
|
||||
const padding& output_padding = {})
|
||||
: primitive_base{id,
|
||||
{input_rois, input_deltas, input_scores, input_im_info, output_classes, output_scores},
|
||||
ext_prim_id,
|
||||
output_padding},
|
||||
output_classes{output_classes},
|
||||
output_scores{output_scores},
|
||||
score_threshold{score_threshold},
|
||||
nms_threshold{nms_threshold},
|
||||
num_classes{num_classes},
|
||||
post_nms_count{post_nms_count},
|
||||
class_agnostic_box_regression{class_agnostic_box_regression},
|
||||
max_detections_per_image{max_detections_per_image},
|
||||
max_delta_log_wh{max_delta_log_wh},
|
||||
deltas_weights{std::move(deltas_weights)} {}
|
||||
|
||||
primitive_id output_classes;
|
||||
primitive_id output_scores;
|
||||
float score_threshold;
|
||||
float nms_threshold;
|
||||
int num_classes;
|
||||
int post_nms_count;
|
||||
int max_detections_per_image;
|
||||
bool class_agnostic_box_regression;
|
||||
float max_delta_log_wh;
|
||||
std::vector<float> deltas_weights;
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
if (!output_classes.empty())
|
||||
ret.emplace_back(output_classes);
|
||||
|
||||
if (!output_scores.empty())
|
||||
ret.emplace_back(output_scores);
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
/// @}
|
||||
} // namespace cldnn
|
@ -0,0 +1,49 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "experimental_detectron_detection_output_inst.hpp"
|
||||
#include "intel_gpu/runtime/error_handler.hpp"
|
||||
#include "json_object.h"
|
||||
#include "primitive_type_base.h"
|
||||
|
||||
namespace cldnn {
|
||||
primitive_type_id experimental_detectron_detection_output::type_id() {
|
||||
static primitive_type_base<experimental_detectron_detection_output> instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
layout experimental_detectron_detection_output_inst::calc_output_layout(
|
||||
const experimental_detectron_detection_output_node& node) {
|
||||
const layout data_layout = node.input().get_output_layout();
|
||||
auto desc = node.get_primitive();
|
||||
|
||||
return layout(data_layout.data_type, format::bfyx, {static_cast<int>(desc->max_detections_per_image), 4, 1, 1});
|
||||
}
|
||||
|
||||
std::string experimental_detectron_detection_output_inst::to_string(
|
||||
const experimental_detectron_detection_output_node& node) {
|
||||
auto desc = node.get_primitive();
|
||||
|
||||
std::stringstream primitive_description;
|
||||
|
||||
json_composite ed_info;
|
||||
ed_info.add("score_threshold", desc->score_threshold);
|
||||
ed_info.add("nms_threshold", desc->nms_threshold);
|
||||
ed_info.add("score_threshold", desc->score_threshold);
|
||||
ed_info.add("max_delta_log_wh", desc->max_delta_log_wh);
|
||||
ed_info.add("num_classes", desc->num_classes);
|
||||
ed_info.add("post_nms_count", desc->post_nms_count);
|
||||
ed_info.add("max_detections_per_image", desc->max_detections_per_image);
|
||||
ed_info.add("class_agnostic_box_regression", desc->class_agnostic_box_regression);
|
||||
ed_info.add("deltas_weights", desc->deltas_weights);
|
||||
|
||||
auto node_info = node.desc_to_json();
|
||||
node_info->add("experimental_detectron_detection_output_info", ed_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
||||
return primitive_description.str();
|
||||
}
|
||||
} // namespace cldnn
|
@ -0,0 +1,80 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "eddo/experimental_detectron_detection_output_kernel_ref.h"
|
||||
#include "eddo/experimental_detectron_detection_output_kernel_selector.h"
|
||||
#include "experimental_detectron_detection_output_inst.hpp"
|
||||
#include "impls/implementation_map.hpp"
|
||||
#include "kernel_selector_helper.h"
|
||||
#include "primitive_base.hpp"
|
||||
|
||||
namespace cldnn {
|
||||
namespace ocl {
|
||||
struct experimental_detectron_detection_output_impl
|
||||
: public typed_primitive_impl_ocl<experimental_detectron_detection_output> {
|
||||
using parent = typed_primitive_impl_ocl<experimental_detectron_detection_output>;
|
||||
using parent::parent;
|
||||
|
||||
std::unique_ptr<primitive_impl> clone() const override {
|
||||
return make_unique<experimental_detectron_detection_output_impl>(*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
kernel_arguments_data get_arguments(typed_primitive_inst<experimental_detectron_detection_output>& instance,
|
||||
int32_t unused) const override {
|
||||
kernel_arguments_data args = parent::get_arguments(instance, unused);
|
||||
args.inputs.push_back(instance.output_classes_memory());
|
||||
args.inputs.push_back(instance.output_scores_memory());
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const experimental_detectron_detection_output_node& arg) {
|
||||
auto params = get_default_params<kernel_selector::experimental_detectron_detection_output_params>(arg);
|
||||
auto optional_params =
|
||||
get_default_optional_params<kernel_selector::experimental_detectron_detection_output_optional_params>(
|
||||
arg.get_program());
|
||||
|
||||
const auto& primitive = arg.get_primitive();
|
||||
|
||||
params.score_threshold = primitive->score_threshold;
|
||||
params.nms_threshold = primitive->nms_threshold;
|
||||
params.max_delta_log_wh = primitive->max_delta_log_wh;
|
||||
params.num_classes = primitive->num_classes;
|
||||
params.post_nms_count = primitive->post_nms_count;
|
||||
params.max_detections_per_image = primitive->max_detections_per_image;
|
||||
params.class_agnostic_box_regression = primitive->class_agnostic_box_regression;
|
||||
params.deltas_weights = primitive->deltas_weights;
|
||||
|
||||
params.inputs.push_back(convert_data_tensor(arg.deltas().get_output_layout()));
|
||||
params.inputs.push_back(convert_data_tensor(arg.scores().get_output_layout()));
|
||||
params.inputs.push_back(convert_data_tensor(arg.image_size_info().get_output_layout()));
|
||||
|
||||
params.inputs.push_back(convert_data_tensor(arg.output_classes_node().get_output_layout()));
|
||||
params.inputs.push_back(convert_data_tensor(arg.output_scores_node().get_output_layout()));
|
||||
|
||||
const auto& kernel_selector =
|
||||
kernel_selector::experimental_detectron_detection_output_kernel_selector::Instance();
|
||||
const auto best_kernels = kernel_selector.GetBestKernels(params, optional_params);
|
||||
|
||||
CLDNN_ERROR_BOOL(arg.id(),
|
||||
"best_kernels.empty()",
|
||||
best_kernels.empty(),
|
||||
"Cannot find a proper kernel with this arguments");
|
||||
|
||||
return new experimental_detectron_detection_output_impl(arg, best_kernels[0]);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
attach_experimental_detectron_detection_output_impl::attach_experimental_detectron_detection_output_impl() {
|
||||
implementation_map<experimental_detectron_detection_output>::add(
|
||||
impl_types::ocl,
|
||||
experimental_detectron_detection_output_impl::create,
|
||||
{std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::f32, format::bfyx)});
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace ocl
|
||||
} // namespace cldnn
|
@ -8,8 +8,7 @@
|
||||
namespace cldnn {
|
||||
namespace ocl {
|
||||
|
||||
#define REGISTER_OCL(prim) \
|
||||
static detail::attach_##prim##_impl attach_##prim
|
||||
#define REGISTER_OCL(prim) static detail::attach_##prim##_impl attach_##prim
|
||||
|
||||
void register_implementations() {
|
||||
REGISTER_OCL(activation);
|
||||
@ -31,6 +30,7 @@ void register_implementations() {
|
||||
REGISTER_OCL(detection_output);
|
||||
REGISTER_OCL(dft);
|
||||
REGISTER_OCL(batch_to_space);
|
||||
REGISTER_OCL(experimental_detectron_detection_output);
|
||||
REGISTER_OCL(experimental_detectron_generate_proposals_single_image);
|
||||
REGISTER_OCL(experimental_detectron_prior_grid_generator);
|
||||
REGISTER_OCL(experimental_detectron_roi_feature_extractor);
|
||||
|
@ -5,6 +5,7 @@
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#pragma once
|
||||
|
||||
#include "generic_layer.hpp"
|
||||
#include "intel_gpu/primitives/activation.hpp"
|
||||
#include "intel_gpu/primitives/arg_max_min.hpp"
|
||||
#include "intel_gpu/primitives/average_unpooling.hpp"
|
||||
@ -14,8 +15,10 @@
|
||||
#include "intel_gpu/primitives/broadcast.hpp"
|
||||
#include "intel_gpu/primitives/bucketize.hpp"
|
||||
#include "intel_gpu/primitives/concatenation.hpp"
|
||||
#include "intel_gpu/primitives/convert_color.hpp"
|
||||
#include "intel_gpu/primitives/convolution.hpp"
|
||||
#include "intel_gpu/primitives/crop.hpp"
|
||||
#include "intel_gpu/primitives/ctc_greedy_decoder.hpp"
|
||||
#include "intel_gpu/primitives/custom_gpu_primitive.hpp"
|
||||
#include "intel_gpu/primitives/deconvolution.hpp"
|
||||
#include "intel_gpu/primitives/depth_to_space.hpp"
|
||||
@ -26,12 +29,16 @@
|
||||
#include "intel_gpu/primitives/experimental_detectron_topk_rois.hpp"
|
||||
#include "intel_gpu/primitives/fully_connected.hpp"
|
||||
#include "intel_gpu/primitives/gather.hpp"
|
||||
#include "intel_gpu/primitives/gather_nd.hpp"
|
||||
#include "intel_gpu/primitives/gather_elements.hpp"
|
||||
#include "intel_gpu/primitives/gather_nd.hpp"
|
||||
#include "intel_gpu/primitives/gather_tree.hpp"
|
||||
#include "intel_gpu/primitives/gemm.hpp"
|
||||
#include "intel_gpu/primitives/grn.hpp"
|
||||
#include "intel_gpu/primitives/lrn.hpp"
|
||||
#include "intel_gpu/primitives/lstm.hpp"
|
||||
#include "intel_gpu/primitives/lstm_dynamic.hpp"
|
||||
#include "intel_gpu/primitives/lstm_dynamic_input.hpp"
|
||||
#include "intel_gpu/primitives/lstm_dynamic_timeloop.hpp"
|
||||
#include "intel_gpu/primitives/max_unpooling.hpp"
|
||||
#include "intel_gpu/primitives/mutable_data.hpp"
|
||||
#include "intel_gpu/primitives/mvn.hpp"
|
||||
@ -48,15 +55,16 @@
|
||||
#include "intel_gpu/primitives/region_yolo.hpp"
|
||||
#include "intel_gpu/primitives/reorder.hpp"
|
||||
#include "intel_gpu/primitives/reorg_yolo.hpp"
|
||||
#include "intel_gpu/primitives/resample.hpp"
|
||||
#include "intel_gpu/primitives/reshape.hpp"
|
||||
#include "intel_gpu/primitives/reverse_sequence.hpp"
|
||||
#include "intel_gpu/primitives/roi_align.hpp"
|
||||
#include "intel_gpu/primitives/roi_pooling.hpp"
|
||||
#include "intel_gpu/primitives/roll.hpp"
|
||||
#include "intel_gpu/primitives/scale.hpp"
|
||||
#include "intel_gpu/primitives/scatter_update.hpp"
|
||||
#include "intel_gpu/primitives/scatter_elements_update.hpp"
|
||||
#include "intel_gpu/primitives/scatter_nd_update.hpp"
|
||||
#include "intel_gpu/primitives/scatter_update.hpp"
|
||||
#include "intel_gpu/primitives/select.hpp"
|
||||
#include "intel_gpu/primitives/shape_of.hpp"
|
||||
#include "intel_gpu/primitives/shuffle_channels.hpp"
|
||||
@ -65,15 +73,6 @@
|
||||
#include "intel_gpu/primitives/space_to_batch.hpp"
|
||||
#include "intel_gpu/primitives/strided_slice.hpp"
|
||||
#include "intel_gpu/primitives/tile.hpp"
|
||||
#include "intel_gpu/primitives/resample.hpp"
|
||||
#include "intel_gpu/primitives/gather_tree.hpp"
|
||||
#include "intel_gpu/primitives/lstm_dynamic_input.hpp"
|
||||
#include "intel_gpu/primitives/lstm_dynamic_timeloop.hpp"
|
||||
#include "intel_gpu/primitives/grn.hpp"
|
||||
#include "intel_gpu/primitives/ctc_greedy_decoder.hpp"
|
||||
#include "intel_gpu/primitives/convert_color.hpp"
|
||||
#include "generic_layer.hpp"
|
||||
|
||||
|
||||
namespace cldnn {
|
||||
namespace ocl {
|
||||
@ -106,6 +105,7 @@ REGISTER_OCL(deformable_interp);
|
||||
REGISTER_OCL(depth_to_space);
|
||||
REGISTER_OCL(detection_output);
|
||||
REGISTER_OCL(dft);
|
||||
REGISTER_OCL(experimental_detectron_detection_output);
|
||||
REGISTER_OCL(experimental_detectron_generate_proposals_single_image);
|
||||
REGISTER_OCL(experimental_detectron_prior_grid_generator);
|
||||
REGISTER_OCL(experimental_detectron_roi_feature_extractor);
|
||||
|
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#pragma once
|
||||
#include "intel_gpu/primitives/experimental_detectron_detection_output.hpp"
|
||||
#include "primitive_inst.h"
|
||||
|
||||
namespace cldnn {
|
||||
template <>
|
||||
struct typed_program_node<experimental_detectron_detection_output>
|
||||
: public typed_program_node_base<experimental_detectron_detection_output> {
|
||||
using parent = typed_program_node_base<experimental_detectron_detection_output>;
|
||||
|
||||
public:
|
||||
using parent::parent;
|
||||
|
||||
program_node& input() const {
|
||||
return get_dependency(0);
|
||||
}
|
||||
|
||||
program_node& deltas() const {
|
||||
return get_dependency(1);
|
||||
}
|
||||
program_node& scores() const {
|
||||
return get_dependency(2);
|
||||
}
|
||||
program_node& image_size_info() const {
|
||||
return get_dependency(3);
|
||||
}
|
||||
|
||||
program_node& output_classes_node() const {
|
||||
return get_dependency(4);
|
||||
}
|
||||
program_node& output_scores_node() const {
|
||||
return get_dependency(5);
|
||||
}
|
||||
};
|
||||
|
||||
using experimental_detectron_detection_output_node = typed_program_node<experimental_detectron_detection_output>;
|
||||
|
||||
template <>
|
||||
class typed_primitive_inst<experimental_detectron_detection_output>
|
||||
: public typed_primitive_inst_base<experimental_detectron_detection_output> {
|
||||
using parent = typed_primitive_inst_base<experimental_detectron_detection_output>;
|
||||
|
||||
public:
|
||||
static layout calc_output_layout(const experimental_detectron_detection_output_node& node);
|
||||
static std::string to_string(const experimental_detectron_detection_output_node& node);
|
||||
|
||||
typed_primitive_inst(network& network, const experimental_detectron_detection_output_node& node)
|
||||
: parent(network, node) {}
|
||||
|
||||
memory::ptr output_classes_memory() const {
|
||||
return dep_memory_ptr(4);
|
||||
}
|
||||
memory::ptr output_scores_memory() const {
|
||||
return dep_memory_ptr(5);
|
||||
}
|
||||
};
|
||||
|
||||
using experimental_detectron_detection_output_inst = typed_primitive_inst<experimental_detectron_detection_output>;
|
||||
|
||||
} // namespace cldnn
|
@ -79,6 +79,7 @@ enum class KernelType {
|
||||
LOOP,
|
||||
NON_MAX_SUPPRESSION,
|
||||
DETECTION_OUTPUT,
|
||||
EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT,
|
||||
EXPERIMENTAL_DETECTRON_GENERATE_PROPOSALS_SINGLE_IMAGE,
|
||||
EXPERIMENTAL_DETECTRON_PRIOR_GRID_GENERATOR,
|
||||
EXPERIMENTAL_DETECTRON_ROI_FEATURE_EXTRACTOR,
|
||||
|
@ -0,0 +1,199 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "experimental_detectron_detection_output_kernel_ref.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "kernel_selector_utils.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
ParamsKey ExperimentalDetectronDetectionOutputKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableInputDataType(Datatype::INT64);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::INT64);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
return k;
|
||||
}
|
||||
|
||||
KernelsPriority ExperimentalDetectronDetectionOutputKernelRef::GetKernelsPriority(const Params&,
|
||||
const optional_params&) const {
|
||||
return DONT_USE_IF_HAVE_SOMETHING_ELSE;
|
||||
}
|
||||
|
||||
bool ExperimentalDetectronDetectionOutputKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||
if (p.GetType() != KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT ||
|
||||
o.GetType() != KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
constexpr int kBoxesInputIdx = 0;
|
||||
constexpr int kDeltasInputIdx = 1;
|
||||
constexpr int kScoresInputIdx = 2;
|
||||
constexpr int kImInfoInputIdx = 3;
|
||||
constexpr int kOutputClassesInputIdx = 4;
|
||||
constexpr int kOutputScoresInputIdx = 5;
|
||||
|
||||
constexpr int kRefinedBoxesBufferIdx = 0;
|
||||
constexpr int kRefinedBoxAreasBufferIdx = 1;
|
||||
constexpr int kRefinedScoresBufferIdx = 2;
|
||||
constexpr int kScoreClassIndexBufferIdx = 3;
|
||||
constexpr int kDetectionCountBufferIdx = 4;
|
||||
|
||||
constexpr int kBufferCount = 5;
|
||||
|
||||
constexpr int kOutputIdx = 0;
|
||||
|
||||
JitConstants ExperimentalDetectronDetectionOutputKernelRef::GetJitConstants(
|
||||
const experimental_detectron_detection_output_params& params) const {
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("SCORE_THRESHOLD", params.score_threshold),
|
||||
MakeJitConstant("NMS_THRESHOLD", params.nms_threshold),
|
||||
MakeJitConstant("NUM_CLASSES", params.num_classes),
|
||||
MakeJitConstant("POST_NMS_COUNT", params.post_nms_count),
|
||||
MakeJitConstant("MAX_DETECTIONS_PER_IMAGE", params.max_detections_per_image),
|
||||
MakeJitConstant("MAX_DELTA_LOG_WH", params.max_delta_log_wh),
|
||||
MakeJitConstant("DELTA_WEIGHT_X", params.deltas_weights[0]),
|
||||
MakeJitConstant("DELTA_WEIGHT_Y", params.deltas_weights[1]),
|
||||
MakeJitConstant("DELTA_WEIGHT_LOG_W", params.deltas_weights[2]),
|
||||
MakeJitConstant("DELTA_WEIGHT_LOG_H", params.deltas_weights[3]),
|
||||
|
||||
MakeJitConstant("ROI_COUNT", params.inputs[kScoresInputIdx].Batch().v),
|
||||
|
||||
MakeJitConstant("OUTPUT_INDICES_TYPE", "INPUT4_TYPE"),
|
||||
});
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
using DispatchData = CommonDispatchData;
|
||||
|
||||
void ExperimentalDetectronDetectionOutputKernelRef::PrepareKernelCommon(
|
||||
const experimental_detectron_detection_output_params& params,
|
||||
const optional_params& options,
|
||||
std::vector<size_t> gws,
|
||||
const std::string& stage_name,
|
||||
size_t stage_index,
|
||||
clKernelData& kernel) const {
|
||||
DispatchData dispatch_data;
|
||||
dispatch_data.gws = std::move(gws);
|
||||
dispatch_data.lws = GetOptimalLocalWorkGroupSizes(dispatch_data.gws, params.engineInfo);
|
||||
|
||||
const auto entry_point = GetEntryPoint(kernelName, params.layerID, params, options, stage_index);
|
||||
auto cldnn_jit = GetJitConstants(params);
|
||||
cldnn_jit.AddConstant(MakeJitConstant(stage_name, "true"));
|
||||
|
||||
const auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
KernelBase::CheckDispatchData(kernelName, dispatch_data, params.engineInfo.maxWorkGroupSize);
|
||||
kernel.params.workGroups.global = dispatch_data.gws;
|
||||
kernel.params.workGroups.local = dispatch_data.lws;
|
||||
kernel.code.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo);
|
||||
}
|
||||
|
||||
void ExperimentalDetectronDetectionOutputKernelRef::PrepareRefineBoxesKernel(
|
||||
const experimental_detectron_detection_output_params& params,
|
||||
const optional_params& options,
|
||||
clKernelData& kernel) const {
|
||||
const size_t roi_count = params.inputs[kScoresInputIdx].Batch().v;
|
||||
const size_t class_count = params.num_classes;
|
||||
|
||||
PrepareKernelCommon(params, options, {roi_count, class_count, 1}, "EDDO_STAGE_0_REFINE_BOXES", 0, kernel);
|
||||
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kBoxesInputIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kDeltasInputIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kScoresInputIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kImInfoInputIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxAreasBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedScoresBufferIdx});
|
||||
}
|
||||
|
||||
void ExperimentalDetectronDetectionOutputKernelRef::PrepareNmsClassWiseKernel(
|
||||
const experimental_detectron_detection_output_params& params,
|
||||
const optional_params& options,
|
||||
clKernelData& kernel) const {
|
||||
PrepareKernelCommon(params, options, {1, 1, 1}, "EDDO_STAGE_1_NMS", 1, kernel);
|
||||
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedScoresBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxAreasBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx});
|
||||
}
|
||||
|
||||
void ExperimentalDetectronDetectionOutputKernelRef::PrepareTopKDetectionsKernel(
|
||||
const experimental_detectron_detection_output_params& params,
|
||||
const optional_params& options,
|
||||
clKernelData& kernel) const {
|
||||
PrepareKernelCommon(params, options, {1, 1, 1}, "EDDO_STAGE_2_TOPK", 2, kernel);
|
||||
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx});
|
||||
}
|
||||
|
||||
void ExperimentalDetectronDetectionOutputKernelRef::PrepareCopyOutputKernel(
|
||||
const experimental_detectron_detection_output_params& params,
|
||||
const optional_params& options,
|
||||
clKernelData& kernel) const {
|
||||
PrepareKernelCommon(params,
|
||||
options,
|
||||
{static_cast<size_t>(params.max_detections_per_image), 1, 1},
|
||||
"EDDO_STAGE_3_COPY_OUTPUT",
|
||||
3,
|
||||
kernel);
|
||||
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kScoreClassIndexBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kDetectionCountBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, kRefinedBoxesBufferIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, kOutputIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kOutputClassesInputIdx});
|
||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, kOutputScoresInputIdx});
|
||||
}
|
||||
|
||||
KernelsData ExperimentalDetectronDetectionOutputKernelRef::GetKernelsData(const Params& params,
|
||||
const optional_params& options) const {
|
||||
if (!Validate(params, options)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
constexpr size_t kKernelCount = 4;
|
||||
KernelData kd = KernelData::Default<experimental_detectron_detection_output_params>(params, kKernelCount);
|
||||
const auto& eddo_params = static_cast<const experimental_detectron_detection_output_params&>(params);
|
||||
|
||||
const auto roi_count = eddo_params.inputs[kScoresInputIdx].Batch().v;
|
||||
const auto class_count = static_cast<size_t>(eddo_params.num_classes);
|
||||
|
||||
kd.internalBufferDataType = Datatype::F32;
|
||||
|
||||
kd.internalBufferSizes.resize(kBufferCount);
|
||||
kd.internalBufferSizes[kRefinedBoxesBufferIdx] = class_count * roi_count * 4 * sizeof(float);
|
||||
kd.internalBufferSizes[kRefinedBoxAreasBufferIdx] = class_count * roi_count * sizeof(float);
|
||||
kd.internalBufferSizes[kRefinedScoresBufferIdx] = class_count * roi_count * sizeof(float);
|
||||
kd.internalBufferSizes[kScoreClassIndexBufferIdx] = class_count * roi_count * 12; // sizeof ScoreClassIndex
|
||||
kd.internalBufferSizes[kDetectionCountBufferIdx] = sizeof(uint32_t);
|
||||
|
||||
PrepareRefineBoxesKernel(eddo_params, options, kd.kernels[0]);
|
||||
PrepareNmsClassWiseKernel(eddo_params, options, kd.kernels[1]);
|
||||
PrepareTopKDetectionsKernel(eddo_params, options, kd.kernels[2]);
|
||||
PrepareCopyOutputKernel(eddo_params, options, kd.kernels[3]);
|
||||
|
||||
return {kd};
|
||||
}
|
||||
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernel_base_opencl.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
struct experimental_detectron_detection_output_params : public base_params {
|
||||
experimental_detectron_detection_output_params()
|
||||
: base_params(KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {}
|
||||
|
||||
float score_threshold;
|
||||
float nms_threshold;
|
||||
float max_delta_log_wh;
|
||||
int num_classes;
|
||||
int post_nms_count;
|
||||
int max_detections_per_image;
|
||||
bool class_agnostic_box_regression;
|
||||
std::vector<float> deltas_weights;
|
||||
};
|
||||
|
||||
struct experimental_detectron_detection_output_optional_params : public optional_params {
|
||||
experimental_detectron_detection_output_optional_params()
|
||||
: optional_params(KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT) {}
|
||||
};
|
||||
|
||||
class ExperimentalDetectronDetectionOutputKernelRef : public KernelBaseOpenCL {
|
||||
public:
|
||||
ExperimentalDetectronDetectionOutputKernelRef() : KernelBaseOpenCL("experimental_detectron_detection_output_ref") {}
|
||||
|
||||
~ExperimentalDetectronDetectionOutputKernelRef() = default;
|
||||
|
||||
protected:
|
||||
bool Validate(const Params& p, const optional_params& o) const override;
|
||||
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
|
||||
private:
|
||||
JitConstants GetJitConstants(const experimental_detectron_detection_output_params& params) const;
|
||||
void PrepareKernelCommon(const experimental_detectron_detection_output_params& params,
|
||||
const optional_params& options,
|
||||
std::vector<size_t> gws,
|
||||
const std::string& stage_name,
|
||||
size_t stage_index,
|
||||
clKernelData& kernel) const;
|
||||
void PrepareRefineBoxesKernel(const experimental_detectron_detection_output_params&,
|
||||
const optional_params&,
|
||||
clKernelData&) const;
|
||||
void PrepareNmsClassWiseKernel(const experimental_detectron_detection_output_params&,
|
||||
const optional_params&,
|
||||
clKernelData&) const;
|
||||
void PrepareTopKDetectionsKernel(const experimental_detectron_detection_output_params&,
|
||||
const optional_params&,
|
||||
clKernelData&) const;
|
||||
void PrepareCopyOutputKernel(const experimental_detectron_detection_output_params&,
|
||||
const optional_params&,
|
||||
clKernelData&) const;
|
||||
};
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "experimental_detectron_detection_output_kernel_selector.h"
|
||||
|
||||
#include "experimental_detectron_detection_output_kernel_ref.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
experimental_detectron_detection_output_kernel_selector::experimental_detectron_detection_output_kernel_selector() {
|
||||
Attach<ExperimentalDetectronDetectionOutputKernelRef>();
|
||||
}
|
||||
|
||||
experimental_detectron_detection_output_kernel_selector&
|
||||
experimental_detectron_detection_output_kernel_selector::Instance() {
|
||||
static experimental_detectron_detection_output_kernel_selector instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
KernelsData experimental_detectron_detection_output_kernel_selector::GetBestKernels(
|
||||
const Params& params,
|
||||
const optional_params& options) const {
|
||||
return GetNaiveBestKernel(params, options, KernelType::EXPERIMENTAL_DETECTRON_DETECTION_OUTPUT);
|
||||
}
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,19 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernel_selector.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
class experimental_detectron_detection_output_kernel_selector : public kernel_selector_base {
|
||||
public:
|
||||
static experimental_detectron_detection_output_kernel_selector& Instance();
|
||||
|
||||
experimental_detectron_detection_output_kernel_selector();
|
||||
|
||||
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
|
||||
};
|
||||
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,319 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "include/batch_headers/data_types.cl"
|
||||
|
||||
#define INPUT_TYPE INPUT0_TYPE
|
||||
#define INPUT_TYPE2 MAKE_VECTOR_TYPE(INPUT0_TYPE, 2)
|
||||
#define INPUT_TYPE4 MAKE_VECTOR_TYPE(INPUT0_TYPE, 4)
|
||||
|
||||
#if INPUT0_TYPE_SIZE == 2 // f16
|
||||
# define HALF_ONE (INPUT_TYPE2)(0.5h)
|
||||
# define ZERO 0.0h
|
||||
# define ONE 1.0h
|
||||
#else
|
||||
# define HALF_ONE (INPUT_TYPE2)(0.5f)
|
||||
# define ZERO 0.0f
|
||||
# define ONE 1.0f
|
||||
#endif
|
||||
|
||||
#define ZERO2 (INPUT_TYPE2)(ZERO)
|
||||
#define ZERO4 (INPUT_TYPE4)(ZERO)
|
||||
|
||||
#define COORDINATE_OFFSET (INPUT_TYPE2)(ONE)
|
||||
|
||||
#define DELTA_WEIGHTS (INPUT_TYPE4)(DELTA_WEIGHT_X, DELTA_WEIGHT_Y, DELTA_WEIGHT_LOG_W, DELTA_WEIGHT_LOG_H)
|
||||
|
||||
#define MAX_DELTA_LOG_SIZE (INPUT_TYPE2)(TO_INPUT0_TYPE(MAX_DELTA_LOG_WH))
|
||||
|
||||
typedef struct __attribute__((packed)) {
|
||||
INPUT_TYPE score __attribute__((aligned(4)));
|
||||
uint class_idx;
|
||||
uint box_idx;
|
||||
} FUNC(SCI);
|
||||
|
||||
#define ScoreClassIndex FUNC(SCI)
|
||||
|
||||
inline void FUNC(swap_info)(__global ScoreClassIndex* a, __global ScoreClassIndex* b) {
|
||||
const ScoreClassIndex temp = *a;
|
||||
*a = *b;
|
||||
*b = temp;
|
||||
}
|
||||
|
||||
inline int FUNC(partition)(__global ScoreClassIndex* arr, int l, int h) {
|
||||
const INPUT_TYPE pivot_score = arr[h].score;
|
||||
const size_t pivot_box_idx = arr[h].box_idx;
|
||||
int i = (l - 1);
|
||||
for (int j = l; j <= h - 1; j++) {
|
||||
if ((arr[j].score > pivot_score) || (arr[j].score == pivot_score && arr[j].box_idx < pivot_box_idx)) {
|
||||
i++;
|
||||
FUNC_CALL(swap_info)(&arr[i], &arr[j]);
|
||||
}
|
||||
}
|
||||
FUNC_CALL(swap_info)(&arr[i + 1], &arr[h]);
|
||||
return (i + 1);
|
||||
}
|
||||
|
||||
inline void FUNC(bubbleSortIterative)(__global ScoreClassIndex* arr, int l, int h) {
|
||||
for (int i = 0; i < h - l; i++) {
|
||||
bool swapped = false;
|
||||
for (int j = l; j < h - i; j++) {
|
||||
if ((arr[j].score > arr[j + 1].score) ||
|
||||
(arr[j].score == arr[j + 1].score && arr[j].box_idx < arr[j + 1].box_idx)) {
|
||||
FUNC_CALL(swap_info)(&arr[j], &arr[j + 1]);
|
||||
swapped = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!swapped)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void FUNC(quickSortIterative)(__global ScoreClassIndex* arr, int l, int h) {
|
||||
// Create an auxiliary stack
|
||||
const int kStackSize = 100;
|
||||
int stack[kStackSize];
|
||||
|
||||
// initialize top of stack
|
||||
int top = -1;
|
||||
|
||||
// push initial values of l and h to stack
|
||||
stack[++top] = l;
|
||||
stack[++top] = h;
|
||||
|
||||
// Keep popping from stack while is not empty
|
||||
while (top >= 0) {
|
||||
// Pop h and l
|
||||
h = stack[top--];
|
||||
l = stack[top--];
|
||||
|
||||
// Set pivot element at its correct position
|
||||
// in sorted array
|
||||
int p = FUNC_CALL(partition)(arr, l, h);
|
||||
|
||||
// If there are elements on left side of pivot,
|
||||
// then push left side to stack
|
||||
if (p - 1 > l) {
|
||||
if (top >= (kStackSize - 1)) {
|
||||
FUNC_CALL(bubbleSortIterative)(arr, l, p - 1);
|
||||
} else {
|
||||
stack[++top] = l;
|
||||
stack[++top] = p - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// If there are elements on right side of pivot,
|
||||
// then push right side to stack
|
||||
if (p + 1 < h) {
|
||||
if (top >= (kStackSize - 1)) {
|
||||
FUNC_CALL(bubbleSortIterative)(arr, p + 1, h);
|
||||
} else {
|
||||
stack[++top] = p + 1;
|
||||
stack[++top] = h;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: rename stages accordingly
|
||||
#ifdef EDDO_STAGE_0_REFINE_BOXES
|
||||
|
||||
// 0. Refine boxes
|
||||
KERNEL(eddo_ref_stage_0)
|
||||
(const __global INPUT_TYPE* boxes,
|
||||
const __global INPUT_TYPE* deltas,
|
||||
const __global INPUT_TYPE* scores,
|
||||
const __global INPUT_TYPE* im_info,
|
||||
__global INPUT_TYPE* refined_boxes,
|
||||
__global INPUT_TYPE* refined_box_areas,
|
||||
__global INPUT_TYPE* refined_scores) {
|
||||
const size_t roi_count = get_global_size(0);
|
||||
size_t roi_idx = get_global_id(0);
|
||||
size_t class_idx = get_global_id(1);
|
||||
|
||||
INPUT_TYPE4 box = vload4(roi_idx, boxes);
|
||||
|
||||
if (any(islessequal(box.hi - box.lo, ZERO2))) {
|
||||
const int refined_offset = roi_count * class_idx + roi_idx;
|
||||
refined_scores[refined_offset] = ZERO;
|
||||
} else {
|
||||
const int offset = NUM_CLASSES * roi_idx + class_idx;
|
||||
|
||||
// width & height of box
|
||||
INPUT_TYPE2 box_size = (box.hi - box.lo + COORDINATE_OFFSET);
|
||||
|
||||
// center location of box
|
||||
const INPUT_TYPE2 center = box.lo + HALF_ONE * box_size;
|
||||
|
||||
const INPUT_TYPE4 delta = vload4(offset, deltas) / DELTA_WEIGHTS;
|
||||
|
||||
// new center location according to deltas (dx, dy)
|
||||
const INPUT_TYPE2 new_center = delta.lo * box_size + center;
|
||||
// new width & height according to deltas d(log w), d(log h)
|
||||
const INPUT_TYPE2 new_size = exp(min(delta.hi, MAX_DELTA_LOG_SIZE)) * box_size;
|
||||
|
||||
// update upper-left corner and lower-right corners respectively
|
||||
INPUT_TYPE4 new_box =
|
||||
(INPUT_TYPE4)(new_center - HALF_ONE * new_size, new_center + HALF_ONE * new_size - COORDINATE_OFFSET);
|
||||
|
||||
// adjust new corner locations to be within the image region
|
||||
const INPUT_TYPE2 img_size = vload2(0, im_info).s10;
|
||||
new_box = clamp(new_box, ZERO4, img_size.xyxy);
|
||||
|
||||
// recompute new width & height
|
||||
const INPUT_TYPE2 new_box_size = new_box.hi - new_box.lo + COORDINATE_OFFSET;
|
||||
|
||||
const int refined_offset = roi_count * class_idx + roi_idx;
|
||||
vstore4(new_box, refined_offset, refined_boxes);
|
||||
refined_box_areas[refined_offset] = new_box_size.x * new_box_size.y;
|
||||
refined_scores[refined_offset] = scores[offset];
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* EDDO_STAGE_0_REFINE_BOXES */
|
||||
|
||||
#ifdef EDDO_STAGE_1_NMS
|
||||
|
||||
inline INPUT_TYPE FUNC(jaccard_overlap)(const __global INPUT_TYPE* refined_boxes,
|
||||
const __global INPUT_TYPE* refined_box_areas,
|
||||
size_t idx1,
|
||||
size_t idx2) {
|
||||
INPUT_TYPE4 box1 = vload4(idx1, refined_boxes);
|
||||
INPUT_TYPE4 box2 = vload4(idx2, refined_boxes);
|
||||
|
||||
const bool bbox_not_covered = any(isgreater((INPUT_TYPE4)(box1.lo, box2.lo), (INPUT_TYPE4)(box2.hi, box1.hi)));
|
||||
if (bbox_not_covered) {
|
||||
return ZERO;
|
||||
}
|
||||
|
||||
INPUT_TYPE2 intersect_min = max(box1.lo, box2.lo);
|
||||
INPUT_TYPE2 intersect_max = min(box1.hi, box2.hi);
|
||||
|
||||
INPUT_TYPE2 intersect_size = intersect_max - intersect_min + COORDINATE_OFFSET;
|
||||
|
||||
if (any(islessequal(intersect_size, ZERO2))) {
|
||||
return ZERO;
|
||||
}
|
||||
|
||||
INPUT_TYPE intersect_area = intersect_size.x * intersect_size.y;
|
||||
INPUT_TYPE bbox1_area = refined_box_areas[idx1];
|
||||
INPUT_TYPE bbox2_area = refined_box_areas[idx2];
|
||||
|
||||
return intersect_area / (bbox1_area + bbox2_area - intersect_area);
|
||||
}
|
||||
|
||||
inline void FUNC(nms_cf)(const __global INPUT_TYPE* refined_scores,
|
||||
const __global INPUT_TYPE* refined_boxes,
|
||||
const __global INPUT_TYPE* refined_box_areas,
|
||||
size_t class_idx,
|
||||
size_t roi_count,
|
||||
__global ScoreClassIndex* score_class_index_map,
|
||||
__global uint* detection_count) {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < roi_count; ++i) {
|
||||
if (refined_scores[i] > SCORE_THRESHOLD) {
|
||||
score_class_index_map[count] = (ScoreClassIndex){refined_scores[i], class_idx, i};
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
FUNC_CALL(quickSortIterative)(score_class_index_map, 0, count - 1);
|
||||
|
||||
int detections = 0;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
const size_t idx = score_class_index_map[i].box_idx;
|
||||
|
||||
bool keep = true;
|
||||
for (size_t k = 0; k < detections; ++k) {
|
||||
const size_t kept_idx = score_class_index_map[k].box_idx;
|
||||
INPUT_TYPE overlap = FUNC_CALL(jaccard_overlap)(refined_boxes, refined_box_areas, idx, kept_idx);
|
||||
if (overlap > NMS_THRESHOLD) {
|
||||
keep = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (keep) {
|
||||
score_class_index_map[detections] = score_class_index_map[i];
|
||||
detections++;
|
||||
}
|
||||
}
|
||||
|
||||
*detection_count = min(POST_NMS_COUNT, detections);
|
||||
}
|
||||
|
||||
KERNEL(eddo_ref_stage_1)
|
||||
(const __global INPUT_TYPE* refined_scores,
|
||||
const __global INPUT_TYPE* refined_boxes,
|
||||
const __global INPUT_TYPE* refined_box_areas,
|
||||
__global ScoreClassIndex* score_class_index_map,
|
||||
__global uint* detection_count) {
|
||||
size_t total_detections_num = 0;
|
||||
// FIXME: figure out how to parallelize this!!!
|
||||
for (int class_idx = 0; class_idx < NUM_CLASSES; ++class_idx) {
|
||||
FUNC_CALL(nms_cf)
|
||||
(&refined_scores[ROI_COUNT * class_idx],
|
||||
&refined_boxes[ROI_COUNT * 4 * class_idx],
|
||||
&refined_box_areas[ROI_COUNT * class_idx],
|
||||
class_idx,
|
||||
ROI_COUNT,
|
||||
&score_class_index_map[total_detections_num],
|
||||
detection_count);
|
||||
total_detections_num += *detection_count;
|
||||
}
|
||||
|
||||
*detection_count = total_detections_num;
|
||||
}
|
||||
|
||||
#endif /* EDDO_STAGE_1_NMS */
|
||||
|
||||
#ifdef EDDO_STAGE_2_TOPK
|
||||
|
||||
KERNEL(eddo_ref_stage_2)
|
||||
(__global ScoreClassIndex* score_class_index_map, const __global uint* detection_count) {
|
||||
if (*detection_count > MAX_DETECTIONS_PER_IMAGE) {
|
||||
FUNC_CALL(quickSortIterative)(score_class_index_map, 0, *detection_count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* EDDO_STAGE_2_TOPK */
|
||||
|
||||
#ifdef EDDO_STAGE_3_COPY_OUTPUT
|
||||
|
||||
KERNEL(eddo_ref_stage_3)
|
||||
(const __global ScoreClassIndex* score_class_index_map,
|
||||
const __global uint* detection_count,
|
||||
const __global INPUT_TYPE* refined_boxes,
|
||||
__global OUTPUT_TYPE* output_boxes,
|
||||
__global OUTPUT_INDICES_TYPE* output_classes,
|
||||
__global OUTPUT_TYPE* output_scores) {
|
||||
size_t i = get_global_id(0);
|
||||
|
||||
if (i < *detection_count) {
|
||||
OUTPUT_TYPE score = score_class_index_map[i].score;
|
||||
OUTPUT_INDICES_TYPE cls = score_class_index_map[i].class_idx;
|
||||
OUTPUT_INDICES_TYPE idx = score_class_index_map[i].box_idx;
|
||||
vstore4(vload4(ROI_COUNT * cls + idx, refined_boxes), i, output_boxes);
|
||||
output_scores[i] = score;
|
||||
output_classes[i] = cls;
|
||||
} else {
|
||||
vstore4(ZERO4, i, output_boxes);
|
||||
output_scores[i] = ZERO;
|
||||
output_classes[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* EDDO_STAGE_3_COPY_OUTPUT */
|
||||
|
||||
#undef INPUT_TYPE
|
||||
#undef INPUT_TYPE2
|
||||
#undef INPUT_TYPE4
|
||||
#undef HALF_ONE
|
||||
#undef ZERO
|
||||
#undef ONE
|
||||
#undef ZERO2
|
||||
#undef ZERO4
|
||||
#undef COORDINATE_OFFSET
|
||||
#undef DELTA_WEIGHTS
|
||||
#undef MAX_DELTA_LOG_SIZE
|
@ -0,0 +1,101 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/experimental_detectron_detection_output.hpp"
|
||||
|
||||
#include "intel_gpu/plugin/common_utils.hpp"
|
||||
#include "intel_gpu/plugin/program.hpp"
|
||||
#include "intel_gpu/primitives/experimental_detectron_detection_output.hpp"
|
||||
#include "intel_gpu/primitives/mutable_data.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace runtime {
|
||||
namespace intel_gpu {
|
||||
|
||||
static void CreateExperimentalDetectronDetectionOutputOp(
|
||||
Program& p,
|
||||
const std::shared_ptr<ngraph::op::v6::ExperimentalDetectronDetectionOutput>& op) {
|
||||
p.ValidateInputs(op, {4});
|
||||
|
||||
if (op->get_output_size() != 3) {
|
||||
IE_THROW() << "ExperimentalDetectronDetectionOutput requires 3 outputs";
|
||||
}
|
||||
|
||||
auto inputs = p.GetInputPrimitiveIDs(op);
|
||||
|
||||
const auto& attrs = op->get_attrs();
|
||||
|
||||
const auto op_friendly_name = op->get_friendly_name();
|
||||
|
||||
const auto layer_type_name = layer_type_name_ID(op);
|
||||
const auto layer_name = layer_type_name + ".0";
|
||||
|
||||
const auto mutable_precision1 = op->get_output_element_type(1);
|
||||
const auto output_shape1 = op->get_output_shape(1);
|
||||
const cldnn::layout mutable_layout1{DataTypeFromPrecision(mutable_precision1),
|
||||
DefaultFormatForDims(output_shape1.size()),
|
||||
tensor_from_dims(output_shape1)};
|
||||
cldnn::memory::ptr shared_memory1{p.GetEngine().allocate_memory(mutable_layout1)};
|
||||
|
||||
const auto mutable_id_w1 = layer_type_name + "_md_write.1";
|
||||
const cldnn::mutable_data mutable_prim_w{mutable_id_w1, shared_memory1, op_friendly_name};
|
||||
p.primitiveIDs[mutable_id_w1] = mutable_id_w1;
|
||||
p.AddPrimitive(mutable_prim_w);
|
||||
inputs.push_back(mutable_id_w1);
|
||||
|
||||
const auto mutable_precision2 = op->get_output_element_type(2);
|
||||
const auto output_shape2 = op->get_output_shape(2);
|
||||
const cldnn::layout mutable_layout2{DataTypeFromPrecision(mutable_precision2),
|
||||
DefaultFormatForDims(output_shape2.size()),
|
||||
tensor_from_dims(output_shape2)};
|
||||
cldnn::memory::ptr shared_memory2{p.GetEngine().allocate_memory(mutable_layout2)};
|
||||
|
||||
const auto mutable_id_w2 = layer_type_name + "_md_write.2";
|
||||
const cldnn::mutable_data mutable_prim_w2{mutable_id_w2, shared_memory2, op_friendly_name};
|
||||
p.primitiveIDs[mutable_id_w2] = mutable_id_w2;
|
||||
p.AddPrimitive(mutable_prim_w2);
|
||||
inputs.push_back(mutable_id_w2);
|
||||
|
||||
const auto expectedPrimInputCount = 4 + 2; // 4 operation inputs plus 2 input-outputs
|
||||
if (inputs.size() != expectedPrimInputCount) {
|
||||
IE_THROW() << "experimental_detectron_detection_output primitive requires 6 inputs";
|
||||
}
|
||||
|
||||
const cldnn::experimental_detectron_detection_output prim{layer_name,
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
inputs[4], // output classes
|
||||
inputs[5], // output scores
|
||||
attrs.score_threshold,
|
||||
attrs.nms_threshold,
|
||||
static_cast<int>(attrs.num_classes),
|
||||
static_cast<int>(attrs.post_nms_count),
|
||||
static_cast<int>(attrs.max_detections_per_image),
|
||||
attrs.class_agnostic_box_regression,
|
||||
attrs.max_delta_log_wh,
|
||||
attrs.deltas_weights,
|
||||
op_friendly_name};
|
||||
|
||||
p.AddPrimitive(prim);
|
||||
|
||||
const auto mutable_id_r1 = layer_type_name + ".1";
|
||||
const cldnn::mutable_data mutable_prim_r1{mutable_id_r1, {layer_name}, shared_memory1, op_friendly_name};
|
||||
p.primitiveIDs[mutable_id_r1] = mutable_id_r1;
|
||||
p.AddPrimitive(mutable_prim_r1);
|
||||
|
||||
const auto mutable_id_r2 = layer_type_name + ".2";
|
||||
const cldnn::mutable_data mutable_prim_r2{mutable_id_r2, {layer_name}, shared_memory2, op_friendly_name};
|
||||
p.primitiveIDs[mutable_id_r2] = mutable_id_r2;
|
||||
p.AddPrimitive(mutable_prim_r2);
|
||||
|
||||
p.AddPrimitiveToProfiler(prim, op);
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v6, ExperimentalDetectronDetectionOutput);
|
||||
|
||||
} // namespace intel_gpu
|
||||
} // namespace runtime
|
||||
} // namespace ov
|
@ -0,0 +1,322 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <intel_gpu/primitives/experimental_detectron_detection_output.hpp>
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/mutable_data.hpp>
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> getValues(const std::vector<float>& values) {
|
||||
std::vector<T> result(values.begin(), values.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float getError();
|
||||
|
||||
template <>
|
||||
float getError<float>() {
|
||||
return 0.001;
|
||||
}
|
||||
|
||||
template <>
|
||||
float getError<half_t>() {
|
||||
return 0.2;
|
||||
}
|
||||
|
||||
}; // namespace
|
||||
|
||||
template <typename T>
|
||||
struct ExperimentalDetectronDetectionOutputParams {
|
||||
float score_threshold;
|
||||
float nms_threshold;
|
||||
float max_delta_log_wh;
|
||||
int num_classes;
|
||||
int post_nms_count;
|
||||
int max_detections_per_image;
|
||||
bool class_agnostic_box_regression;
|
||||
std::vector<float> deltas_weights;
|
||||
|
||||
size_t roi_count;
|
||||
|
||||
std::vector<T> boxes;
|
||||
std::vector<T> deltas;
|
||||
std::vector<T> scores;
|
||||
std::vector<T> im_info;
|
||||
|
||||
std::vector<T> expected_boxes;
|
||||
std::vector<int32_t> expected_classes;
|
||||
std::vector<T> expected_scores;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct experimental_detectron_detection_output_test
|
||||
: public ::testing::TestWithParam<ExperimentalDetectronDetectionOutputParams<T>> {
|
||||
public:
|
||||
void test() {
|
||||
const ExperimentalDetectronDetectionOutputParams<T> param =
|
||||
testing::TestWithParam<ExperimentalDetectronDetectionOutputParams<T>>::GetParam();
|
||||
auto data_type = type_to_data_type<T>::value;
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
const primitive_id input_boxes_id = "InputBoxes";
|
||||
const auto input_boxes =
|
||||
engine.allocate_memory({data_type, format::bfyx, tensor{batch(param.roi_count), feature(4)}});
|
||||
set_values(input_boxes, param.boxes);
|
||||
|
||||
const primitive_id input_deltas_id = "InputDeltas";
|
||||
auto input_deltas = engine.allocate_memory(
|
||||
{data_type, format::bfyx, tensor{batch(param.roi_count), feature(param.num_classes * 4)}});
|
||||
set_values(input_deltas, param.deltas);
|
||||
|
||||
const primitive_id input_scores_id = "InputScores";
|
||||
auto input_scores = engine.allocate_memory(
|
||||
{data_type, format::bfyx, tensor{batch(param.roi_count), feature(param.num_classes)}});
|
||||
set_values(input_scores, param.scores);
|
||||
|
||||
const primitive_id input_im_info_id = "InputImInfo";
|
||||
const auto input_im_info = engine.allocate_memory({data_type, format::bfyx, tensor{batch(1), feature(3)}});
|
||||
set_values(input_im_info, param.im_info);
|
||||
|
||||
const primitive_id output_scores_id = "OutputScores";
|
||||
auto output_scores =
|
||||
engine.allocate_memory({data_type, format::bfyx, tensor{batch(param.max_detections_per_image)}});
|
||||
|
||||
const primitive_id output_classes_id = "OutputClasses";
|
||||
auto output_classes =
|
||||
engine.allocate_memory({data_types::i32, format::bfyx, tensor{batch(param.max_detections_per_image)}});
|
||||
|
||||
topology topology;
|
||||
|
||||
topology.add(input_layout(input_boxes_id, input_boxes->get_layout()));
|
||||
topology.add(input_layout(input_deltas_id, input_deltas->get_layout()));
|
||||
topology.add(input_layout(input_scores_id, input_scores->get_layout()));
|
||||
topology.add(input_layout(input_im_info_id, input_im_info->get_layout()));
|
||||
topology.add(mutable_data(output_classes_id, output_classes));
|
||||
topology.add(mutable_data(output_scores_id, output_scores));
|
||||
|
||||
const primitive_id eddo_id = "experimental_detectron_detection_output";
|
||||
const auto eddo_primitive = experimental_detectron_detection_output{
|
||||
eddo_id,
|
||||
input_boxes_id,
|
||||
input_deltas_id,
|
||||
input_scores_id,
|
||||
input_im_info_id,
|
||||
output_classes_id,
|
||||
output_scores_id,
|
||||
param.score_threshold,
|
||||
param.nms_threshold,
|
||||
param.num_classes,
|
||||
param.post_nms_count,
|
||||
param.max_detections_per_image,
|
||||
param.class_agnostic_box_regression,
|
||||
param.max_delta_log_wh,
|
||||
param.deltas_weights,
|
||||
};
|
||||
|
||||
topology.add(eddo_primitive);
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data(input_boxes_id, input_boxes);
|
||||
network.set_input_data(input_deltas_id, input_deltas);
|
||||
network.set_input_data(input_scores_id, input_scores);
|
||||
network.set_input_data(input_im_info_id, input_im_info);
|
||||
|
||||
const auto outputs = network.execute();
|
||||
|
||||
const auto output_boxes = outputs.at(eddo_id).get_memory();
|
||||
|
||||
const cldnn::mem_lock<T> output_boxes_ptr(output_boxes, get_test_stream());
|
||||
ASSERT_EQ(output_boxes_ptr.size(), param.max_detections_per_image * 4);
|
||||
|
||||
const cldnn::mem_lock<int32_t> output_classes_ptr(output_classes, get_test_stream());
|
||||
ASSERT_EQ(output_classes_ptr.size(), param.max_detections_per_image);
|
||||
|
||||
const cldnn::mem_lock<T> output_scores_ptr(output_scores, get_test_stream());
|
||||
ASSERT_EQ(output_scores_ptr.size(), param.max_detections_per_image);
|
||||
|
||||
const auto& expected_boxes = param.expected_boxes;
|
||||
const auto& expected_classes = param.expected_classes;
|
||||
const auto& expected_scores = param.expected_scores;
|
||||
for (size_t i = 0; i < param.max_detections_per_image; ++i) {
|
||||
EXPECT_NEAR(expected_scores[i], output_scores_ptr[i], 0.001) << "i=" << i;
|
||||
for (size_t coord = 0; coord < 4; ++coord) {
|
||||
const auto roi_idx = i * 4 + coord;
|
||||
EXPECT_NEAR(expected_boxes[roi_idx], output_boxes_ptr[roi_idx], getError<T>())
|
||||
<< "i=" << i << ", coord=" << coord;
|
||||
}
|
||||
|
||||
EXPECT_EQ(expected_classes[i], output_classes_ptr[i]) << "i=" << i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using experimental_detectron_detection_output_test_f32 = experimental_detectron_detection_output_test<float>;
|
||||
using experimental_detectron_detection_output_test_f16 = experimental_detectron_detection_output_test<half_t>;
|
||||
|
||||
TEST_P(experimental_detectron_detection_output_test_f32, basic) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
TEST_P(experimental_detectron_detection_output_test_f16, basic) {
|
||||
ASSERT_NO_FATAL_FAILURE(test());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ExperimentalDetectronDetectionOutputParams<T>> getExperimentalDetectronDetectionOutputParams() {
|
||||
std::vector<ExperimentalDetectronDetectionOutputParams<T>> params = {
|
||||
{
|
||||
0.01000000074505806f, // score_threshold
|
||||
0.2f, // nms_threshold
|
||||
2.0f, // max_delta_log_wh
|
||||
2, // num_classes
|
||||
500, // post_nms_count
|
||||
5, // max_detections_per_image
|
||||
true, // class_agnostic_box_regression
|
||||
{10.0f, 10.0f, 5.0f, 5.0f}, // deltas_weights
|
||||
16, // roi count
|
||||
|
||||
// boxes
|
||||
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f,
|
||||
1.0f, 8.0f, 5.0f, 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
|
||||
|
||||
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
|
||||
|
||||
getValues<T>({0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}),
|
||||
|
||||
getValues<T>({16.0f, 12.0f, 1.0f}),
|
||||
getValues<T>({4.8929863f, 0.892986298f, 12.0f, 12.1070137f, 0.0f, 0.892986298f, 10.1070137f,
|
||||
12.1070137f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}),
|
||||
std::vector<int32_t>{0, 1, 0, 0, 0},
|
||||
getValues<T>({0.8f, 0.9f, 0.0f, 0.0f, 0.0f}),
|
||||
},
|
||||
{
|
||||
0.0500000007, // score_threshold
|
||||
0.5, // nms_threshold
|
||||
4.13516665, // max_delta_log_wh
|
||||
5, // num_classes
|
||||
10, // post_nms_count
|
||||
10, // max_detections_per_image
|
||||
false, // class_agnostic_box_regression
|
||||
{10.0f, 10.0f, 5.0f, 5.0f}, // deltas_weights
|
||||
10, // roi count
|
||||
|
||||
// boxes
|
||||
getValues<T>({
|
||||
4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188,
|
||||
6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281,
|
||||
0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125,
|
||||
6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719,
|
||||
}),
|
||||
|
||||
// deltas
|
||||
getValues<T>({
|
||||
4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188,
|
||||
6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281,
|
||||
0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125,
|
||||
6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719,
|
||||
3.74609, 4.54688, 5.83203, 5.91406, 2.85547, 7.46875, 4.31641, 2.71094, 9.71484, 1.14062,
|
||||
6.55078, 0.257812, 4.32422, 9.5625, 8.53516, 0.554688, 8.68359, 2.73438, 6.26953, 5.60156,
|
||||
2.79297, 8.65625, 5.75391, 5.39844, 2.65234, 7.32812, 8.98828, 7.94531, 6.26172, 4.75,
|
||||
7.97266, 1.24219, 5.62109, 8.92188, 2.70703, 1.28906, 4.73047, 7.84375, 5.19141, 6.08594,
|
||||
7.58984, 9.51562, 7.42578, 5.63281, 6.19922, 7.9375, 5.41016, 9.92969, 2.55859, 1.10938,
|
||||
1.14453, 8.97656, 4.66797, 9.03125, 4.62891, 0.773438, 4.52734, 1.70312, 9.86328, 1.32031,
|
||||
0.136719, 9.125, 2.84766, 4.61719, 9.49609, 5.29688, 5.58203, 0.664062, 2.60547, 6.21875,
|
||||
8.06641, 5.46094, 1.46484, 7.89062, 0.300781, 5.00781, 0.0742188, 0.3125, 6.28516, 3.30469,
|
||||
4.43359, 1.48438, 2.01953, 8.35156, 8.54297, 7.40625, 9.50391, 2.14844, 2.40234, 2.07812,
|
||||
2.73828, 2.69531, 4.01172, 9.5, 7.72266, 9.99219, 1.37109, 3.67188, 2.45703, 2.03906,
|
||||
0.480469, 4.59375, 2.94141, 4.83594, 1.33984, 0.265625, 1.17578, 4.38281, 5.94922, 8.6875,
|
||||
5.16016, 0.679688, 4.30859, 5.85938, 4.89453, 7.72656, 4.41797, 5.78125, 4.37891, 1.52344,
|
||||
8.27734, 4.45312, 3.61328, 4.07031, 7.88672, 9.875, 4.59766, 1.36719, 7.24609, 8.04688,
|
||||
5.33203, 5.41406, 4.35547, 0.96875, 1.81641, 8.21094, 3.21484, 4.64062, 4.05078, 9.75781,
|
||||
7.82422, 3.0625, 4.03516, 0.0546875, 8.18359, 8.23438, 1.76953, 1.10156, 2.29297, 8.15625,
|
||||
9.25391, 0.898438, 6.15234, 8.82812, 6.48828, 7.44531, 1.76172, 2.25, 9.47266, 0.742188,
|
||||
}),
|
||||
|
||||
// scores
|
||||
getValues<T>({
|
||||
4.90234, 6.57812, 5.23828, 9.19531, 8.51172, 2, 8.22266, 0.492188, 9.87109, 4.17188,
|
||||
6.95703, 8.53906, 0.980469, 9.09375, 3.44141, 5.33594, 9.83984, 6.76562, 1.67578, 6.88281,
|
||||
0.449219, 9.1875, 7.66016, 7.17969, 8.80859, 2.35938, 5.39453, 8.22656, 0.917969, 0.28125,
|
||||
6.87891, 6.02344, 6.77734, 6.95312, 6.11328, 6.57031, 0.386719, 8.375, 5.09766, 9.86719,
|
||||
3.74609, 4.54688, 5.83203, 5.91406, 2.85547, 7.46875, 4.31641, 2.71094, 9.71484, 1.14062,
|
||||
}),
|
||||
|
||||
// im_info
|
||||
getValues<T>({
|
||||
4.90234,
|
||||
6.57812,
|
||||
5.23828,
|
||||
}),
|
||||
|
||||
// out_boxes
|
||||
getValues<T>({
|
||||
0, 2.97829, 6.57812, 4.90234, 0, 4.90234, 6.57812, 4.90234, 4.37184, 4.90234,
|
||||
6.03075, 4.90234, 5.95093, 3.66966, 6.57812, 4.90234, 0, 4.90234, 6.57812, 4.90234,
|
||||
1.31075, 4.90234, 6.57812, 4.90234, 3.24829, 4.90234, 6.57812, 4.90234, 0, 0,
|
||||
6.57812, 4.90234, 4.20346, 0, 6.57812, 4.90234, 0, 0, 6.57812, 4.90234,
|
||||
}),
|
||||
|
||||
// out_classes
|
||||
std::vector<int32_t>({
|
||||
4,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
2,
|
||||
3,
|
||||
}),
|
||||
|
||||
// out_scores
|
||||
getValues<T>({
|
||||
9.86719,
|
||||
9.71484,
|
||||
9.19531,
|
||||
8.51172,
|
||||
8.375,
|
||||
7.46875,
|
||||
6.57812,
|
||||
6.57031,
|
||||
5.23828,
|
||||
5.09766,
|
||||
}),
|
||||
},
|
||||
};
|
||||
return params;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(experimental_detectron_detection_output_gpu_test,
|
||||
experimental_detectron_detection_output_test_f32,
|
||||
::testing::ValuesIn(getExperimentalDetectronDetectionOutputParams<float>()));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(experimental_detectron_detection_output_gpu_test,
|
||||
experimental_detectron_detection_output_test_f16,
|
||||
::testing::ValuesIn(getExperimentalDetectronDetectionOutputParams<half_t>()));
|
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "single_layer_tests/experimental_detectron_detection_output.hpp"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||
|
||||
using namespace ov::test;
|
||||
using namespace ov::test::subgraph;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<ov::test::ElementType> netPrecisions = {
|
||||
ov::element::Type_t::f16,
|
||||
ov::element::Type_t::f32,
|
||||
};
|
||||
|
||||
const std::vector<float> score_threshold = {0.01f, 0.8f};
|
||||
|
||||
const std::vector<float> nms_threshold = {0.2f, 0.5f};
|
||||
|
||||
// specifies maximal delta of logarithms for width and height
|
||||
const std::vector<float> max_delta_log_wh = {2.0f, 5.0f};
|
||||
|
||||
// specifies number of detected classes
|
||||
const std::vector<int64_t> num_classes = {2};
|
||||
|
||||
// specifies maximal number of detections per class
|
||||
const std::vector<int64_t> post_nms_count = {5, 25};
|
||||
|
||||
// specifies maximual number of detections per image
|
||||
const std::vector<size_t> max_detections_per_image = {5, 25};
|
||||
|
||||
// a flag specifies whether to delete background classes or not
|
||||
// `true` means background classes should be deleted,
|
||||
// `false` means background classes shouldn't be deleted.
|
||||
const std::vector<bool> class_agnostic_box_regression = {true, false};
|
||||
|
||||
// specifies deltas of weights
|
||||
const std::vector<std::vector<float>> deltas_weights = {{10.0f, 10.0f, 5.0f, 5.0f}};
|
||||
|
||||
const std::vector<std::vector<InputShape>> inputShapes = {
|
||||
// inputRois / inputDeltas / inputScores / inputImInfos
|
||||
static_shapes_to_test_representation({{16, 4}, {16, 8}, {16, 2}, {1, 3}}),
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ExperimentalDetectronDetectionOutput,
|
||||
ExperimentalDetectronDetectionOutputLayerTest,
|
||||
::testing::Combine(::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(score_threshold),
|
||||
::testing::ValuesIn(nms_threshold),
|
||||
::testing::ValuesIn(max_delta_log_wh),
|
||||
::testing::ValuesIn(num_classes),
|
||||
::testing::ValuesIn(post_nms_count),
|
||||
::testing::ValuesIn(max_detections_per_image),
|
||||
::testing::ValuesIn(class_agnostic_box_regression),
|
||||
::testing::ValuesIn(deltas_weights),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
ExperimentalDetectronDetectionOutputLayerTest::getTestCaseName);
|
||||
|
||||
|
||||
} // namespace
|
@ -79,6 +79,9 @@ void ExperimentalDetectronDetectionOutputLayerTest::SetUp() {
|
||||
netPrecision,
|
||||
targetName) = this->GetParam();
|
||||
|
||||
if (netPrecision == element::f16)
|
||||
abs_threshold = 0.01;
|
||||
|
||||
inType = outType = netPrecision;
|
||||
targetDevice = targetName;
|
||||
|
||||
@ -93,42 +96,66 @@ void ExperimentalDetectronDetectionOutputLayerTest::SetUp() {
|
||||
params[2], // input_scores
|
||||
params[3], // input_im_info
|
||||
attributes);
|
||||
function = std::make_shared<ov::Model>(
|
||||
ov::OutputVector{experimentalDetectron->output(0), experimentalDetectron->output(1)},
|
||||
function = std::make_shared<ov::Model>(ov::OutputVector{experimentalDetectron->output(0),
|
||||
experimentalDetectron->output(1),
|
||||
experimentalDetectron->output(2)},
|
||||
"ExperimentalDetectronDetectionOutput");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> getValues(const std::vector<float>& values) {
|
||||
std::vector<T> result(values.begin(), values.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<ov::Tensor> generateInputTensors() {
|
||||
const auto netPrecision = ov::element::from<T>();
|
||||
std::vector<ov::Tensor> inputTensors = {
|
||||
// 16 x 4 = 64
|
||||
ov::test::utils::create_tensor<T>(
|
||||
netPrecision,
|
||||
Shape{16, 4},
|
||||
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f,
|
||||
1.0f, 8.0f, 5.0f, 1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})),
|
||||
// 16 x 8
|
||||
ov::test::utils::create_tensor<T>(
|
||||
netPrecision,
|
||||
Shape{16, 8},
|
||||
getValues<T>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f})),
|
||||
// 16 x 2 = 32
|
||||
ov::test::utils::create_tensor<T>(
|
||||
netPrecision,
|
||||
Shape{16, 2},
|
||||
getValues<T>({0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f})),
|
||||
// 1 x 3 = 3
|
||||
ov::test::utils::create_tensor<T>(netPrecision, Shape{1, 3}, getValues<T>({16.0f, 12.0f, 1.0f}))};
|
||||
|
||||
return inputTensors;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ExperimentalDetectronDetectionOutputLayerTest::generate_inputs(
|
||||
const std::vector<ngraph::Shape>& targetInputStaticShapes) {
|
||||
static const std::vector<ov::Tensor> inputTensors = {
|
||||
// 16 x 4 = 64
|
||||
ov::test::utils::create_tensor<float>(
|
||||
ov::element::f32,
|
||||
Shape{16, 4},
|
||||
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 8.0f, 5.0f,
|
||||
1.0f, 1.0f, 10.0f, 10.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
|
||||
// 16 x 8
|
||||
ov::test::utils::create_tensor<float>(
|
||||
ov::element::f32,
|
||||
Shape{16, 8},
|
||||
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
5.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
|
||||
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
|
||||
// 16 x 2 = 32
|
||||
ov::test::utils::create_tensor<float>(
|
||||
ov::element::f32,
|
||||
Shape{16, 2},
|
||||
{0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.8f, 0.9f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
|
||||
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}),
|
||||
// 1 x 3 = 3
|
||||
ov::test::utils::create_tensor<float>(ov::element::f32, Shape{1, 3}, {16.0f, 12.0f, 1.0f})};
|
||||
const auto netPrecision = std::get<9>(GetParam());
|
||||
|
||||
const std::vector<ov::Tensor> inputTensors =
|
||||
(netPrecision == element::f16) ? generateInputTensors<ov::float16>() : generateInputTensors<float>();
|
||||
|
||||
inputs.clear();
|
||||
const auto& funcInputs = function->inputs();
|
||||
|
Loading…
Reference in New Issue
Block a user