[GPU] Fix for SoftmaxKernel_bf in dynamic case (#20769)

This commit is contained in:
Roman Lyamin 2023-10-31 09:02:03 +04:00 committed by GitHub
parent fc4fe07a0e
commit 50b6c5f0d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -113,16 +113,26 @@ JitConstants SoftmaxKernel_bf::GetJitConstants(const softmax_params& params, Dis
const auto& input = params.inputs[0];
DimensionAccessHelper dims(input);
auto softmax_dim_y_bfyx = (params.dim == SoftmaxDim::Y && input.GetLayout() == DataLayout::bfyx);
const std::string flatten_bf = "(SOFTMAX_DIM_Y_BFYX&&(" + dims.f() + ">1))";
auto softmax_dim_x_bfyx = (params.dim == SoftmaxDim::X && input.GetLayout() == DataLayout::bfyx);
const std::string lws_0 = "get_local_size(0)";
const std::string data_set_count = "(FLATTEN_BF?" + toVectorMulString({dims.f(), dims.b()}) + ":" + dims.b() + ")";
const std::string data_set_size = "(FLATTEN_BF?" + dims.y() + ":" + toVectorMulString({dims.x(), dims.y(), dims.z(), dims.f()}) + ")";
std::string data_set_count;
std::string data_set_size;
if (softmax_dim_y_bfyx) {
data_set_count = toVectorMulString({dims.f(), dims.b()});
data_set_size = dims.y();
} else if (softmax_dim_x_bfyx) {
data_set_count = toVectorMulString({dims.f(), dims.b(), dims.y()});
data_set_size = dims.x();
} else {
data_set_count = dims.b();
data_set_size = toVectorMulString({dims.x(), dims.y(), dims.z(), dims.f()});
}
// It can be expected that the maximum possible itemsNum will not exceed 32
// Therefore, in dynamic shape, stack_size including additional buffer is set to 33
constexpr size_t stack_size = 33; // The size of stack for my_chunk
jit.AddConstants({
MakeJitConstant("SOFTMAX_DIM_Y_BFYX", softmax_dim_y_bfyx),
MakeJitConstant("FLATTEN_BF", flatten_bf),
MakeJitConstant("LWS", lws_0),
MakeJitConstant("SLM_SIZE", dispatchData.maxSlmSize),
MakeJitConstant("DATA_SETS_COUNT", data_set_count),