[GPU] Fix detection output stage-0 kernel (#10262)
- Change the constant value to the maximum work group size - Add CLK_GLOBAL_MEM_FENCE barrier to synchronize storing result in intermediate buffer - Add condition to prevent access local array out of range Signed-off-by: Andrew Kwangwoong Park <andrew.kwangwoong.park@intel.com>
This commit is contained in:
committed by
GitHub
parent
89c3a18f83
commit
51c89dff26
@@ -210,6 +210,7 @@ KernelsData DetectionOutputKernelRef::GetKernelsData(const Params& params, const
|
||||
auto num_classes = detectOutParams.detectOutParams.num_classes;
|
||||
auto num_loc_classes = (detectOutParams.detectOutParams.share_location) ? 1 : num_classes;
|
||||
auto num_prior_boxes = (loc_feature_num / (num_loc_classes * prior_box_size));
|
||||
auto max_wg = detectOutParams.engineInfo.maxWorkGroupSize;
|
||||
|
||||
constexpr size_t buffer_bytes = 10; // The size of struct Scores in detection_output_gpu_ref.cl
|
||||
size_t buffer_stride = num_prior_boxes * buffer_bytes;
|
||||
@@ -238,7 +239,7 @@ KernelsData DetectionOutputKernelRef::GetKernelsData(const Params& params, const
|
||||
cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_CAFFE_OPT", "true")});
|
||||
}
|
||||
size_t num_bit_mask = CeilDiv(num_prior_boxes, 8);
|
||||
size_t num_score_per_item = RoundUp(CeilDiv(num_prior_boxes, 256), 8);
|
||||
size_t num_score_per_item = RoundUp(CeilDiv(num_prior_boxes, max_wg), 8);
|
||||
size_t num_score_block = CeilDiv(num_prior_boxes, num_score_per_item);
|
||||
cldnnJit.AddConstants({MakeJitConstant("NUM_BIT_MASK", num_bit_mask),
|
||||
MakeJitConstant("NUM_PRIORS_PER_ITEM", num_score_per_item),
|
||||
|
||||
@@ -241,9 +241,10 @@ KERNEL (detection_output_stage_0_scores_caffe)(__global INPUT1_TYPE* input_confi
|
||||
__local char4 bit_mask[NUM_BIT_MASK];
|
||||
__local int4 block_num[NUM_PRIOR_BLOCKS];
|
||||
|
||||
block_num[box_gid] = (int4)(0, 0, 0, 0);
|
||||
|
||||
{
|
||||
// to prevent access array out of range
|
||||
if (start_bid < end_bid)
|
||||
block_num[box_gid] = (int4)(0, 0, 0, 0);
|
||||
int mask_id = start_bid / 8;
|
||||
for (int i = start_bid; i < end_bid; i += 8) {
|
||||
bit_mask[mask_id] = (char4)(0, 0, 0, 0);
|
||||
@@ -277,10 +278,12 @@ KERNEL (detection_output_stage_0_scores_caffe)(__global INPUT1_TYPE* input_confi
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
barrier(CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
{
|
||||
int4 write_offsets = block_num[box_gid];
|
||||
int4 write_offsets = (int4)(0, 0, 0, 0);
|
||||
if (start_bid < end_bid)
|
||||
write_offsets = block_num[box_gid];
|
||||
int mask_id = start_bid >> 3;
|
||||
for (int i = start_bid; i < end_bid; i += 8) {
|
||||
for (int bi = 0; bi < 8; bi++) {
|
||||
@@ -319,9 +322,10 @@ KERNEL (detection_output_stage_0_scores_caffe)(__global INPUT1_TYPE* input_confi
|
||||
__local char bit_mask[NUM_BIT_MASK];
|
||||
__local int block_num[NUM_PRIOR_BLOCKS];
|
||||
|
||||
block_num[box_gid] = 0;
|
||||
|
||||
{
|
||||
// to prevent access array out of range
|
||||
if (start_bid < end_bid)
|
||||
block_num[box_gid] = 0;
|
||||
int mask_id = start_bid / 8;
|
||||
for (int i = start_bid; i < end_bid; i += 8) {
|
||||
bit_mask[mask_id] = 0;
|
||||
@@ -336,7 +340,6 @@ KERNEL (detection_output_stage_0_scores_caffe)(__global INPUT1_TYPE* input_confi
|
||||
mask_id++;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
{
|
||||
@@ -350,11 +353,12 @@ KERNEL (detection_output_stage_0_scores_caffe)(__global INPUT1_TYPE* input_confi
|
||||
buffer1[batchId * NUM_CLASSES_ACC + classId] = acc_num;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
barrier(CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
{
|
||||
int write_offset = block_num[box_gid];
|
||||
int write_offset = 0;
|
||||
if (start_bid < end_bid)
|
||||
write_offset = block_num[box_gid];
|
||||
int mask_id = start_bid >> 3;
|
||||
for (int i = start_bid; i < end_bid; i += 8) {
|
||||
for (int bi = 0; bi < 8; bi++) {
|
||||
|
||||
Reference in New Issue
Block a user