From 2956717118ba1397eb152471bc08d3cd94a14b4b Mon Sep 17 00:00:00 2001 From: Andrew Kwangwoong Park Date: Sun, 26 Mar 2023 14:32:17 +0900 Subject: [PATCH] [GPU] Added shape agnostic TopK kernel (#16161) * [GPU] Added shape agnostic TopK kernel implementation Signed-off-by: Andrew Park * Update kernel to use internal buffers for shape agnostic kernel Signed-off-by: Andrew Park * Add WA to compile_graph for shape agnostic arg_max_min_axis with non-const k input Signed-off-by: Andrew Park * Fix is_dynamic pameter for FillCLKernelData with the case where the output is static shape Signed-off-by: Andrew Park * Fix corner case where inbuf size becomes 0 when ops_size is 1 Signed-off-by: Andrew Park --------- Signed-off-by: Andrew Park --- .../graph/graph_optimizer/compile_graph.cpp | 6 + .../src/graph/impls/ocl/arg_max_min.cpp | 43 +- .../cl_kernels/arg_max_min_axis.cl | 445 ++++++++++-------- .../cl_kernels/arg_max_min_gpu_ref.cl | 5 +- .../arg_max_min/arg_max_min_kernel_axis.cpp | 103 +++- .../arg_max_min/arg_max_min_kernel_axis.h | 1 + .../arg_max_min/arg_max_min_kernel_base.cpp | 27 +- .../arg_max_min_kernel_gpu_ref.cpp | 1 + .../tests/test_cases/arg_max_gpu_test.cpp | 51 ++ 9 files changed, 441 insertions(+), 241 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp index 995c6f965b5..83d2fb224af 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp @@ -7,6 +7,7 @@ #include "mutable_data_inst.h" #include "reshape_inst.h" #include "quantize_inst.h" +#include "arg_max_min_inst.h" #include "program_node.h" #include "intel_gpu/runtime/engine.hpp" #include "intel_gpu/runtime/itt.hpp" @@ -51,6 +52,11 @@ void compile_graph::run(program& p) { if (node->is_type() && node->is_dynamic() && node->get_output_layout().get_partial_shape().size() > 3) can_select_impl = false; + // TODO: Remove this WA once we have shape agnostic arg_max_min_axis kernel with non-const k input + if (node->is_type() && node->is_dynamic() && node->as().get_primitive()->top_k == 0) { + can_select_impl = false; + } + bool is_planar = node->get_output_layout().format == format::bfyx || node->get_output_layout().format == format::bfzyx || node->get_output_layout().format == format::bfwzyx; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/arg_max_min.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/arg_max_min.cpp index 2e91cd92ab4..094c3194cc7 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/arg_max_min.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/arg_max_min.cpp @@ -59,7 +59,7 @@ protected: } public: - static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) { + static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { const auto& primitive = impl_param.typed_desc(); const auto& axis = primitive->axis; const auto& top_k = primitive->top_k; @@ -68,7 +68,7 @@ public: const auto& values_first = primitive->values_first; const auto& outputs_num = primitive->input_size() == 3 ? 2 : static_cast(primitive->output_size()); - auto argm_params = get_default_params(impl_param); + auto argm_params = get_default_params(impl_param, is_shape_agnostic); auto argm_optional_params = get_default_optional_params(impl_param.get_program()); @@ -76,7 +76,7 @@ public: argm_params.argMaxMinAxis = GetArgMaxMinAxis(axis, impl_param.get_output_layout().get_rank()); auto& constant_mem = impl_param.memory_deps; - if (constant_mem.count(1)) { + if (constant_mem.count(1) && !argm_params.has_dynamic_outputs()) { // The topK could be got by reading impl_param.memory_deps.at(1). // However, here we utilize output_layout and axis information to minimize mem_lock. auto output_layout = impl_param.get_output_layout(0); @@ -110,26 +110,45 @@ public: return {argm_params, argm_optional_params}; } + + void update_dispatch_data(const kernel_impl_params& impl_param) override { + auto kernel_params = get_kernel_params(impl_param, true); + (_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data); + update_kernels_list_to_skip(); + } }; namespace detail { attach_arg_max_min_impl::attach_arg_max_min_impl() { auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::i32}; - auto formats = {format::bfyx, - format::yxfb, - format::b_fs_yx_fsv16, - format::b_fs_yx_fsv32, - format::bs_fs_yx_bsv16_fsv16, - format::bs_fs_yx_bsv32_fsv16, - format::bs_fs_yx_bsv32_fsv32, - - format::bfzyx}; + auto formats = { + format::bfyx, + format::yxfb, + format::b_fs_yx_fsv16, + format::b_fs_yx_fsv32, + format::bs_fs_yx_bsv16_fsv16, + format::bs_fs_yx_bsv32_fsv16, + format::bs_fs_yx_bsv32_fsv32, + format::bfzyx + }; implementation_map::add(impl_types::ocl, + shape_types::static_shape, typed_primitive_impl_ocl::create, types, formats); + + auto dyn_formats = { + format::bfyx, + format::bfzyx + }; + + implementation_map::add(impl_types::ocl, + shape_types::dynamic_shape, + typed_primitive_impl_ocl::create, + types, + dyn_formats); } } // namespace detail } // namespace ocl diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_axis.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_axis.cl index 1f0af43f1c9..dc65e37a244 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_axis.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_axis.cl @@ -43,48 +43,22 @@ #define MINIMUM_NUMBER_FOR_PARTIAL_SORTING 100 -KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input - ,__global OUTPUT_TYPE* output -#ifdef SECOND_OUTPUT_EXIST -#ifdef MULTIPLE_OUTPUTS - ,__global OUTPUT1_TYPE* second_output -#else - ,__global INPUT1_TYPE* second_output -#endif -#endif - ) +inline void FUNC(get_indices_from_dims)(OPTIONAL_SHAPE_INFO_ARG + const uint output_idx, + uint* indices) { -#include "include/arg_max_min_common.cl" -#if SORT_BY_VALUE - const uint sort_idx = (uint)get_global_id(1); -#elif TOP_K == 1 - iav_type result[TOP_K]; -#else - iav_type result[VALUES_NUM], temp_buf[VALUES_NUM]; - const uint group_size = TOP_K >= 8 ? TOP_K : 8; - const uint group_num = ((VALUES_NUM - 1) / group_size) + 1; - const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size; - const uint last_group_offset = (group_num - 1) * group_size; -#endif // SORT_BY_VALUE - -#if OPERATION_NUM > 1 - const uint output_idx = (uint)get_global_id(0); - - if (output_idx >= OPERATION_NUM) - return; - #ifdef BATCH_AXIS #ifdef OUTPUT_LAYOUT_YXFB const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F - uint indices[] = { 0, out_fourth_dim, 0, out_first_dim, out_second_dim }; // BFZYX + indices[1] = out_fourth_dim; indices[3] = out_first_dim; indices[4] = out_second_dim; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = { 0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim }; + indices[1] = out_first_dim; indices[2] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim; #endif #endif #ifdef FEATURE_AXIS @@ -92,13 +66,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B - uint indices[] = { out_fourth_dim, 0, 0, out_first_dim, out_second_dim }; // BFZYX + indices[0] = out_fourth_dim; indices[3] = out_first_dim; indices[4] = out_second_dim; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = { out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim }; + indices[0] = out_first_dim; indices[2] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim; #endif #endif #ifdef Z_AXIS @@ -106,20 +80,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = { out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim }; + indices[0] = out_first_dim; indices[1] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim; #endif #ifdef Y_AXIS #ifdef OUTPUT_LAYOUT_YXFB const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B - uint indices[] = { out_fourth_dim, out_second_dim, 0, 0, out_first_dim }; // BFZYX + indices[0] = out_fourth_dim; indices[1] = out_second_dim; indices[4] = out_first_dim; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X - uint indices[] = { out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim }; + indices[0] = out_first_dim; indices[1] = out_second_dim; indices[2] = out_third_dim; indices[4] = out_fourth_dim; #endif #endif #ifdef X_AXIS @@ -127,19 +101,63 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B - uint indices[] = { out_fourth_dim, out_second_dim, 0, out_first_dim, 0 }; // BFZYX + indices[0] = out_fourth_dim; indices[1] = out_second_dim; indices[3] = out_first_dim; // BFZYX #else const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y - uint indices[] = { out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0 }; + indices[0] = out_first_dim; indices[1] = out_second_dim; indices[2] = out_third_dim; indices[3] = out_fourth_dim; #endif #endif +} -#else // OPERATION_NUM > 1 +KERNEL(arg_max_min_modified)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input + ,__global OUTPUT_TYPE* output +#ifdef SECOND_OUTPUT_EXIST +#ifdef MULTIPLE_OUTPUTS + ,__global OUTPUT1_TYPE* second_output +#else + ,__global INPUT1_TYPE* second_output +#endif +#endif +#ifdef IS_DYNAMIC + ,__global INPUT0_TYPE* tmp_buffer0 + ,__global INPUT0_TYPE* tmp_buffer1 + ,__global INPUT0_TYPE* tmp_buffer2 +#endif +) +{ +#include "include/arg_max_min_common.cl" + const uint output_idx = (uint)get_global_id(0); +#if SORT_BY_VALUE + const uint sort_idx = (uint)get_global_id(1); +#elif TOP_K == 1 + iav_type result[TOP_K]; +#else +#ifdef IS_DYNAMIC + const uint iav_type_size = INPUT0_TYPE_SIZE + 4; + const uint buffer_size = iav_type_size * VALUES_NUM; + const uint buffer_offset = buffer_size * OPERATION_NUM; + __global iav_type *result = OFFSET_GLOBAL_PTR(iav_type, tmp_buffer0, output_idx * buffer_size); + __global iav_type *temp_buf = OFFSET_GLOBAL_PTR(iav_type, tmp_buffer0, buffer_offset + output_idx * buffer_size); +#else + iav_type result[VALUES_NUM], temp_buf[VALUES_NUM]; +#endif + const uint group_size = TOP_K >= 8 ? TOP_K : 8; + const uint group_num = ((VALUES_NUM - 1) / group_size) + 1; + const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size; + const uint last_group_offset = (group_num - 1) * group_size; +#endif // SORT_BY_VALUE uint indices[] = { 0, 0, 0, 0, 0 }; -#endif // OPERATION_NUM > 1 + + if (OPERATION_NUM > 1) { + if (output_idx >= OPERATION_NUM) + return; + FUNC_CALL(get_indices_from_dims)(OPTIONAL_SHAPE_INFO_TENSOR output_idx, indices); + } // Using parallel sorting for sorting by values #if SORT_BY_VALUE @@ -147,41 +165,41 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input indices[AXIS] = sort_idx; iav_type result; - result.value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + result.value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; result.index = sort_idx; for (uint i = 0; i < sort_idx / 8; i++) { uint index_offset = i * 8; indices[AXIS] = index_offset; - INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 1; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 2; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 3; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 4; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 5; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 6; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; indices[AXIS] = index_offset + 7; - test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; if (sort_position >= TOP_K) @@ -190,7 +208,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) { indices[AXIS] = i; - INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_1 test_value) sort_position++; } @@ -200,7 +218,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input for (uint i = sort_idx + 1; i < VALUES_NUM; i++) { indices[AXIS] = i; - INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (result.value COMPARE_PARALLEL_SIGN_2 test_value) sort_position++; if (sort_position >= TOP_K) @@ -209,7 +227,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input // Using simple sorting for sorting by indices and when TOP_K == 1 #elif TOP_K == 1 - INPUT0_TYPE val = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + INPUT0_TYPE val = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; result[0].index = 0; result[0].value = val; bool already_exist = false; @@ -228,7 +246,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input } indices[AXIS] = i; - INPUT0_TYPE in_data = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + INPUT0_TYPE in_data = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; if (val COMPARE_SIGN in_data) { result[top_k].index = i; result[top_k].value = in_data; @@ -237,197 +255,206 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input } val = INPUT0_FILL_VAL; } - -// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING) -#elif ((TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)) - for (uint i = 0; i < VALUES_NUM / 8; i++) { - uint index_offset = i * 8; - indices[AXIS] = result[index_offset].index = index_offset; - result[index_offset].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 1].index = index_offset + 1; - result[index_offset + 1].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 2].index = index_offset + 2; - result[index_offset + 2].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 3].index = index_offset + 3; - result[index_offset + 3].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 4].index = index_offset + 4; - result[index_offset + 4].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 5].index = index_offset + 5; - result[index_offset + 5].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 6].index = index_offset + 6; - result[index_offset + 6].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = result[index_offset + 7].index = index_offset + 7; - result[index_offset + 7].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - } - - for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) { - indices[AXIS] = result[i].index = i; - result[i].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - } - - for (uint k = 1; k < VALUES_NUM; k *= 2) { - for (uint left = 0; left + k < VALUES_NUM; left += k * 2) { - uint i, j, m; - uint right = left + k; - uint right_end = right + k; - if (right_end > VALUES_NUM) right_end = VALUES_NUM; - m = i = left; j = right; - while ((i < right) && (j < right_end)) { - if (result[i].value COMPARE_PARTIAL_SIGN result[j].value) { - temp_buf[m++] = result[i++]; - } else { - temp_buf[m++] = result[j++]; - } - } - while (i < right) - temp_buf[m++] = result[i++]; - while (j < right_end) - temp_buf[m++] = result[j++]; - for (m = left; m < right_end; m++) - result[m] = temp_buf[m]; - } - } - -// In other cases for sorting by indices using mixed partial/merge sorting #else // SORT_BY_VALUE - for (uint i = 0; i < VALUES_NUM / 8; i++) { - uint index_offset = i * 8; - indices[AXIS] = temp_buf[index_offset].index = index_offset; - temp_buf[index_offset].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 1].index = index_offset + 1; - temp_buf[index_offset + 1].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 2].index = index_offset + 2; - temp_buf[index_offset + 2].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 3].index = index_offset + 3; - temp_buf[index_offset + 3].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 4].index = index_offset + 4; - temp_buf[index_offset + 4].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 5].index = index_offset + 5; - temp_buf[index_offset + 5].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 6].index = index_offset + 6; - temp_buf[index_offset + 6].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - indices[AXIS] = temp_buf[index_offset + 7].index = index_offset + 7; - temp_buf[index_offset + 7].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - } + // Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING) + bool use_merge_sorting = (TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING); + if (use_merge_sorting) { + for (uint i = 0; i < VALUES_NUM / 8; i++) { + uint index_offset = i * 8; + indices[AXIS] = result[index_offset].index = index_offset; + result[index_offset].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 1].index = index_offset + 1; + result[index_offset + 1].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 2].index = index_offset + 2; + result[index_offset + 2].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 3].index = index_offset + 3; + result[index_offset + 3].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 4].index = index_offset + 4; + result[index_offset + 4].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 5].index = index_offset + 5; + result[index_offset + 5].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 6].index = index_offset + 6; + result[index_offset + 6].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = result[index_offset + 7].index = index_offset + 7; + result[index_offset + 7].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + } - for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) { - indices[AXIS] = temp_buf[i].index = i; - temp_buf[i].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; - } + for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) { + indices[AXIS] = result[i].index = i; + result[i].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + } - for (uint group = 0; group < group_num - 1; group++) { - uint group_offset = group * group_size; - for (uint k = 1; k < group_size; k *= 2) { - for (uint left = 0; left + k < group_size; left += k * 2) { + for (uint k = 1; k < VALUES_NUM; k *= 2) { + for (uint left = 0; left + k < VALUES_NUM; left += k * 2) { uint i, j, m; uint right = left + k; uint right_end = right + k; - if (right_end > group_size) right_end = group_size; + if (right_end > VALUES_NUM) right_end = VALUES_NUM; m = i = left; j = right; while ((i < right) && (j < right_end)) { - if (temp_buf[group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[group_offset + j].value) { - result[group_offset + (m++)] = temp_buf[group_offset + (i++)]; + if (result[i].value COMPARE_PARTIAL_SIGN result[j].value) { + temp_buf[m++] = result[i++]; } else { - result[group_offset + (m++)] = temp_buf[group_offset + (j++)]; + temp_buf[m++] = result[j++]; } } while (i < right) - result[group_offset + (m++)] = temp_buf[group_offset + (i++)]; + temp_buf[m++] = result[i++]; while (j < right_end) - result[group_offset + (m++)] = temp_buf[group_offset + (j++)]; + temp_buf[m++] = result[j++]; for (m = left; m < right_end; m++) - temp_buf[group_offset + m] = result[group_offset + m]; + result[m] = temp_buf[m]; } } - } + } else { + // In other cases for sorting by indices using mixed partial/merge sorting + for (uint i = 0; i < VALUES_NUM / 8; i++) { + uint index_offset = i * 8; + indices[AXIS] = temp_buf[index_offset].index = index_offset; + temp_buf[index_offset].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 1].index = index_offset + 1; + temp_buf[index_offset + 1].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 2].index = index_offset + 2; + temp_buf[index_offset + 2].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 3].index = index_offset + 3; + temp_buf[index_offset + 3].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 4].index = index_offset + 4; + temp_buf[index_offset + 4].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 5].index = index_offset + 5; + temp_buf[index_offset + 5].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 6].index = index_offset + 6; + temp_buf[index_offset + 6].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + indices[AXIS] = temp_buf[index_offset + 7].index = index_offset + 7; + temp_buf[index_offset + 7].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + } - for (uint k = 1; k < last_group_size; k *= 2) { - for (uint left = 0; left + k < last_group_size; left += k * 2) { - uint i, j, m; - uint right = left + k; - uint right_end = right + k; - if (right_end > last_group_size) right_end = last_group_size; - m = i = left; j = right; - while ((i < right) && (j < right_end)) { - if (temp_buf[last_group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[last_group_offset + j].value) { - result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)]; - } else { - result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)]; + for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) { + indices[AXIS] = temp_buf[i].index = i; + temp_buf[i].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])]; + } + + for (uint group = 0; group < group_num - 1; group++) { + uint group_offset = group * group_size; + for (uint k = 1; k < group_size; k *= 2) { + for (uint left = 0; left + k < group_size; left += k * 2) { + uint i, j, m; + uint right = left + k; + uint right_end = right + k; + if (right_end > group_size) right_end = group_size; + m = i = left; j = right; + while ((i < right) && (j < right_end)) { + if (temp_buf[group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[group_offset + j].value) { + result[group_offset + (m++)] = temp_buf[group_offset + (i++)]; + } else { + result[group_offset + (m++)] = temp_buf[group_offset + (j++)]; + } + } + while (i < right) + result[group_offset + (m++)] = temp_buf[group_offset + (i++)]; + while (j < right_end) + result[group_offset + (m++)] = temp_buf[group_offset + (j++)]; + for (m = left; m < right_end; m++) + temp_buf[group_offset + m] = result[group_offset + m]; } } - while (i < right) - result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)]; - while (j < right_end) - result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)]; - for (m = left; m < right_end; m++) - temp_buf[last_group_offset + m] = result[last_group_offset + m]; } - } - uint merge_counter[group_num]; - uint max_merge_counter[group_num]; - iav_type merge_buf; - bool subgroup_done[group_num]; - - unroll_for (uint i = 0; i < group_num - 1; i++) { - merge_counter[i] = 0; - max_merge_counter[i] = group_size; - subgroup_done[i] = false; - } - - merge_counter[group_num - 1] = 0; - max_merge_counter[group_num - 1] = last_group_size; - subgroup_done[group_num - 1] = false; - - for (uint i = 0; i < TOP_K; i++) { - bool merge_buf_done = false; - uint merge_buf_index = 0; - for (uint j = 0; j < group_num; j++) { - if (subgroup_done[j]) - continue; - - uint test_index = j * group_size + merge_counter[j]; - - if (!merge_buf_done) { - merge_buf = temp_buf[test_index]; - merge_buf_done = true; - merge_buf_index = j; - continue; - } - - if (temp_buf[test_index].value COMPARE_MERGE_SIGN merge_buf.value) { - merge_buf = temp_buf[test_index]; - merge_buf_index = j; + for (uint k = 1; k < last_group_size; k *= 2) { + for (uint left = 0; left + k < last_group_size; left += k * 2) { + uint i, j, m; + uint right = left + k; + uint right_end = right + k; + if (right_end > last_group_size) right_end = last_group_size; + m = i = left; j = right; + while ((i < right) && (j < right_end)) { + if (temp_buf[last_group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[last_group_offset + j].value) { + result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)]; + } else { + result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)]; + } + } + while (i < right) + result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)]; + while (j < right_end) + result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)]; + for (m = left; m < right_end; m++) + temp_buf[last_group_offset + m] = result[last_group_offset + m]; } } - merge_counter[merge_buf_index]++; - if (merge_counter[merge_buf_index] == max_merge_counter[merge_buf_index]) - subgroup_done[merge_buf_index] = true; + #ifdef IS_DYNAMIC + const uint counter_size = group_num * 4; + const uint counter_offset = counter_size * OPERATION_NUM; + __global uint* merge_counter = OFFSET_GLOBAL_PTR(uint, tmp_buffer1, output_idx * counter_size); + __global uint* max_merge_counter = OFFSET_GLOBAL_PTR(uint, tmp_buffer1, counter_offset + output_idx * counter_size); + __global bool* subgroup_done = OFFSET_GLOBAL_PTR(bool, tmp_buffer2, output_idx); + #else + uint merge_counter[group_num]; + uint max_merge_counter[group_num]; + bool subgroup_done[group_num]; + #endif + iav_type merge_buf; - result[i] = merge_buf; + unroll_for (uint i = 0; i < group_num - 1; i++) { + merge_counter[i] = 0; + max_merge_counter[i] = group_size; + subgroup_done[i] = false; + } + + merge_counter[group_num - 1] = 0; + max_merge_counter[group_num - 1] = last_group_size; + subgroup_done[group_num - 1] = false; + + for (uint i = 0; i < TOP_K; i++) { + bool merge_buf_done = false; + uint merge_buf_index = 0; + for (uint j = 0; j < group_num; j++) { + if (subgroup_done[j]) + continue; + + uint test_index = j * group_size + merge_counter[j]; + + if (!merge_buf_done) { + merge_buf = temp_buf[test_index]; + merge_buf_done = true; + merge_buf_index = j; + continue; + } + + if (temp_buf[test_index].value COMPARE_MERGE_SIGN merge_buf.value) { + merge_buf = temp_buf[test_index]; + merge_buf_index = j; + } + } + + merge_counter[merge_buf_index]++; + if (merge_counter[merge_buf_index] == max_merge_counter[merge_buf_index]) + subgroup_done[merge_buf_index] = true; + + result[i] = merge_buf; + } } #endif // SORT_BY_VALUE #if SORT_BY_VALUE indices[AXIS] = sort_position; #ifdef TOP_K_ORDER - output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value); + output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value); #else - output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index); + output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index); #endif #ifdef SECOND_OUTPUT_EXIST #ifdef MULTIPLE_OUTPUTS #ifdef TOP_K_ORDER - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.index); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.index); #else - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.value); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.value); #endif #else #ifdef TOP_K_ORDER - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index); #else - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value); #endif #endif #endif @@ -445,22 +472,22 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input indices[AXIS] = out_position; #ifdef TOP_K_ORDER - output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value); + output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value); #else - output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index); + output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index); #endif #ifdef SECOND_OUTPUT_EXIST #ifdef MULTIPLE_OUTPUTS #ifdef TOP_K_ORDER - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].index); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].index); #else - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].value); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].value); #endif #else #ifdef TOP_K_ORDER - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index); #else - second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value); + second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value); #endif #endif #endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_gpu_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_gpu_ref.cl index 8732adf90e4..f5703ccd36c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_gpu_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/arg_max_min_gpu_ref.cl @@ -16,7 +16,10 @@ #endif __attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1))) -KERNEL(arg_max_gpu_top_k)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output) +KERNEL(arg_max_gpu_top_k)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output) { #include "include/arg_max_min_common.cl" uint results[TOP_K]; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp index d9b2bbe392f..2e4ffc212b8 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.cpp @@ -20,6 +20,25 @@ size_t getOperationNumber(const arg_max_min_params& params) { } } +std::string getOperationNumberString(const arg_max_min_params& params) { + const auto& output = params.outputs[0]; + auto x = toCodeString(output.X(), 11); + auto y = toCodeString(output.Y(), 10); + auto z = toCodeString(output.Z(), 9); + auto w = toCodeString(output.W(), 8); + auto f = toCodeString(output.Feature(), 7); + auto b = toCodeString(output.Batch(), 6); + switch (params.argMaxMinAxis) { + case ArgMaxMinAxis::BATCH: return toVectorMulString({x, y, z, f}); + case ArgMaxMinAxis::FEATURE: return toVectorMulString({x, y, z, b}); + case ArgMaxMinAxis::Z: return toVectorMulString({y, z, f, b}); + case ArgMaxMinAxis::Y: return toVectorMulString({x, z, f, b}); + case ArgMaxMinAxis::X: return toVectorMulString({y, z, f, b}); + default: + throw std::invalid_argument("Unsupported axis"); + } +} + size_t getSortSize(const arg_max_min_params& params) { switch (params.argMaxMinAxis) { case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v; @@ -65,6 +84,7 @@ ParamsKey ArgMaxMinKernelAxis::GetSupportedKey() const { k.EnableBatching(); k.EnableTensorPitches(); k.EnableTensorOffset(); + k.EnableDynamicShapesSupport(); return k; } @@ -83,38 +103,83 @@ bool ArgMaxMinKernelAxis::Validate(const Params& p, const optional_params& o) co return true; } +ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelAxis::SetDefault(const arg_max_min_params& params) const { + DispatchData dispatchData; + + if (!params.has_dynamic_tensors()) { + size_t ops_size = getOperationNumber(params); + ops_size = ops_size > 1 ? Align(ops_size, 32) : 1; + size_t sort_size = params.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(params) : 1; + + dispatchData.gws = { ops_size, sort_size, 1 }; + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); + } + + return dispatchData; +} + KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const optional_params& options) const { if (!Validate(params, options)) { return {}; } const arg_max_min_params& orgParams = static_cast(params); + bool is_dynamic = orgParams.has_dynamic_tensors(); - size_t ops_size = getOperationNumber(orgParams); - ops_size = ops_size > 1 ? Align(ops_size, 32) : 1; - size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1; - - DispatchData dispatchData; - - dispatchData.gws = { ops_size, sort_size, 1 }; - dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); - + auto dispatchData = SetDefault(orgParams); KernelData kd = KernelData::Default(params); + kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) { + const auto& prim_params = static_cast(params); + auto dispatchData = SetDefault(prim_params); + OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kd.kernels[0].params.workGroups.global = dispatchData.gws; + kd.kernels[0].params.workGroups.local = dispatchData.lws; + + const size_t elem_size = prim_params.inputs[0].ElementSize(); + const size_t iav_type_size = elem_size + 4; + const size_t sort_size = getSortSize(prim_params); + const size_t ops_size = getOperationNumber(prim_params); + const size_t group_size = prim_params.topK >= 8 ? prim_params.topK : 8; + const size_t group_num = ((sort_size - 1) / group_size) + 1; + + kd.internalBufferSizes.clear(); + kd.internalBufferSizes.push_back(iav_type_size * sort_size * ops_size * 2); + kd.internalBufferSizes.push_back(4 * group_num * ops_size * 2); + kd.internalBufferSizes.push_back(ops_size * elem_size); + kd.internalBufferDataType = prim_params.inputs[0].GetDType(); + }; auto cldnn_jit = GetJitConstants(orgParams); auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options); auto jit = CreateJit(kernelName, cldnn_jit, entry_point); auto& kernel = kd.kernels[0]; - if (!orgParams.use_multiple_outputs) { - FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point); - } else { - FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, - "", false, false, 1, GetFusedPrimitiveInputsCount(params), 2); - } + FillCLKernelData(kernel, + dispatchData, + params.engineInfo, + kernelName, + jit, + entry_point, + EXE_MODE_DEFAULT, + false, + false, + 1, + GetFusedPrimitiveInputsCount(params), + orgParams.use_multiple_outputs ? 2 : 1, + is_dynamic); if (orgParams.has_second_output && !orgParams.use_multiple_outputs) kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1}); + if (is_dynamic) { + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes()); + kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes()); + kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes()); + kd.internalBufferDataType = orgParams.inputs[0].GetDType(); + } + return {kd}; } @@ -125,7 +190,13 @@ KernelsPriority ArgMaxMinKernelAxis::GetKernelsPriority(const Params& /*params*/ JitConstants ArgMaxMinKernelAxis::GetJitConstants(const arg_max_min_params& params) const { auto jit = ArgMaxMinKernelBase::GetJitConstants(params); - jit.AddConstant(MakeJitConstant("OPERATION_NUM", getOperationNumber(params))); + if (params.has_dynamic_tensors()) { + const std::string operation_num = getOperationNumberString(params); + jit.AddConstant(MakeJitConstant("OPERATION_NUM", operation_num)); + } else { + const size_t operation_num = getOperationNumber(params); + jit.AddConstant(MakeJitConstant("OPERATION_NUM", operation_num)); + } if (params.argMaxMinSortType == ArgMaxMinSortType::VALUE) jit.AddConstant(MakeJitConstant("SORT_BY_VALUE", 1)); else diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.h index d0c717689d0..dd1afa41d1b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_axis.h @@ -13,6 +13,7 @@ public: virtual ~ArgMaxMinKernelAxis() {} JitConstants GetJitConstants(const arg_max_min_params& params) const override; + DispatchData SetDefault(const arg_max_min_params& params) const override; KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override; ParamsKey GetSupportedKey() const override; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_base.cpp index 90b9fa9a5e8..59af1a8712c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_base.cpp @@ -27,8 +27,10 @@ JitConstants ArgMaxMinKernelBase::GetJitConstants(const arg_max_min_params& para ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelBase::SetDefault(const arg_max_min_params& params) const { DispatchData dispatchData; - dispatchData.gws = { 128, params.inputs[0].Batch().v, 1 }; - dispatchData.lws = { 128, 1, 1 }; + if (!params.has_dynamic_inputs()) { + dispatchData.gws = { 128, params.inputs[0].Batch().v, 1 }; + dispatchData.lws = { 128, 1, 1 }; + } return dispatchData; } @@ -43,13 +45,32 @@ KernelsData ArgMaxMinKernelBase::GetCommonKernelsData(const Params& params, cons DispatchData dispatchData = SetDefault(orgParams); KernelData kd = KernelData::Default(params); + kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) { + const auto& prim_params = static_cast(params); + auto dispatchData = SetDefault(prim_params); + OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kd.kernels[0].params.workGroups.global = dispatchData.gws; + kd.kernels[0].params.workGroups.local = dispatchData.lws; + }; auto cldnn_jit = GetJitConstants(orgParams); auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options); auto jit = CreateJit(kernelName, cldnn_jit, entry_point); auto& kernel = kd.kernels[0]; - FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point); + FillCLKernelData(kernel, + dispatchData, + params.engineInfo, + kernelName, + jit, + entry_point, + EXE_MODE_DEFAULT, + false, + false, + (uint32_t)orgParams.inputs.size(), + GetFusedPrimitiveInputsCount(params), + (uint32_t)orgParams.outputs.size(), + orgParams.has_dynamic_tensors()); return {kd}; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_gpu_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_gpu_ref.cpp index af341efa818..297475dce5e 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_gpu_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/arg_max_min/arg_max_min_kernel_gpu_ref.cpp @@ -18,6 +18,7 @@ ParamsKey ArgMaxMinKernelGPURef::GetSupportedKey() const { k.EnableDifferentTypes(); k.EnableBatching(); k.EnableTensorPitches(); + k.EnableDynamicShapesSupport(); return k; } diff --git a/src/plugins/intel_gpu/tests/test_cases/arg_max_gpu_test.cpp b/src/plugins/intel_gpu/tests/test_cases/arg_max_gpu_test.cpp index 752ab7270cb..1bf43241c15 100644 --- a/src/plugins/intel_gpu/tests/test_cases/arg_max_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/test_cases/arg_max_gpu_test.cpp @@ -9,6 +9,8 @@ #include #include +#include + #include "test_utils.h" using namespace cldnn; @@ -870,3 +872,52 @@ TEST(top_k_layer_tests, md_sync) { TEST(export_import_top_k_layer_tests, md_sync) { test_top_k_layer_md_sync(true); } + +TEST(arg_max_min_gpu, dynamic) { + static const int32_t x_size = 2, y_size = 2, feature_num = 4, batch_num = 2; + auto& engine = get_test_engine(); + const int top_k = 2; + auto input_layout_dynamic = layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx}; + auto input_layout_static = layout{ov::PartialShape{batch_num, feature_num, y_size, x_size}, data_types::f32, format::bfyx}; + auto input = engine.allocate_memory(input_layout_static); + + topology topology; + topology.add(input_layout("input", input_layout_dynamic)); + topology.add(arg_max_min("arg_max", { input_info("input") }, ov::op::TopKMode::MIN, top_k, 0)); + + std::vector input_vec = {// y0x0 y0x1 y1x0 y1x1 + /*b0f0*/ 0.1f, -0.1f, 0.9f, 1.5f, + /*b0f1*/ 0.2f, 0.2f, -10.f, 5.2f, + /*b0f2*/ 0.2f, 0.2f, -10.f, 5.2f, + /*b0f3*/ 0.2f, 0.2f, -10.f, 4.2f, + + /*b1f0*/ 3.f, 0.5f, 7.f, 10.f, + /*b1f1*/ 4.f, 0.5f, 8.f, 8.2f, + /*b1f2*/ 0.2f, 0.2f, -10.f, 5.2f, + /*b1f3*/ 4.f, 0.5f, 8.f, 8.2f}; + + set_values(input, input_vec); + + ExecutionConfig config; + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + network network(engine, topology, config); + network.set_input_data("input", input); + + auto inst = network.get_primitive("arg_max"); + auto impl = inst->get_impl(); + ASSERT_TRUE(impl != nullptr); + ASSERT_TRUE(impl->is_dynamic()); + + auto outputs = network.execute(); + ASSERT_EQ(outputs.size(), size_t(1)); + ASSERT_EQ(outputs.begin()->first, "arg_max"); + + const int out_size = y_size * feature_num * x_size * top_k; + auto output = outputs.at("arg_max").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + + ASSERT_EQ(output_ptr.size(), out_size); + for (uint32_t i = 0; i < out_size; i++) { + ASSERT_FLOAT_EQ(output_ptr[i], i < (out_size / 2) ? 0 : 1); + } +}