[GPU] Fix detection output kernel build error on dGPU (#17150)
+ Check local memory size used in the kernel and choose proper kernel. + Select DO_STAGE_0_CAFFE instead of DO_STAGE_0_CAFFE_OPT
This commit is contained in:
parent
a6b1544acf
commit
5c21dcec4d
@ -238,17 +238,23 @@ KernelsData DetectionOutputKernelRef::GetKernelsData(const Params& params, const
|
||||
if (detectOutParams.detectOutParams.decrease_label_id) {
|
||||
cldnnJit.AddConstant(MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_MXNET", "true"));
|
||||
} else {
|
||||
if (detectOutParams.detectOutParams.conf_padding_x || detectOutParams.detectOutParams.conf_padding_y) {
|
||||
cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_CAFFE", "true")});
|
||||
} else {
|
||||
cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_CAFFE_OPT", "true")});
|
||||
}
|
||||
size_t num_bit_mask = CeilDiv(num_prior_boxes, 8);
|
||||
size_t num_score_per_item = RoundUp(CeilDiv(num_prior_boxes, max_wg), 8);
|
||||
size_t num_score_block = CeilDiv(num_prior_boxes, num_score_per_item);
|
||||
cldnnJit.AddConstants({MakeJitConstant("NUM_BIT_MASK", num_bit_mask),
|
||||
MakeJitConstant("NUM_PRIORS_PER_ITEM", num_score_per_item),
|
||||
MakeJitConstant("NUM_PRIOR_BLOCKS", num_score_block)});
|
||||
|
||||
std::string kernel_name_suffix = "_CAFFE";
|
||||
if (detectOutParams.detectOutParams.conf_padding_x == 0 && detectOutParams.detectOutParams.conf_padding_y == 0) {
|
||||
size_t req_local_mem_size = num_bit_mask * 4 * BytesPerElement(kernel_selector::Datatype::INT8)
|
||||
+ num_score_block * 4 * BytesPerElement(kernel_selector::Datatype::INT32);
|
||||
// Check local mem size used in DO_STAGE_0_CAFFE_OPT.
|
||||
if (req_local_mem_size < detectOutParams.engineInfo.maxLocalMemSize) {
|
||||
kernel_name_suffix = "_CAFFE_OPT";
|
||||
}
|
||||
}
|
||||
cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + kernel_name_suffix, "true")});
|
||||
}
|
||||
} else if (i == 1) {
|
||||
if (detectOutParams.detectOutParams.decrease_label_id) {
|
||||
|
Loading…
Reference in New Issue
Block a user