[GPU] TopK optimizations for value mode and OPERATION_NUM = 1 (#9170)
[GPU] Cycle unrolling and minimization of unused thread number
This commit is contained in:
parent
4342473120
commit
06865a252a
@ -77,11 +77,13 @@ KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const opti
|
||||
}
|
||||
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
|
||||
|
||||
size_t ops_size = getOperationNumber(orgParams);
|
||||
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
|
||||
size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
|
||||
|
||||
DispatchData dispatchData;
|
||||
|
||||
dispatchData.gws = { Align(getOperationNumber(orgParams), 32), sort_size, 1 };
|
||||
dispatchData.gws = { ops_size, sort_size, 1 };
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||
|
||||
KernelData kd = KernelData::Default<arg_max_min_params>(params);
|
||||
|
@ -90,6 +90,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
||||
const uint last_group_offset = (group_num - 1) * group_size;
|
||||
#endif // SORT_BY_VALUE
|
||||
|
||||
#if OPERATION_NUM > 1
|
||||
const uint output_idx = (uint)get_global_id(0);
|
||||
|
||||
if (output_idx >= OPERATION_NUM)
|
||||
@ -100,13 +101,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y
|
||||
const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X
|
||||
const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F
|
||||
uint indices[] = {0, out_fourth_dim, 0, out_first_dim, out_second_dim}; // BFZYX
|
||||
uint indices[] = { 0, out_fourth_dim, 0, out_first_dim, out_second_dim }; // BFZYX
|
||||
#else
|
||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F
|
||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
|
||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||
uint indices[] = {0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim};
|
||||
uint indices[] = { 0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim };
|
||||
#endif
|
||||
#endif
|
||||
#ifdef FEATURE_AXIS
|
||||
@ -114,13 +115,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y
|
||||
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X
|
||||
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
||||
uint indices[] = {out_fourth_dim, 0, 0, out_first_dim, out_second_dim}; // BFZYX
|
||||
uint indices[] = { out_fourth_dim, 0, 0, out_first_dim, out_second_dim }; // BFZYX
|
||||
#else
|
||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B
|
||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
|
||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||
uint indices[] = {out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim};
|
||||
uint indices[] = { out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim };
|
||||
#endif
|
||||
#endif
|
||||
#ifdef Z_AXIS
|
||||
@ -128,20 +129,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
|
||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||
uint indices[] = {out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim};
|
||||
uint indices[] = { out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim };
|
||||
#endif
|
||||
#ifdef Y_AXIS
|
||||
#ifdef OUTPUT_LAYOUT_YXFB
|
||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X
|
||||
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
|
||||
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
||||
uint indices[] = {out_fourth_dim, out_second_dim, 0, 0, out_first_dim}; // BFZYX
|
||||
uint indices[] = { out_fourth_dim, out_second_dim, 0, 0, out_first_dim }; // BFZYX
|
||||
#else
|
||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B
|
||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
|
||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z
|
||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||
uint indices[] = {out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim};
|
||||
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim };
|
||||
#endif
|
||||
#endif
|
||||
#ifdef X_AXIS
|
||||
@ -149,16 +150,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y
|
||||
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
|
||||
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
||||
uint indices[] = {out_fourth_dim, out_second_dim, 0, out_first_dim, 0}; // BFZYX
|
||||
uint indices[] = { out_fourth_dim, out_second_dim, 0, out_first_dim, 0 }; // BFZYX
|
||||
#else
|
||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B
|
||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F
|
||||
const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z
|
||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y
|
||||
uint indices[] = {out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0};
|
||||
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0 };
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#else // OPERATION_NUM > 1
|
||||
uint indices[] = { 0, 0, 0, 0, 0 };
|
||||
#endif // OPERATION_NUM > 1
|
||||
|
||||
// Using parallel sorting for sorting by values
|
||||
#if SORT_BY_VALUE
|
||||
uint sort_position = 0;
|
||||
@ -168,15 +173,54 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
||||
result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
result.index = sort_idx;
|
||||
|
||||
for (uint i = 0; i < sort_idx; i++) {
|
||||
indices[AXIS] = i;
|
||||
for (uint i = 0; i < sort_idx / 8; i++) {
|
||||
uint index_offset = i * 8;
|
||||
indices[AXIS] = index_offset;
|
||||
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 1;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 2;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 3;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 4;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 5;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 6;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
indices[AXIS] = index_offset + 7;
|
||||
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
if (sort_position >= TOP_K)
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) {
|
||||
indices[AXIS] = i;
|
||||
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||
sort_position++;
|
||||
}
|
||||
|
||||
if (sort_position >= TOP_K)
|
||||
return;
|
||||
|
||||
for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
|
||||
indices[AXIS] = i;
|
||||
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
|
||||
|
Loading…
Reference in New Issue
Block a user