diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp index d5304e4c784..c3e8f267c40 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/softmax/softmax_kernel_bf.cpp @@ -113,16 +113,26 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis const auto& input = params.inputs[0]; DimensionAccessHelper dims(input); auto softmax_dim_y_bfyx = (params.dim == SoftmaxDim::Y && input.GetLayout() == DataLayout::bfyx); - const std::string flatten_bf = "(SOFTMAX_DIM_Y_BFYX&&(" + dims.f() + ">1))"; + auto softmax_dim_x_bfyx = (params.dim == SoftmaxDim::X && input.GetLayout() == DataLayout::bfyx); const std::string lws_0 = "get_local_size(0)"; - const std::string data_set_count = "(FLATTEN_BF?" + toVectorMulString({dims.f(), dims.b()}) + ":" + dims.b() + ")"; - const std::string data_set_size = "(FLATTEN_BF?" + dims.y() + ":" + toVectorMulString({dims.x(), dims.y(), dims.z(), dims.f()}) + ")"; + + std::string data_set_count; + std::string data_set_size; + if (softmax_dim_y_bfyx) { + data_set_count = toVectorMulString({dims.f(), dims.b()}); + data_set_size = dims.y(); + } else if (softmax_dim_x_bfyx) { + data_set_count = toVectorMulString({dims.f(), dims.b(), dims.y()}); + data_set_size = dims.x(); + } else { + data_set_count = dims.b(); + data_set_size = toVectorMulString({dims.x(), dims.y(), dims.z(), dims.f()}); + } + // It can be expected that the maximum possible itemsNum will not exceed 32 // Therefore, in dynamic shape, stack_size including additional buffer is set to 33 constexpr size_t stack_size = 33; // The size of stack for my_chunk jit.AddConstants({ - MakeJitConstant("SOFTMAX_DIM_Y_BFYX", softmax_dim_y_bfyx), - MakeJitConstant("FLATTEN_BF", flatten_bf), MakeJitConstant("LWS", lws_0), MakeJitConstant("SLM_SIZE", dispatchData.maxSlmSize), MakeJitConstant("DATA_SETS_COUNT", data_set_count),