[GPU] Added shape agnostic TopK kernel (#16161)

* [GPU] Added shape agnostic TopK kernel implementation

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Update kernel to use internal buffers for shape agnostic kernel

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Add WA to compile_graph for shape agnostic arg_max_min_axis with non-const k input

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Fix is_dynamic pameter for FillCLKernelData with the case where the output is static shape

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Fix corner case where inbuf size becomes 0 when ops_size is 1

Signed-off-by: Andrew Park <andrew.park@intel.com>

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2023-03-26 14:32:17 +09:00 committed by GitHub
parent 60ab7490bf
commit 2956717118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 441 additions and 241 deletions

View File

@ -7,6 +7,7 @@
#include "mutable_data_inst.h"
#include "reshape_inst.h"
#include "quantize_inst.h"
#include "arg_max_min_inst.h"
#include "program_node.h"
#include "intel_gpu/runtime/engine.hpp"
#include "intel_gpu/runtime/itt.hpp"
@ -51,6 +52,11 @@ void compile_graph::run(program& p) {
if (node->is_type<fully_connected>() && node->is_dynamic() && node->get_output_layout().get_partial_shape().size() > 3)
can_select_impl = false;
// TODO: Remove this WA once we have shape agnostic arg_max_min_axis kernel with non-const k input
if (node->is_type<arg_max_min>() && node->is_dynamic() && node->as<arg_max_min>().get_primitive()->top_k == 0) {
can_select_impl = false;
}
bool is_planar = node->get_output_layout().format == format::bfyx ||
node->get_output_layout().format == format::bfzyx ||
node->get_output_layout().format == format::bfwzyx;

View File

@ -59,7 +59,7 @@ protected:
}
public:
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<arg_max_min>();
const auto& axis = primitive->axis;
const auto& top_k = primitive->top_k;
@ -68,7 +68,7 @@ public:
const auto& values_first = primitive->values_first;
const auto& outputs_num = primitive->input_size() == 3 ? 2 : static_cast<uint32_t>(primitive->output_size());
auto argm_params = get_default_params<kernel_selector::arg_max_min_params>(impl_param);
auto argm_params = get_default_params<kernel_selector::arg_max_min_params>(impl_param, is_shape_agnostic);
auto argm_optional_params =
get_default_optional_params<kernel_selector::arg_max_min_optional_params>(impl_param.get_program());
@ -76,7 +76,7 @@ public:
argm_params.argMaxMinAxis = GetArgMaxMinAxis(axis, impl_param.get_output_layout().get_rank());
auto& constant_mem = impl_param.memory_deps;
if (constant_mem.count(1)) {
if (constant_mem.count(1) && !argm_params.has_dynamic_outputs()) {
// The topK could be got by reading impl_param.memory_deps.at(1).
// However, here we utilize output_layout and axis information to minimize mem_lock.
auto output_layout = impl_param.get_output_layout(0);
@ -110,26 +110,45 @@ public:
return {argm_params, argm_optional_params};
}
void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data);
update_kernels_list_to_skip();
}
};
namespace detail {
attach_arg_max_min_impl::attach_arg_max_min_impl() {
auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::i32};
auto formats = {format::bfyx,
format::yxfb,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bfzyx};
auto formats = {
format::bfyx,
format::yxfb,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bfzyx
};
implementation_map<arg_max_min>::add(impl_types::ocl,
shape_types::static_shape,
typed_primitive_impl_ocl<arg_max_min>::create<arg_max_min_impl>,
types,
formats);
auto dyn_formats = {
format::bfyx,
format::bfzyx
};
implementation_map<arg_max_min>::add(impl_types::ocl,
shape_types::dynamic_shape,
typed_primitive_impl_ocl<arg_max_min>::create<arg_max_min_impl>,
types,
dyn_formats);
}
} // namespace detail
} // namespace ocl

View File

@ -43,48 +43,22 @@
#define MINIMUM_NUMBER_FOR_PARTIAL_SORTING 100
KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
,__global OUTPUT_TYPE* output
#ifdef SECOND_OUTPUT_EXIST
#ifdef MULTIPLE_OUTPUTS
,__global OUTPUT1_TYPE* second_output
#else
,__global INPUT1_TYPE* second_output
#endif
#endif
)
inline void FUNC(get_indices_from_dims)(OPTIONAL_SHAPE_INFO_ARG
const uint output_idx,
uint* indices)
{
#include "include/arg_max_min_common.cl"
#if SORT_BY_VALUE
const uint sort_idx = (uint)get_global_id(1);
#elif TOP_K == 1
iav_type result[TOP_K];
#else
iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
const uint group_size = TOP_K >= 8 ? TOP_K : 8;
const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
const uint last_group_offset = (group_num - 1) * group_size;
#endif // SORT_BY_VALUE
#if OPERATION_NUM > 1
const uint output_idx = (uint)get_global_id(0);
if (output_idx >= OPERATION_NUM)
return;
#ifdef BATCH_AXIS
#ifdef OUTPUT_LAYOUT_YXFB
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y
const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X
const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F
uint indices[] = { 0, out_fourth_dim, 0, out_first_dim, out_second_dim }; // BFZYX
indices[1] = out_fourth_dim; indices[3] = out_first_dim; indices[4] = out_second_dim; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = { 0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim };
indices[1] = out_first_dim; indices[2] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim;
#endif
#endif
#ifdef FEATURE_AXIS
@ -92,13 +66,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
uint indices[] = { out_fourth_dim, 0, 0, out_first_dim, out_second_dim }; // BFZYX
indices[0] = out_fourth_dim; indices[3] = out_first_dim; indices[4] = out_second_dim; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = { out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim };
indices[0] = out_first_dim; indices[2] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim;
#endif
#endif
#ifdef Z_AXIS
@ -106,20 +80,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = { out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim };
indices[0] = out_first_dim; indices[1] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim;
#endif
#ifdef Y_AXIS
#ifdef OUTPUT_LAYOUT_YXFB
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
uint indices[] = { out_fourth_dim, out_second_dim, 0, 0, out_first_dim }; // BFZYX
indices[0] = out_fourth_dim; indices[1] = out_second_dim; indices[4] = out_first_dim; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim };
indices[0] = out_first_dim; indices[1] = out_second_dim; indices[2] = out_third_dim; indices[4] = out_fourth_dim;
#endif
#endif
#ifdef X_AXIS
@ -127,19 +101,63 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
uint indices[] = { out_fourth_dim, out_second_dim, 0, out_first_dim, 0 }; // BFZYX
indices[0] = out_fourth_dim; indices[1] = out_second_dim; indices[3] = out_first_dim; // BFZYX
#else
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F
const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z
const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0 };
indices[0] = out_first_dim; indices[1] = out_second_dim; indices[2] = out_third_dim; indices[3] = out_fourth_dim;
#endif
#endif
}
#else // OPERATION_NUM > 1
KERNEL(arg_max_min_modified)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input
,__global OUTPUT_TYPE* output
#ifdef SECOND_OUTPUT_EXIST
#ifdef MULTIPLE_OUTPUTS
,__global OUTPUT1_TYPE* second_output
#else
,__global INPUT1_TYPE* second_output
#endif
#endif
#ifdef IS_DYNAMIC
,__global INPUT0_TYPE* tmp_buffer0
,__global INPUT0_TYPE* tmp_buffer1
,__global INPUT0_TYPE* tmp_buffer2
#endif
)
{
#include "include/arg_max_min_common.cl"
const uint output_idx = (uint)get_global_id(0);
#if SORT_BY_VALUE
const uint sort_idx = (uint)get_global_id(1);
#elif TOP_K == 1
iav_type result[TOP_K];
#else
#ifdef IS_DYNAMIC
const uint iav_type_size = INPUT0_TYPE_SIZE + 4;
const uint buffer_size = iav_type_size * VALUES_NUM;
const uint buffer_offset = buffer_size * OPERATION_NUM;
__global iav_type *result = OFFSET_GLOBAL_PTR(iav_type, tmp_buffer0, output_idx * buffer_size);
__global iav_type *temp_buf = OFFSET_GLOBAL_PTR(iav_type, tmp_buffer0, buffer_offset + output_idx * buffer_size);
#else
iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
#endif
const uint group_size = TOP_K >= 8 ? TOP_K : 8;
const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
const uint last_group_offset = (group_num - 1) * group_size;
#endif // SORT_BY_VALUE
uint indices[] = { 0, 0, 0, 0, 0 };
#endif // OPERATION_NUM > 1
if (OPERATION_NUM > 1) {
if (output_idx >= OPERATION_NUM)
return;
FUNC_CALL(get_indices_from_dims)(OPTIONAL_SHAPE_INFO_TENSOR output_idx, indices);
}
// Using parallel sorting for sorting by values
#if SORT_BY_VALUE
@ -147,41 +165,41 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
indices[AXIS] = sort_idx;
iav_type result;
result.value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
result.value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
result.index = sort_idx;
for (uint i = 0; i < sort_idx / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = index_offset;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 1;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 2;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 3;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 4;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 5;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 6;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
indices[AXIS] = index_offset + 7;
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
if (sort_position >= TOP_K)
@ -190,7 +208,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
}
@ -200,7 +218,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_2 test_value)
sort_position++;
if (sort_position >= TOP_K)
@ -209,7 +227,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
// Using simple sorting for sorting by indices and when TOP_K == 1
#elif TOP_K == 1
INPUT0_TYPE val = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
INPUT0_TYPE val = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
result[0].index = 0;
result[0].value = val;
bool already_exist = false;
@ -228,7 +246,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
}
indices[AXIS] = i;
INPUT0_TYPE in_data = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
INPUT0_TYPE in_data = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
if (val COMPARE_SIGN in_data) {
result[top_k].index = i;
result[top_k].value = in_data;
@ -237,197 +255,206 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
}
val = INPUT0_FILL_VAL;
}
// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
#elif ((TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING))
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = result[index_offset].index = index_offset;
result[index_offset].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 1].index = index_offset + 1;
result[index_offset + 1].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 2].index = index_offset + 2;
result[index_offset + 2].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 3].index = index_offset + 3;
result[index_offset + 3].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 4].index = index_offset + 4;
result[index_offset + 4].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 5].index = index_offset + 5;
result[index_offset + 5].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 6].index = index_offset + 6;
result[index_offset + 6].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 7].index = index_offset + 7;
result[index_offset + 7].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
indices[AXIS] = result[i].index = i;
result[i].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint k = 1; k < VALUES_NUM; k *= 2) {
for (uint left = 0; left + k < VALUES_NUM; left += k * 2) {
uint i, j, m;
uint right = left + k;
uint right_end = right + k;
if (right_end > VALUES_NUM) right_end = VALUES_NUM;
m = i = left; j = right;
while ((i < right) && (j < right_end)) {
if (result[i].value COMPARE_PARTIAL_SIGN result[j].value) {
temp_buf[m++] = result[i++];
} else {
temp_buf[m++] = result[j++];
}
}
while (i < right)
temp_buf[m++] = result[i++];
while (j < right_end)
temp_buf[m++] = result[j++];
for (m = left; m < right_end; m++)
result[m] = temp_buf[m];
}
}
// In other cases for sorting by indices using mixed partial/merge sorting
#else // SORT_BY_VALUE
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = temp_buf[index_offset].index = index_offset;
temp_buf[index_offset].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 1].index = index_offset + 1;
temp_buf[index_offset + 1].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 2].index = index_offset + 2;
temp_buf[index_offset + 2].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 3].index = index_offset + 3;
temp_buf[index_offset + 3].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 4].index = index_offset + 4;
temp_buf[index_offset + 4].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 5].index = index_offset + 5;
temp_buf[index_offset + 5].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 6].index = index_offset + 6;
temp_buf[index_offset + 6].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 7].index = index_offset + 7;
temp_buf[index_offset + 7].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
bool use_merge_sorting = (TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING);
if (use_merge_sorting) {
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = result[index_offset].index = index_offset;
result[index_offset].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 1].index = index_offset + 1;
result[index_offset + 1].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 2].index = index_offset + 2;
result[index_offset + 2].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 3].index = index_offset + 3;
result[index_offset + 3].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 4].index = index_offset + 4;
result[index_offset + 4].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 5].index = index_offset + 5;
result[index_offset + 5].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 6].index = index_offset + 6;
result[index_offset + 6].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = result[index_offset + 7].index = index_offset + 7;
result[index_offset + 7].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
indices[AXIS] = temp_buf[i].index = i;
temp_buf[i].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
indices[AXIS] = result[i].index = i;
result[i].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint group = 0; group < group_num - 1; group++) {
uint group_offset = group * group_size;
for (uint k = 1; k < group_size; k *= 2) {
for (uint left = 0; left + k < group_size; left += k * 2) {
for (uint k = 1; k < VALUES_NUM; k *= 2) {
for (uint left = 0; left + k < VALUES_NUM; left += k * 2) {
uint i, j, m;
uint right = left + k;
uint right_end = right + k;
if (right_end > group_size) right_end = group_size;
if (right_end > VALUES_NUM) right_end = VALUES_NUM;
m = i = left; j = right;
while ((i < right) && (j < right_end)) {
if (temp_buf[group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[group_offset + j].value) {
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
if (result[i].value COMPARE_PARTIAL_SIGN result[j].value) {
temp_buf[m++] = result[i++];
} else {
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
temp_buf[m++] = result[j++];
}
}
while (i < right)
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
temp_buf[m++] = result[i++];
while (j < right_end)
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
temp_buf[m++] = result[j++];
for (m = left; m < right_end; m++)
temp_buf[group_offset + m] = result[group_offset + m];
result[m] = temp_buf[m];
}
}
}
} else {
// In other cases for sorting by indices using mixed partial/merge sorting
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = temp_buf[index_offset].index = index_offset;
temp_buf[index_offset].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 1].index = index_offset + 1;
temp_buf[index_offset + 1].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 2].index = index_offset + 2;
temp_buf[index_offset + 2].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 3].index = index_offset + 3;
temp_buf[index_offset + 3].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 4].index = index_offset + 4;
temp_buf[index_offset + 4].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 5].index = index_offset + 5;
temp_buf[index_offset + 5].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 6].index = index_offset + 6;
temp_buf[index_offset + 6].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
indices[AXIS] = temp_buf[index_offset + 7].index = index_offset + 7;
temp_buf[index_offset + 7].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint k = 1; k < last_group_size; k *= 2) {
for (uint left = 0; left + k < last_group_size; left += k * 2) {
uint i, j, m;
uint right = left + k;
uint right_end = right + k;
if (right_end > last_group_size) right_end = last_group_size;
m = i = left; j = right;
while ((i < right) && (j < right_end)) {
if (temp_buf[last_group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[last_group_offset + j].value) {
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
} else {
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
indices[AXIS] = temp_buf[i].index = i;
temp_buf[i].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
}
for (uint group = 0; group < group_num - 1; group++) {
uint group_offset = group * group_size;
for (uint k = 1; k < group_size; k *= 2) {
for (uint left = 0; left + k < group_size; left += k * 2) {
uint i, j, m;
uint right = left + k;
uint right_end = right + k;
if (right_end > group_size) right_end = group_size;
m = i = left; j = right;
while ((i < right) && (j < right_end)) {
if (temp_buf[group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[group_offset + j].value) {
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
} else {
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
}
}
while (i < right)
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
while (j < right_end)
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
for (m = left; m < right_end; m++)
temp_buf[group_offset + m] = result[group_offset + m];
}
}
while (i < right)
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
while (j < right_end)
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
for (m = left; m < right_end; m++)
temp_buf[last_group_offset + m] = result[last_group_offset + m];
}
}
uint merge_counter[group_num];
uint max_merge_counter[group_num];
iav_type merge_buf;
bool subgroup_done[group_num];
unroll_for (uint i = 0; i < group_num - 1; i++) {
merge_counter[i] = 0;
max_merge_counter[i] = group_size;
subgroup_done[i] = false;
}
merge_counter[group_num - 1] = 0;
max_merge_counter[group_num - 1] = last_group_size;
subgroup_done[group_num - 1] = false;
for (uint i = 0; i < TOP_K; i++) {
bool merge_buf_done = false;
uint merge_buf_index = 0;
for (uint j = 0; j < group_num; j++) {
if (subgroup_done[j])
continue;
uint test_index = j * group_size + merge_counter[j];
if (!merge_buf_done) {
merge_buf = temp_buf[test_index];
merge_buf_done = true;
merge_buf_index = j;
continue;
}
if (temp_buf[test_index].value COMPARE_MERGE_SIGN merge_buf.value) {
merge_buf = temp_buf[test_index];
merge_buf_index = j;
for (uint k = 1; k < last_group_size; k *= 2) {
for (uint left = 0; left + k < last_group_size; left += k * 2) {
uint i, j, m;
uint right = left + k;
uint right_end = right + k;
if (right_end > last_group_size) right_end = last_group_size;
m = i = left; j = right;
while ((i < right) && (j < right_end)) {
if (temp_buf[last_group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[last_group_offset + j].value) {
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
} else {
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
}
}
while (i < right)
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
while (j < right_end)
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
for (m = left; m < right_end; m++)
temp_buf[last_group_offset + m] = result[last_group_offset + m];
}
}
merge_counter[merge_buf_index]++;
if (merge_counter[merge_buf_index] == max_merge_counter[merge_buf_index])
subgroup_done[merge_buf_index] = true;
#ifdef IS_DYNAMIC
const uint counter_size = group_num * 4;
const uint counter_offset = counter_size * OPERATION_NUM;
__global uint* merge_counter = OFFSET_GLOBAL_PTR(uint, tmp_buffer1, output_idx * counter_size);
__global uint* max_merge_counter = OFFSET_GLOBAL_PTR(uint, tmp_buffer1, counter_offset + output_idx * counter_size);
__global bool* subgroup_done = OFFSET_GLOBAL_PTR(bool, tmp_buffer2, output_idx);
#else
uint merge_counter[group_num];
uint max_merge_counter[group_num];
bool subgroup_done[group_num];
#endif
iav_type merge_buf;
result[i] = merge_buf;
unroll_for (uint i = 0; i < group_num - 1; i++) {
merge_counter[i] = 0;
max_merge_counter[i] = group_size;
subgroup_done[i] = false;
}
merge_counter[group_num - 1] = 0;
max_merge_counter[group_num - 1] = last_group_size;
subgroup_done[group_num - 1] = false;
for (uint i = 0; i < TOP_K; i++) {
bool merge_buf_done = false;
uint merge_buf_index = 0;
for (uint j = 0; j < group_num; j++) {
if (subgroup_done[j])
continue;
uint test_index = j * group_size + merge_counter[j];
if (!merge_buf_done) {
merge_buf = temp_buf[test_index];
merge_buf_done = true;
merge_buf_index = j;
continue;
}
if (temp_buf[test_index].value COMPARE_MERGE_SIGN merge_buf.value) {
merge_buf = temp_buf[test_index];
merge_buf_index = j;
}
}
merge_counter[merge_buf_index]++;
if (merge_counter[merge_buf_index] == max_merge_counter[merge_buf_index])
subgroup_done[merge_buf_index] = true;
result[i] = merge_buf;
}
}
#endif // SORT_BY_VALUE
#if SORT_BY_VALUE
indices[AXIS] = sort_position;
#ifdef TOP_K_ORDER
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
#else
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
#endif
#ifdef SECOND_OUTPUT_EXIST
#ifdef MULTIPLE_OUTPUTS
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.index);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.index);
#else
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.value);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.value);
#endif
#else
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
#else
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
#endif
#endif
#endif
@ -445,22 +472,22 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
indices[AXIS] = out_position;
#ifdef TOP_K_ORDER
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
#else
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
#endif
#ifdef SECOND_OUTPUT_EXIST
#ifdef MULTIPLE_OUTPUTS
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].index);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].index);
#else
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].value);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].value);
#endif
#else
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
#else
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
#endif
#endif
#endif

View File

@ -16,7 +16,10 @@
#endif
__attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1)))
KERNEL(arg_max_gpu_top_k)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output)
KERNEL(arg_max_gpu_top_k)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output)
{
#include "include/arg_max_min_common.cl"
uint results[TOP_K];

View File

@ -20,6 +20,25 @@ size_t getOperationNumber(const arg_max_min_params& params) {
}
}
std::string getOperationNumberString(const arg_max_min_params& params) {
const auto& output = params.outputs[0];
auto x = toCodeString(output.X(), 11);
auto y = toCodeString(output.Y(), 10);
auto z = toCodeString(output.Z(), 9);
auto w = toCodeString(output.W(), 8);
auto f = toCodeString(output.Feature(), 7);
auto b = toCodeString(output.Batch(), 6);
switch (params.argMaxMinAxis) {
case ArgMaxMinAxis::BATCH: return toVectorMulString({x, y, z, f});
case ArgMaxMinAxis::FEATURE: return toVectorMulString({x, y, z, b});
case ArgMaxMinAxis::Z: return toVectorMulString({y, z, f, b});
case ArgMaxMinAxis::Y: return toVectorMulString({x, z, f, b});
case ArgMaxMinAxis::X: return toVectorMulString({y, z, f, b});
default:
throw std::invalid_argument("Unsupported axis");
}
}
size_t getSortSize(const arg_max_min_params& params) {
switch (params.argMaxMinAxis) {
case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v;
@ -65,6 +84,7 @@ ParamsKey ArgMaxMinKernelAxis::GetSupportedKey() const {
k.EnableBatching();
k.EnableTensorPitches();
k.EnableTensorOffset();
k.EnableDynamicShapesSupport();
return k;
}
@ -83,38 +103,83 @@ bool ArgMaxMinKernelAxis::Validate(const Params& p, const optional_params& o) co
return true;
}
ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelAxis::SetDefault(const arg_max_min_params& params) const {
DispatchData dispatchData;
if (!params.has_dynamic_tensors()) {
size_t ops_size = getOperationNumber(params);
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
size_t sort_size = params.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(params) : 1;
dispatchData.gws = { ops_size, sort_size, 1 };
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
}
return dispatchData;
}
KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
bool is_dynamic = orgParams.has_dynamic_tensors();
size_t ops_size = getOperationNumber(orgParams);
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
DispatchData dispatchData;
dispatchData.gws = { ops_size, sort_size, 1 };
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
auto dispatchData = SetDefault(orgParams);
KernelData kd = KernelData::Default<arg_max_min_params>(params);
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
const auto& prim_params = static_cast<const arg_max_min_params&>(params);
auto dispatchData = SetDefault(prim_params);
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
kd.kernels[0].params.workGroups.global = dispatchData.gws;
kd.kernels[0].params.workGroups.local = dispatchData.lws;
const size_t elem_size = prim_params.inputs[0].ElementSize();
const size_t iav_type_size = elem_size + 4;
const size_t sort_size = getSortSize(prim_params);
const size_t ops_size = getOperationNumber(prim_params);
const size_t group_size = prim_params.topK >= 8 ? prim_params.topK : 8;
const size_t group_num = ((sort_size - 1) / group_size) + 1;
kd.internalBufferSizes.clear();
kd.internalBufferSizes.push_back(iav_type_size * sort_size * ops_size * 2);
kd.internalBufferSizes.push_back(4 * group_num * ops_size * 2);
kd.internalBufferSizes.push_back(ops_size * elem_size);
kd.internalBufferDataType = prim_params.inputs[0].GetDType();
};
auto cldnn_jit = GetJitConstants(orgParams);
auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
auto& kernel = kd.kernels[0];
if (!orgParams.use_multiple_outputs) {
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point);
} else {
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point,
"", false, false, 1, GetFusedPrimitiveInputsCount(params), 2);
}
FillCLKernelData(kernel,
dispatchData,
params.engineInfo,
kernelName,
jit,
entry_point,
EXE_MODE_DEFAULT,
false,
false,
1,
GetFusedPrimitiveInputsCount(params),
orgParams.use_multiple_outputs ? 2 : 1,
is_dynamic);
if (orgParams.has_second_output && !orgParams.use_multiple_outputs)
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
if (is_dynamic) {
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes());
kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes());
kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes());
kd.internalBufferDataType = orgParams.inputs[0].GetDType();
}
return {kd};
}
@ -125,7 +190,13 @@ KernelsPriority ArgMaxMinKernelAxis::GetKernelsPriority(const Params& /*params*/
JitConstants ArgMaxMinKernelAxis::GetJitConstants(const arg_max_min_params& params) const {
auto jit = ArgMaxMinKernelBase::GetJitConstants(params);
jit.AddConstant(MakeJitConstant("OPERATION_NUM", getOperationNumber(params)));
if (params.has_dynamic_tensors()) {
const std::string operation_num = getOperationNumberString(params);
jit.AddConstant(MakeJitConstant("OPERATION_NUM", operation_num));
} else {
const size_t operation_num = getOperationNumber(params);
jit.AddConstant(MakeJitConstant("OPERATION_NUM", operation_num));
}
if (params.argMaxMinSortType == ArgMaxMinSortType::VALUE)
jit.AddConstant(MakeJitConstant("SORT_BY_VALUE", 1));
else

View File

@ -13,6 +13,7 @@ public:
virtual ~ArgMaxMinKernelAxis() {}
JitConstants GetJitConstants(const arg_max_min_params& params) const override;
DispatchData SetDefault(const arg_max_min_params& params) const override;
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;

View File

@ -27,8 +27,10 @@ JitConstants ArgMaxMinKernelBase::GetJitConstants(const arg_max_min_params& para
ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelBase::SetDefault(const arg_max_min_params& params) const {
DispatchData dispatchData;
dispatchData.gws = { 128, params.inputs[0].Batch().v, 1 };
dispatchData.lws = { 128, 1, 1 };
if (!params.has_dynamic_inputs()) {
dispatchData.gws = { 128, params.inputs[0].Batch().v, 1 };
dispatchData.lws = { 128, 1, 1 };
}
return dispatchData;
}
@ -43,13 +45,32 @@ KernelsData ArgMaxMinKernelBase::GetCommonKernelsData(const Params& params, cons
DispatchData dispatchData = SetDefault(orgParams);
KernelData kd = KernelData::Default<arg_max_min_params>(params);
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
const auto& prim_params = static_cast<const arg_max_min_params&>(params);
auto dispatchData = SetDefault(prim_params);
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
kd.kernels[0].params.workGroups.global = dispatchData.gws;
kd.kernels[0].params.workGroups.local = dispatchData.lws;
};
auto cldnn_jit = GetJitConstants(orgParams);
auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
auto& kernel = kd.kernels[0];
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point);
FillCLKernelData(kernel,
dispatchData,
params.engineInfo,
kernelName,
jit,
entry_point,
EXE_MODE_DEFAULT,
false,
false,
(uint32_t)orgParams.inputs.size(),
GetFusedPrimitiveInputsCount(params),
(uint32_t)orgParams.outputs.size(),
orgParams.has_dynamic_tensors());
return {kd};
}

View File

@ -18,6 +18,7 @@ ParamsKey ArgMaxMinKernelGPURef::GetSupportedKey() const {
k.EnableDifferentTypes();
k.EnableBatching();
k.EnableTensorPitches();
k.EnableDynamicShapesSupport();
return k;
}

View File

@ -9,6 +9,8 @@
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/mutable_data.hpp>
#include <arg_max_min_inst.h>
#include "test_utils.h"
using namespace cldnn;
@ -870,3 +872,52 @@ TEST(top_k_layer_tests, md_sync) {
TEST(export_import_top_k_layer_tests, md_sync) {
test_top_k_layer_md_sync<int>(true);
}
TEST(arg_max_min_gpu, dynamic) {
static const int32_t x_size = 2, y_size = 2, feature_num = 4, batch_num = 2;
auto& engine = get_test_engine();
const int top_k = 2;
auto input_layout_dynamic = layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx};
auto input_layout_static = layout{ov::PartialShape{batch_num, feature_num, y_size, x_size}, data_types::f32, format::bfyx};
auto input = engine.allocate_memory(input_layout_static);
topology topology;
topology.add(input_layout("input", input_layout_dynamic));
topology.add(arg_max_min("arg_max", { input_info("input") }, ov::op::TopKMode::MIN, top_k, 0));
std::vector<float> input_vec = {// y0x0 y0x1 y1x0 y1x1
/*b0f0*/ 0.1f, -0.1f, 0.9f, 1.5f,
/*b0f1*/ 0.2f, 0.2f, -10.f, 5.2f,
/*b0f2*/ 0.2f, 0.2f, -10.f, 5.2f,
/*b0f3*/ 0.2f, 0.2f, -10.f, 4.2f,
/*b1f0*/ 3.f, 0.5f, 7.f, 10.f,
/*b1f1*/ 4.f, 0.5f, 8.f, 8.2f,
/*b1f2*/ 0.2f, 0.2f, -10.f, 5.2f,
/*b1f3*/ 4.f, 0.5f, 8.f, 8.2f};
set_values(input, input_vec);
ExecutionConfig config;
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
network network(engine, topology, config);
network.set_input_data("input", input);
auto inst = network.get_primitive("arg_max");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
auto outputs = network.execute();
ASSERT_EQ(outputs.size(), size_t(1));
ASSERT_EQ(outputs.begin()->first, "arg_max");
const int out_size = y_size * feature_num * x_size * top_k;
auto output = outputs.at("arg_max").get_memory();
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
ASSERT_EQ(output_ptr.size(), out_size);
for (uint32_t i = 0; i < out_size; i++) {
ASSERT_FLOAT_EQ(output_ptr[i], i < (out_size / 2) ? 0 : 1);
}
}