diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/detection_output/detection_output_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/detection_output/detection_output_kernel_ref.cpp index 3e82d684d08..a72823d7dd5 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/detection_output/detection_output_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/detection_output/detection_output_kernel_ref.cpp @@ -238,17 +238,23 @@ KernelsData DetectionOutputKernelRef::GetKernelsData(const Params& params, const if (detectOutParams.detectOutParams.decrease_label_id) { cldnnJit.AddConstant(MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_MXNET", "true")); } else { - if (detectOutParams.detectOutParams.conf_padding_x || detectOutParams.detectOutParams.conf_padding_y) { - cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_CAFFE", "true")}); - } else { - cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + "_CAFFE_OPT", "true")}); - } size_t num_bit_mask = CeilDiv(num_prior_boxes, 8); size_t num_score_per_item = RoundUp(CeilDiv(num_prior_boxes, max_wg), 8); size_t num_score_block = CeilDiv(num_prior_boxes, num_score_per_item); cldnnJit.AddConstants({MakeJitConstant("NUM_BIT_MASK", num_bit_mask), MakeJitConstant("NUM_PRIORS_PER_ITEM", num_score_per_item), MakeJitConstant("NUM_PRIOR_BLOCKS", num_score_block)}); + + std::string kernel_name_suffix = "_CAFFE"; + if (detectOutParams.detectOutParams.conf_padding_x == 0 && detectOutParams.detectOutParams.conf_padding_y == 0) { + size_t req_local_mem_size = num_bit_mask * 4 * BytesPerElement(kernel_selector::Datatype::INT8) + + num_score_block * 4 * BytesPerElement(kernel_selector::Datatype::INT32); + // Check local mem size used in DO_STAGE_0_CAFFE_OPT. + if (req_local_mem_size < detectOutParams.engineInfo.maxLocalMemSize) { + kernel_name_suffix = "_CAFFE_OPT"; + } + } + cldnnJit.AddConstants({MakeJitConstant("DO_STAGE_" + std::to_string(i) + kernel_name_suffix, "true")}); } } else if (i == 1) { if (detectOutParams.detectOutParams.decrease_label_id) {