[GPU] fix bug on resample_opt (#15434)

* fix bug: wrong feature slice num
This commit is contained in:
Sungeun Kim
2023-02-07 16:29:18 +09:00
committed by GitHub
parent b341d641d2
commit 00d9ed0da4
2 changed files with 15 additions and 10 deletions

View File

@@ -227,7 +227,6 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
#endif
const int f_block = get_group_id(1);
const int b = get_global_id(2);
const int feature_num = f_block * FEATURE_SLICE_SIZE + get_sub_group_local_id();
const uint feature_block = f_block * FEATURE_SLICE_SIZE;
typedef IN_VEC_TYPE in_vec_t;

View File

@@ -81,6 +81,18 @@ DeviceFeaturesKey ResampleKernelOpt::get_required_device_features_key(const Para
return get_common_subgroups_device_features_key(params, options);
}
static size_t get_vec_size(const resample_params &params) {
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
return 2;
} else {
return 1;
}
}
static int get_feature_slice_size(const resample_params &params) {
return 16 * get_vec_size(params);
}
ResampleKernelBase::DispatchData ResampleKernelOpt::SetDefault(const kernel_selector::resample_params &arg) const {
DispatchData dispatchData;
auto in_layout = arg.inputs[0].GetLayout();
@@ -110,7 +122,7 @@ ResampleKernelBase::DispatchData ResampleKernelOpt::SetDefault(const kernel_sele
} else {
dispatchData.gws[0] = CeilDiv(out.X().v, opt_x_block_size) * out.Y().v;
}
dispatchData.gws[1] = Align(out.Feature().v, sub_group_size);
dispatchData.gws[1] = Align(CeilDiv(out.Feature().v, get_vec_size(arg)), sub_group_size);
dispatchData.gws[2] = arg.outputs[0].Batch().v;
dispatchData.lws[0] = 1;
@@ -165,14 +177,8 @@ JitConstants ResampleKernelOpt::GetJitConstants(const resample_params &params) c
jit.AddConstant(MakeJitConstant("X_BLOCKS", CeilDiv(params.outputs[0].X().v, opt_x_block_size)));
jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));
size_t vec_size = 0;
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
vec_size = 2;
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 32));
} else {
vec_size = 1;
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 16));
}
const size_t vec_size = get_vec_size(params);
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", get_feature_slice_size(params)));
jit.AddConstant(MakeJitConstant("VEC_SIZE", vec_size));
if (!params.fused_ops.empty()) {