[GPU] fix bug on resample_opt (#15434)
* fix bug: wrong feature slice num
This commit is contained in:
@@ -227,7 +227,6 @@ KERNEL (resample_opt)(__global INPUT0_TYPE* input,
|
||||
#endif
|
||||
const int f_block = get_group_id(1);
|
||||
const int b = get_global_id(2);
|
||||
const int feature_num = f_block * FEATURE_SLICE_SIZE + get_sub_group_local_id();
|
||||
const uint feature_block = f_block * FEATURE_SLICE_SIZE;
|
||||
|
||||
typedef IN_VEC_TYPE in_vec_t;
|
||||
|
||||
@@ -81,6 +81,18 @@ DeviceFeaturesKey ResampleKernelOpt::get_required_device_features_key(const Para
|
||||
return get_common_subgroups_device_features_key(params, options);
|
||||
}
|
||||
|
||||
static size_t get_vec_size(const resample_params ¶ms) {
|
||||
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
|
||||
return 2;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
static int get_feature_slice_size(const resample_params ¶ms) {
|
||||
return 16 * get_vec_size(params);
|
||||
}
|
||||
|
||||
ResampleKernelBase::DispatchData ResampleKernelOpt::SetDefault(const kernel_selector::resample_params &arg) const {
|
||||
DispatchData dispatchData;
|
||||
auto in_layout = arg.inputs[0].GetLayout();
|
||||
@@ -110,7 +122,7 @@ ResampleKernelBase::DispatchData ResampleKernelOpt::SetDefault(const kernel_sele
|
||||
} else {
|
||||
dispatchData.gws[0] = CeilDiv(out.X().v, opt_x_block_size) * out.Y().v;
|
||||
}
|
||||
dispatchData.gws[1] = Align(out.Feature().v, sub_group_size);
|
||||
dispatchData.gws[1] = Align(CeilDiv(out.Feature().v, get_vec_size(arg)), sub_group_size);
|
||||
dispatchData.gws[2] = arg.outputs[0].Batch().v;
|
||||
|
||||
dispatchData.lws[0] = 1;
|
||||
@@ -165,14 +177,8 @@ JitConstants ResampleKernelOpt::GetJitConstants(const resample_params ¶ms) c
|
||||
jit.AddConstant(MakeJitConstant("X_BLOCKS", CeilDiv(params.outputs[0].X().v, opt_x_block_size)));
|
||||
jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));
|
||||
|
||||
size_t vec_size = 0;
|
||||
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
|
||||
vec_size = 2;
|
||||
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 32));
|
||||
} else {
|
||||
vec_size = 1;
|
||||
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 16));
|
||||
}
|
||||
const size_t vec_size = get_vec_size(params);
|
||||
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", get_feature_slice_size(params)));
|
||||
jit.AddConstant(MakeJitConstant("VEC_SIZE", vec_size));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
|
||||
Reference in New Issue
Block a user