[GPU] TopK optimizations for value mode and OPERATION_NUM = 1 (#9170)

[GPU] Cycle unrolling and minimization of unused thread number
This commit is contained in:
Ilya Znamenskiy 2021-12-13 19:11:18 +03:00 committed by GitHub
parent 4342473120
commit 06865a252a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 12 deletions

View File

@ -77,11 +77,13 @@ KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const opti
}
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
size_t ops_size = getOperationNumber(orgParams);
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
DispatchData dispatchData;
dispatchData.gws = { Align(getOperationNumber(orgParams), 32), sort_size, 1 };
dispatchData.gws = { ops_size, sort_size, 1 };
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
KernelData kd = KernelData::Default<arg_max_min_params>(params);

View File

@ -90,6 +90,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint last_group_offset = (group_num - 1) * group_size;
#endif // SORT_BY_VALUE
#if OPERATION_NUM > 1
const uint output_idx = (uint)get_global_id(0);
if (output_idx >= OPERATION_NUM)
@ -159,6 +160,10 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
#endif
#endif
#else // OPERATION_NUM > 1
uint indices[] = { 0, 0, 0, 0, 0 };
#endif // OPERATION_NUM > 1
// Using parallel sorting for sorting by values
#if SORT_BY_VALUE
uint sort_position = 0;
@ -168,15 +173,54 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
result.index = sort_idx;
for (uint i = 0; i < sort_idx; i++) {
indices[AXIS] = i;
for (uint i = 0; i < sort_idx / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = index_offset;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 1;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 2;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 3;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 4;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 5;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 6;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 7;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
if (sort_position >= TOP_K)
return;
}
for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
}
if (sort_position >= TOP_K)
return;
for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];