[GPU] TopK optimizations for value mode and OPERATION_NUM = 1 (#9170)

[GPU] Cycle unrolling and minimization of unused thread number
This commit is contained in:
Ilya Znamenskiy 2021-12-13 19:11:18 +03:00 committed by GitHub
parent 4342473120
commit 06865a252a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 12 deletions

View File

@ -77,11 +77,13 @@ KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const opti
}
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
size_t ops_size = getOperationNumber(orgParams);
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
DispatchData dispatchData;
dispatchData.gws = { Align(getOperationNumber(orgParams), 32), sort_size, 1 };
dispatchData.gws = { ops_size, sort_size, 1 };
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
KernelData kd = KernelData::Default<arg_max_min_params>(params);

View File

@ -90,6 +90,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint last_group_offset = (group_num - 1) * group_size;
#endif // SORT_BY_VALUE
#if OPERATION_NUM > 1
const uint output_idx = (uint)get_global_id(0);
if (output_idx >= OPERATION_NUM)
@ -100,13 +101,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y
const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X
const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F
uint indices[] = {0, out_fourth_dim, 0, out_first_dim, out_second_dim}; // BFZYX
uint indices[] = { 0, out_fourth_dim, 0, out_first_dim, out_second_dim }; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = {0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim};
uint indices[] = { 0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim };
#endif
#endif
#ifdef FEATURE_AXIS
@ -114,13 +115,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
uint indices[] = {out_fourth_dim, 0, 0, out_first_dim, out_second_dim}; // BFZYX
uint indices[] = { out_fourth_dim, 0, 0, out_first_dim, out_second_dim }; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = {out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim};
uint indices[] = { out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim };
#endif
#endif
#ifdef Z_AXIS
@ -128,20 +129,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = {out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim};
uint indices[] = { out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim };
#endif
#ifdef Y_AXIS
#ifdef OUTPUT_LAYOUT_YXFB
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
uint indices[] = {out_fourth_dim, out_second_dim, 0, 0, out_first_dim}; // BFZYX
uint indices[] = { out_fourth_dim, out_second_dim, 0, 0, out_first_dim }; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = {out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim};
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim };
#endif
#endif
#ifdef X_AXIS
@ -149,16 +150,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
uint indices[] = {out_fourth_dim, out_second_dim, 0, out_first_dim, 0}; // BFZYX
uint indices[] = { out_fourth_dim, out_second_dim, 0, out_first_dim, 0 }; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F
const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z
const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y
uint indices[] = {out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0};
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0 };
#endif
#endif
#else // OPERATION_NUM > 1
uint indices[] = { 0, 0, 0, 0, 0 };
#endif // OPERATION_NUM > 1
// Using parallel sorting for sorting by values
#if SORT_BY_VALUE
uint sort_position = 0;
@ -168,15 +173,54 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
result.index = sort_idx;
for (uint i = 0; i < sort_idx; i++) {
indices[AXIS] = i;
for (uint i = 0; i < sort_idx / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = index_offset;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 1;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 2;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 3;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 4;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 5;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 6;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 7;
test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
if (sort_position >= TOP_K)
return;
}
for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
}
if (sort_position >= TOP_K)
return;
for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];