From 06865a252a1e384cad530f2562aaefbaadbd9116 Mon Sep 17 00:00:00 2001 From: Ilya Znamenskiy Date: Mon, 13 Dec 2021 19:11:18 +0300 Subject: [PATCH] [GPU] TopK optimizations for value mode and OPERATION_NUM = 1 (#9170) [GPU] Cycle unrolling and minimization of unused thread number --- .../arg_max_min/arg_max_min_kernel_axis.cpp | 4 +- .../core/cl_kernels/arg_max_min_axis.cl | 66 +++++++++++++++---- 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/arg_max_min/arg_max_min_kernel_axis.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/arg_max_min/arg_max_min_kernel_axis.cpp index c23599ce74f..0cd765cc643 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/arg_max_min/arg_max_min_kernel_axis.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/arg_max_min/arg_max_min_kernel_axis.cpp @@ -77,11 +77,13 @@ KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const opti } const arg_max_min_params& orgParams = static_cast(params); + size_t ops_size = getOperationNumber(orgParams); + ops_size = ops_size > 1 ? Align(ops_size, 32) : 1; size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1; DispatchData dispatchData; - dispatchData.gws = { Align(getOperationNumber(orgParams), 32), sort_size, 1 }; + dispatchData.gws = { ops_size, sort_size, 1 }; dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); KernelData kd = KernelData::Default(params); diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/arg_max_min_axis.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/arg_max_min_axis.cl index 68d47ff8989..aa4373d889c 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/arg_max_min_axis.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/arg_max_min_axis.cl @@ -90,6 +90,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint last_group_offset = (group_num - 1) * group_size; #endif // SORT_BY_VALUE +#if OPERATION_NUM > 1 const uint output_idx = (uint)get_global_id(0); if (output_idx >= OPERATION_NUM) @@ -100,13 +101,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F - uint indices[] = {0, out_fourth_dim, 0, out_first_dim, out_second_dim}; // BFZYX + uint indices[] = { 0, out_fourth_dim, 0, out_first_dim, out_second_dim }; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = {0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim}; + uint indices[] = { 0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim }; #endif #endif #ifdef FEATURE_AXIS @@ -114,13 +115,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B - uint indices[] = {out_fourth_dim, 0, 0, out_first_dim, out_second_dim}; // BFZYX + uint indices[] = { out_fourth_dim, 0, 0, out_first_dim, out_second_dim }; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = {out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim}; + uint indices[] = { out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim }; #endif #endif #ifdef Z_AXIS @@ -128,20 +129,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = {out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim}; + uint indices[] = { out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim }; #endif #ifdef Y_AXIS #ifdef OUTPUT_LAYOUT_YXFB const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B - uint indices[] = {out_fourth_dim, out_second_dim, 0, 0, out_first_dim}; // BFZYX + uint indices[] = { out_fourth_dim, out_second_dim, 0, 0, out_first_dim }; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = {out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim}; + uint indices[] = { out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim }; #endif #endif #ifdef X_AXIS @@ -149,16 +150,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B - uint indices[] = {out_fourth_dim, out_second_dim, 0, out_first_dim, 0}; // BFZYX + uint indices[] = { out_fourth_dim, out_second_dim, 0, out_first_dim, 0 }; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y - uint indices[] = {out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0}; + uint indices[] = { out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0 }; #endif #endif +#else // OPERATION_NUM > 1 + uint indices[] = { 0, 0, 0, 0, 0 }; +#endif // OPERATION_NUM > 1 + // Using parallel sorting for sorting by values #if SORT_BY_VALUE uint sort_position = 0; @@ -168,15 +173,54 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; result.index = sort_idx; - for (uint i = 0; i < sort_idx; i++) { - indices[AXIS] = i; + for (uint i = 0; i < sort_idx / 8; i++) { + uint index_offset = i * 8; + indices[AXIS] = index_offset; INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 1; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 2; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 3; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 4; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 5; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 6; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + indices[AXIS] = index_offset + 7; + test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; if (sort_position >= TOP_K) return; } + for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) { + indices[AXIS] = i; + INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])]; + if (result.value COMPARE_PARALLEL_SIGN_1 test_value) + sort_position++; + } + + if (sort_position >= TOP_K) + return; + for (uint i = sort_idx + 1; i < VALUES_NUM; i++) { indices[AXIS] = i; INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];