[GPU] Added shape agnostic TopK kernel (#16161)
* [GPU] Added shape agnostic TopK kernel implementation Signed-off-by: Andrew Park <andrew.park@intel.com> * Update kernel to use internal buffers for shape agnostic kernel Signed-off-by: Andrew Park <andrew.park@intel.com> * Add WA to compile_graph for shape agnostic arg_max_min_axis with non-const k input Signed-off-by: Andrew Park <andrew.park@intel.com> * Fix is_dynamic pameter for FillCLKernelData with the case where the output is static shape Signed-off-by: Andrew Park <andrew.park@intel.com> * Fix corner case where inbuf size becomes 0 when ops_size is 1 Signed-off-by: Andrew Park <andrew.park@intel.com> --------- Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
parent
60ab7490bf
commit
2956717118
@ -7,6 +7,7 @@
|
|||||||
#include "mutable_data_inst.h"
|
#include "mutable_data_inst.h"
|
||||||
#include "reshape_inst.h"
|
#include "reshape_inst.h"
|
||||||
#include "quantize_inst.h"
|
#include "quantize_inst.h"
|
||||||
|
#include "arg_max_min_inst.h"
|
||||||
#include "program_node.h"
|
#include "program_node.h"
|
||||||
#include "intel_gpu/runtime/engine.hpp"
|
#include "intel_gpu/runtime/engine.hpp"
|
||||||
#include "intel_gpu/runtime/itt.hpp"
|
#include "intel_gpu/runtime/itt.hpp"
|
||||||
@ -51,6 +52,11 @@ void compile_graph::run(program& p) {
|
|||||||
if (node->is_type<fully_connected>() && node->is_dynamic() && node->get_output_layout().get_partial_shape().size() > 3)
|
if (node->is_type<fully_connected>() && node->is_dynamic() && node->get_output_layout().get_partial_shape().size() > 3)
|
||||||
can_select_impl = false;
|
can_select_impl = false;
|
||||||
|
|
||||||
|
// TODO: Remove this WA once we have shape agnostic arg_max_min_axis kernel with non-const k input
|
||||||
|
if (node->is_type<arg_max_min>() && node->is_dynamic() && node->as<arg_max_min>().get_primitive()->top_k == 0) {
|
||||||
|
can_select_impl = false;
|
||||||
|
}
|
||||||
|
|
||||||
bool is_planar = node->get_output_layout().format == format::bfyx ||
|
bool is_planar = node->get_output_layout().format == format::bfyx ||
|
||||||
node->get_output_layout().format == format::bfzyx ||
|
node->get_output_layout().format == format::bfzyx ||
|
||||||
node->get_output_layout().format == format::bfwzyx;
|
node->get_output_layout().format == format::bfwzyx;
|
||||||
|
@ -59,7 +59,7 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
|
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
|
||||||
const auto& primitive = impl_param.typed_desc<arg_max_min>();
|
const auto& primitive = impl_param.typed_desc<arg_max_min>();
|
||||||
const auto& axis = primitive->axis;
|
const auto& axis = primitive->axis;
|
||||||
const auto& top_k = primitive->top_k;
|
const auto& top_k = primitive->top_k;
|
||||||
@ -68,7 +68,7 @@ public:
|
|||||||
const auto& values_first = primitive->values_first;
|
const auto& values_first = primitive->values_first;
|
||||||
const auto& outputs_num = primitive->input_size() == 3 ? 2 : static_cast<uint32_t>(primitive->output_size());
|
const auto& outputs_num = primitive->input_size() == 3 ? 2 : static_cast<uint32_t>(primitive->output_size());
|
||||||
|
|
||||||
auto argm_params = get_default_params<kernel_selector::arg_max_min_params>(impl_param);
|
auto argm_params = get_default_params<kernel_selector::arg_max_min_params>(impl_param, is_shape_agnostic);
|
||||||
auto argm_optional_params =
|
auto argm_optional_params =
|
||||||
get_default_optional_params<kernel_selector::arg_max_min_optional_params>(impl_param.get_program());
|
get_default_optional_params<kernel_selector::arg_max_min_optional_params>(impl_param.get_program());
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ public:
|
|||||||
argm_params.argMaxMinAxis = GetArgMaxMinAxis(axis, impl_param.get_output_layout().get_rank());
|
argm_params.argMaxMinAxis = GetArgMaxMinAxis(axis, impl_param.get_output_layout().get_rank());
|
||||||
|
|
||||||
auto& constant_mem = impl_param.memory_deps;
|
auto& constant_mem = impl_param.memory_deps;
|
||||||
if (constant_mem.count(1)) {
|
if (constant_mem.count(1) && !argm_params.has_dynamic_outputs()) {
|
||||||
// The topK could be got by reading impl_param.memory_deps.at(1).
|
// The topK could be got by reading impl_param.memory_deps.at(1).
|
||||||
// However, here we utilize output_layout and axis information to minimize mem_lock.
|
// However, here we utilize output_layout and axis information to minimize mem_lock.
|
||||||
auto output_layout = impl_param.get_output_layout(0);
|
auto output_layout = impl_param.get_output_layout(0);
|
||||||
@ -110,26 +110,45 @@ public:
|
|||||||
|
|
||||||
return {argm_params, argm_optional_params};
|
return {argm_params, argm_optional_params};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void update_dispatch_data(const kernel_impl_params& impl_param) override {
|
||||||
|
auto kernel_params = get_kernel_params(impl_param, true);
|
||||||
|
(_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data);
|
||||||
|
update_kernels_list_to_skip();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
attach_arg_max_min_impl::attach_arg_max_min_impl() {
|
attach_arg_max_min_impl::attach_arg_max_min_impl() {
|
||||||
auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::i32};
|
auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::i32};
|
||||||
|
|
||||||
auto formats = {format::bfyx,
|
auto formats = {
|
||||||
format::yxfb,
|
format::bfyx,
|
||||||
format::b_fs_yx_fsv16,
|
format::yxfb,
|
||||||
format::b_fs_yx_fsv32,
|
format::b_fs_yx_fsv16,
|
||||||
format::bs_fs_yx_bsv16_fsv16,
|
format::b_fs_yx_fsv32,
|
||||||
format::bs_fs_yx_bsv32_fsv16,
|
format::bs_fs_yx_bsv16_fsv16,
|
||||||
format::bs_fs_yx_bsv32_fsv32,
|
format::bs_fs_yx_bsv32_fsv16,
|
||||||
|
format::bs_fs_yx_bsv32_fsv32,
|
||||||
format::bfzyx};
|
format::bfzyx
|
||||||
|
};
|
||||||
|
|
||||||
implementation_map<arg_max_min>::add(impl_types::ocl,
|
implementation_map<arg_max_min>::add(impl_types::ocl,
|
||||||
|
shape_types::static_shape,
|
||||||
typed_primitive_impl_ocl<arg_max_min>::create<arg_max_min_impl>,
|
typed_primitive_impl_ocl<arg_max_min>::create<arg_max_min_impl>,
|
||||||
types,
|
types,
|
||||||
formats);
|
formats);
|
||||||
|
|
||||||
|
auto dyn_formats = {
|
||||||
|
format::bfyx,
|
||||||
|
format::bfzyx
|
||||||
|
};
|
||||||
|
|
||||||
|
implementation_map<arg_max_min>::add(impl_types::ocl,
|
||||||
|
shape_types::dynamic_shape,
|
||||||
|
typed_primitive_impl_ocl<arg_max_min>::create<arg_max_min_impl>,
|
||||||
|
types,
|
||||||
|
dyn_formats);
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace ocl
|
} // namespace ocl
|
||||||
|
@ -43,48 +43,22 @@
|
|||||||
|
|
||||||
#define MINIMUM_NUMBER_FOR_PARTIAL_SORTING 100
|
#define MINIMUM_NUMBER_FOR_PARTIAL_SORTING 100
|
||||||
|
|
||||||
KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
inline void FUNC(get_indices_from_dims)(OPTIONAL_SHAPE_INFO_ARG
|
||||||
,__global OUTPUT_TYPE* output
|
const uint output_idx,
|
||||||
#ifdef SECOND_OUTPUT_EXIST
|
uint* indices)
|
||||||
#ifdef MULTIPLE_OUTPUTS
|
|
||||||
,__global OUTPUT1_TYPE* second_output
|
|
||||||
#else
|
|
||||||
,__global INPUT1_TYPE* second_output
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
)
|
|
||||||
{
|
{
|
||||||
#include "include/arg_max_min_common.cl"
|
|
||||||
#if SORT_BY_VALUE
|
|
||||||
const uint sort_idx = (uint)get_global_id(1);
|
|
||||||
#elif TOP_K == 1
|
|
||||||
iav_type result[TOP_K];
|
|
||||||
#else
|
|
||||||
iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
|
|
||||||
const uint group_size = TOP_K >= 8 ? TOP_K : 8;
|
|
||||||
const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
|
|
||||||
const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
|
|
||||||
const uint last_group_offset = (group_num - 1) * group_size;
|
|
||||||
#endif // SORT_BY_VALUE
|
|
||||||
|
|
||||||
#if OPERATION_NUM > 1
|
|
||||||
const uint output_idx = (uint)get_global_id(0);
|
|
||||||
|
|
||||||
if (output_idx >= OPERATION_NUM)
|
|
||||||
return;
|
|
||||||
|
|
||||||
#ifdef BATCH_AXIS
|
#ifdef BATCH_AXIS
|
||||||
#ifdef OUTPUT_LAYOUT_YXFB
|
#ifdef OUTPUT_LAYOUT_YXFB
|
||||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y
|
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_FEATURE_NUM); // Y
|
||||||
const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X
|
const uint out_second_dim = output_idx / INPUT0_FEATURE_NUM % INPUT0_SIZE_X; // X
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F
|
const uint out_fourth_dim = output_idx % INPUT0_FEATURE_NUM; // F
|
||||||
uint indices[] = { 0, out_fourth_dim, 0, out_first_dim, out_second_dim }; // BFZYX
|
indices[1] = out_fourth_dim; indices[3] = out_first_dim; indices[4] = out_second_dim; // BFZYX
|
||||||
#else
|
#else
|
||||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F
|
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // F
|
||||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
|
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
|
||||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||||
uint indices[] = { 0, out_first_dim, out_second_dim, out_third_dim, out_fourth_dim };
|
indices[1] = out_first_dim; indices[2] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim;
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#ifdef FEATURE_AXIS
|
#ifdef FEATURE_AXIS
|
||||||
@ -92,13 +66,13 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y
|
const uint out_first_dim = output_idx / (INPUT0_SIZE_X * INPUT0_BATCH_NUM); // Y
|
||||||
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X
|
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_SIZE_X; // X
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
||||||
uint indices[] = { out_fourth_dim, 0, 0, out_first_dim, out_second_dim }; // BFZYX
|
indices[0] = out_fourth_dim; indices[3] = out_first_dim; indices[4] = out_second_dim; // BFZYX
|
||||||
#else
|
#else
|
||||||
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B
|
const uint out_first_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y * INPUT0_SIZE_X); // B
|
||||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
|
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_SIZE_Z; // Z
|
||||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||||
uint indices[] = { out_first_dim, 0, out_second_dim, out_third_dim, out_fourth_dim };
|
indices[0] = out_first_dim; indices[2] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim;
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#ifdef Z_AXIS
|
#ifdef Z_AXIS
|
||||||
@ -106,20 +80,20 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
|
const uint out_second_dim = output_idx / (INPUT0_SIZE_Y * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
|
||||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Y; // Y
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||||
uint indices[] = { out_first_dim, out_second_dim, 0, out_third_dim, out_fourth_dim };
|
indices[0] = out_first_dim; indices[1] = out_second_dim; indices[3] = out_third_dim; indices[4] = out_fourth_dim;
|
||||||
#endif
|
#endif
|
||||||
#ifdef Y_AXIS
|
#ifdef Y_AXIS
|
||||||
#ifdef OUTPUT_LAYOUT_YXFB
|
#ifdef OUTPUT_LAYOUT_YXFB
|
||||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X
|
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // X
|
||||||
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
|
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
||||||
uint indices[] = { out_fourth_dim, out_second_dim, 0, 0, out_first_dim }; // BFZYX
|
indices[0] = out_fourth_dim; indices[1] = out_second_dim; indices[4] = out_first_dim; // BFZYX
|
||||||
#else
|
#else
|
||||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B
|
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_X); // B
|
||||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
|
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_X) % INPUT0_FEATURE_NUM; // F
|
||||||
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z
|
const uint out_third_dim = output_idx / INPUT0_SIZE_X % INPUT0_SIZE_Z; // Z
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
const uint out_fourth_dim = output_idx % INPUT0_SIZE_X; // X
|
||||||
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, 0, out_fourth_dim };
|
indices[0] = out_first_dim; indices[1] = out_second_dim; indices[2] = out_third_dim; indices[4] = out_fourth_dim;
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#ifdef X_AXIS
|
#ifdef X_AXIS
|
||||||
@ -127,19 +101,63 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y
|
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_BATCH_NUM); // Y
|
||||||
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
|
const uint out_second_dim = output_idx / INPUT0_BATCH_NUM % INPUT0_FEATURE_NUM; // F
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
const uint out_fourth_dim = output_idx % INPUT0_BATCH_NUM; // B
|
||||||
uint indices[] = { out_fourth_dim, out_second_dim, 0, out_first_dim, 0 }; // BFZYX
|
indices[0] = out_fourth_dim; indices[1] = out_second_dim; indices[3] = out_first_dim; // BFZYX
|
||||||
#else
|
#else
|
||||||
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B
|
const uint out_first_dim = output_idx / (INPUT0_FEATURE_NUM * INPUT0_SIZE_Z * INPUT0_SIZE_Y); // B
|
||||||
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F
|
const uint out_second_dim = output_idx / (INPUT0_SIZE_Z * INPUT0_SIZE_Y) % INPUT0_FEATURE_NUM; // F
|
||||||
const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z
|
const uint out_third_dim = output_idx / INPUT0_SIZE_Y % INPUT0_SIZE_Z; // Z
|
||||||
const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y
|
const uint out_fourth_dim = output_idx % INPUT0_SIZE_Y; // Y
|
||||||
uint indices[] = { out_first_dim, out_second_dim, out_third_dim, out_fourth_dim, 0 };
|
indices[0] = out_first_dim; indices[1] = out_second_dim; indices[2] = out_third_dim; indices[3] = out_fourth_dim;
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
#else // OPERATION_NUM > 1
|
KERNEL(arg_max_min_modified)(
|
||||||
|
OPTIONAL_SHAPE_INFO_ARG
|
||||||
|
const __global INPUT0_TYPE* input
|
||||||
|
,__global OUTPUT_TYPE* output
|
||||||
|
#ifdef SECOND_OUTPUT_EXIST
|
||||||
|
#ifdef MULTIPLE_OUTPUTS
|
||||||
|
,__global OUTPUT1_TYPE* second_output
|
||||||
|
#else
|
||||||
|
,__global INPUT1_TYPE* second_output
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#ifdef IS_DYNAMIC
|
||||||
|
,__global INPUT0_TYPE* tmp_buffer0
|
||||||
|
,__global INPUT0_TYPE* tmp_buffer1
|
||||||
|
,__global INPUT0_TYPE* tmp_buffer2
|
||||||
|
#endif
|
||||||
|
)
|
||||||
|
{
|
||||||
|
#include "include/arg_max_min_common.cl"
|
||||||
|
const uint output_idx = (uint)get_global_id(0);
|
||||||
|
#if SORT_BY_VALUE
|
||||||
|
const uint sort_idx = (uint)get_global_id(1);
|
||||||
|
#elif TOP_K == 1
|
||||||
|
iav_type result[TOP_K];
|
||||||
|
#else
|
||||||
|
#ifdef IS_DYNAMIC
|
||||||
|
const uint iav_type_size = INPUT0_TYPE_SIZE + 4;
|
||||||
|
const uint buffer_size = iav_type_size * VALUES_NUM;
|
||||||
|
const uint buffer_offset = buffer_size * OPERATION_NUM;
|
||||||
|
__global iav_type *result = OFFSET_GLOBAL_PTR(iav_type, tmp_buffer0, output_idx * buffer_size);
|
||||||
|
__global iav_type *temp_buf = OFFSET_GLOBAL_PTR(iav_type, tmp_buffer0, buffer_offset + output_idx * buffer_size);
|
||||||
|
#else
|
||||||
|
iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
|
||||||
|
#endif
|
||||||
|
const uint group_size = TOP_K >= 8 ? TOP_K : 8;
|
||||||
|
const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
|
||||||
|
const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
|
||||||
|
const uint last_group_offset = (group_num - 1) * group_size;
|
||||||
|
#endif // SORT_BY_VALUE
|
||||||
uint indices[] = { 0, 0, 0, 0, 0 };
|
uint indices[] = { 0, 0, 0, 0, 0 };
|
||||||
#endif // OPERATION_NUM > 1
|
|
||||||
|
if (OPERATION_NUM > 1) {
|
||||||
|
if (output_idx >= OPERATION_NUM)
|
||||||
|
return;
|
||||||
|
FUNC_CALL(get_indices_from_dims)(OPTIONAL_SHAPE_INFO_TENSOR output_idx, indices);
|
||||||
|
}
|
||||||
|
|
||||||
// Using parallel sorting for sorting by values
|
// Using parallel sorting for sorting by values
|
||||||
#if SORT_BY_VALUE
|
#if SORT_BY_VALUE
|
||||||
@ -147,41 +165,41 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
indices[AXIS] = sort_idx;
|
indices[AXIS] = sort_idx;
|
||||||
|
|
||||||
iav_type result;
|
iav_type result;
|
||||||
result.value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
result.value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
result.index = sort_idx;
|
result.index = sort_idx;
|
||||||
|
|
||||||
for (uint i = 0; i < sort_idx / 8; i++) {
|
for (uint i = 0; i < sort_idx / 8; i++) {
|
||||||
uint index_offset = i * 8;
|
uint index_offset = i * 8;
|
||||||
indices[AXIS] = index_offset;
|
indices[AXIS] = index_offset;
|
||||||
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 1;
|
indices[AXIS] = index_offset + 1;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 2;
|
indices[AXIS] = index_offset + 2;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 3;
|
indices[AXIS] = index_offset + 3;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 4;
|
indices[AXIS] = index_offset + 4;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 5;
|
indices[AXIS] = index_offset + 5;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 6;
|
indices[AXIS] = index_offset + 6;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
indices[AXIS] = index_offset + 7;
|
indices[AXIS] = index_offset + 7;
|
||||||
test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
if (sort_position >= TOP_K)
|
if (sort_position >= TOP_K)
|
||||||
@ -190,7 +208,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
|
|
||||||
for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) {
|
for (uint i = (sort_idx / 8) * 8; i < sort_idx; i++) {
|
||||||
indices[AXIS] = i;
|
indices[AXIS] = i;
|
||||||
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
}
|
}
|
||||||
@ -200,7 +218,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
|
|
||||||
for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
|
for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
|
||||||
indices[AXIS] = i;
|
indices[AXIS] = i;
|
||||||
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (result.value COMPARE_PARALLEL_SIGN_2 test_value)
|
if (result.value COMPARE_PARALLEL_SIGN_2 test_value)
|
||||||
sort_position++;
|
sort_position++;
|
||||||
if (sort_position >= TOP_K)
|
if (sort_position >= TOP_K)
|
||||||
@ -209,7 +227,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
|
|
||||||
// Using simple sorting for sorting by indices and when TOP_K == 1
|
// Using simple sorting for sorting by indices and when TOP_K == 1
|
||||||
#elif TOP_K == 1
|
#elif TOP_K == 1
|
||||||
INPUT0_TYPE val = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
INPUT0_TYPE val = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
result[0].index = 0;
|
result[0].index = 0;
|
||||||
result[0].value = val;
|
result[0].value = val;
|
||||||
bool already_exist = false;
|
bool already_exist = false;
|
||||||
@ -228,7 +246,7 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
}
|
}
|
||||||
|
|
||||||
indices[AXIS] = i;
|
indices[AXIS] = i;
|
||||||
INPUT0_TYPE in_data = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
INPUT0_TYPE in_data = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
if (val COMPARE_SIGN in_data) {
|
if (val COMPARE_SIGN in_data) {
|
||||||
result[top_k].index = i;
|
result[top_k].index = i;
|
||||||
result[top_k].value = in_data;
|
result[top_k].value = in_data;
|
||||||
@ -237,197 +255,206 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
}
|
}
|
||||||
val = INPUT0_FILL_VAL;
|
val = INPUT0_FILL_VAL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
|
|
||||||
#elif ((TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING))
|
|
||||||
for (uint i = 0; i < VALUES_NUM / 8; i++) {
|
|
||||||
uint index_offset = i * 8;
|
|
||||||
indices[AXIS] = result[index_offset].index = index_offset;
|
|
||||||
result[index_offset].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 1].index = index_offset + 1;
|
|
||||||
result[index_offset + 1].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 2].index = index_offset + 2;
|
|
||||||
result[index_offset + 2].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 3].index = index_offset + 3;
|
|
||||||
result[index_offset + 3].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 4].index = index_offset + 4;
|
|
||||||
result[index_offset + 4].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 5].index = index_offset + 5;
|
|
||||||
result[index_offset + 5].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 6].index = index_offset + 6;
|
|
||||||
result[index_offset + 6].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
indices[AXIS] = result[index_offset + 7].index = index_offset + 7;
|
|
||||||
result[index_offset + 7].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
|
|
||||||
indices[AXIS] = result[i].index = i;
|
|
||||||
result[i].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint k = 1; k < VALUES_NUM; k *= 2) {
|
|
||||||
for (uint left = 0; left + k < VALUES_NUM; left += k * 2) {
|
|
||||||
uint i, j, m;
|
|
||||||
uint right = left + k;
|
|
||||||
uint right_end = right + k;
|
|
||||||
if (right_end > VALUES_NUM) right_end = VALUES_NUM;
|
|
||||||
m = i = left; j = right;
|
|
||||||
while ((i < right) && (j < right_end)) {
|
|
||||||
if (result[i].value COMPARE_PARTIAL_SIGN result[j].value) {
|
|
||||||
temp_buf[m++] = result[i++];
|
|
||||||
} else {
|
|
||||||
temp_buf[m++] = result[j++];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
while (i < right)
|
|
||||||
temp_buf[m++] = result[i++];
|
|
||||||
while (j < right_end)
|
|
||||||
temp_buf[m++] = result[j++];
|
|
||||||
for (m = left; m < right_end; m++)
|
|
||||||
result[m] = temp_buf[m];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// In other cases for sorting by indices using mixed partial/merge sorting
|
|
||||||
#else // SORT_BY_VALUE
|
#else // SORT_BY_VALUE
|
||||||
for (uint i = 0; i < VALUES_NUM / 8; i++) {
|
// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
|
||||||
uint index_offset = i * 8;
|
bool use_merge_sorting = (TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING);
|
||||||
indices[AXIS] = temp_buf[index_offset].index = index_offset;
|
if (use_merge_sorting) {
|
||||||
temp_buf[index_offset].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
for (uint i = 0; i < VALUES_NUM / 8; i++) {
|
||||||
indices[AXIS] = temp_buf[index_offset + 1].index = index_offset + 1;
|
uint index_offset = i * 8;
|
||||||
temp_buf[index_offset + 1].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset].index = index_offset;
|
||||||
indices[AXIS] = temp_buf[index_offset + 2].index = index_offset + 2;
|
result[index_offset].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
temp_buf[index_offset + 2].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset + 1].index = index_offset + 1;
|
||||||
indices[AXIS] = temp_buf[index_offset + 3].index = index_offset + 3;
|
result[index_offset + 1].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
temp_buf[index_offset + 3].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset + 2].index = index_offset + 2;
|
||||||
indices[AXIS] = temp_buf[index_offset + 4].index = index_offset + 4;
|
result[index_offset + 2].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
temp_buf[index_offset + 4].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset + 3].index = index_offset + 3;
|
||||||
indices[AXIS] = temp_buf[index_offset + 5].index = index_offset + 5;
|
result[index_offset + 3].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
temp_buf[index_offset + 5].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset + 4].index = index_offset + 4;
|
||||||
indices[AXIS] = temp_buf[index_offset + 6].index = index_offset + 6;
|
result[index_offset + 4].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
temp_buf[index_offset + 6].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset + 5].index = index_offset + 5;
|
||||||
indices[AXIS] = temp_buf[index_offset + 7].index = index_offset + 7;
|
result[index_offset + 5].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
temp_buf[index_offset + 7].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
indices[AXIS] = result[index_offset + 6].index = index_offset + 6;
|
||||||
}
|
result[index_offset + 6].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = result[index_offset + 7].index = index_offset + 7;
|
||||||
|
result[index_offset + 7].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
}
|
||||||
|
|
||||||
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
|
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
|
||||||
indices[AXIS] = temp_buf[i].index = i;
|
indices[AXIS] = result[i].index = i;
|
||||||
temp_buf[i].value = input[FUNC_CALL(get_input_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
result[i].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint group = 0; group < group_num - 1; group++) {
|
for (uint k = 1; k < VALUES_NUM; k *= 2) {
|
||||||
uint group_offset = group * group_size;
|
for (uint left = 0; left + k < VALUES_NUM; left += k * 2) {
|
||||||
for (uint k = 1; k < group_size; k *= 2) {
|
|
||||||
for (uint left = 0; left + k < group_size; left += k * 2) {
|
|
||||||
uint i, j, m;
|
uint i, j, m;
|
||||||
uint right = left + k;
|
uint right = left + k;
|
||||||
uint right_end = right + k;
|
uint right_end = right + k;
|
||||||
if (right_end > group_size) right_end = group_size;
|
if (right_end > VALUES_NUM) right_end = VALUES_NUM;
|
||||||
m = i = left; j = right;
|
m = i = left; j = right;
|
||||||
while ((i < right) && (j < right_end)) {
|
while ((i < right) && (j < right_end)) {
|
||||||
if (temp_buf[group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[group_offset + j].value) {
|
if (result[i].value COMPARE_PARTIAL_SIGN result[j].value) {
|
||||||
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
|
temp_buf[m++] = result[i++];
|
||||||
} else {
|
} else {
|
||||||
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
|
temp_buf[m++] = result[j++];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (i < right)
|
while (i < right)
|
||||||
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
|
temp_buf[m++] = result[i++];
|
||||||
while (j < right_end)
|
while (j < right_end)
|
||||||
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
|
temp_buf[m++] = result[j++];
|
||||||
for (m = left; m < right_end; m++)
|
for (m = left; m < right_end; m++)
|
||||||
temp_buf[group_offset + m] = result[group_offset + m];
|
result[m] = temp_buf[m];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
// In other cases for sorting by indices using mixed partial/merge sorting
|
||||||
|
for (uint i = 0; i < VALUES_NUM / 8; i++) {
|
||||||
|
uint index_offset = i * 8;
|
||||||
|
indices[AXIS] = temp_buf[index_offset].index = index_offset;
|
||||||
|
temp_buf[index_offset].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 1].index = index_offset + 1;
|
||||||
|
temp_buf[index_offset + 1].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 2].index = index_offset + 2;
|
||||||
|
temp_buf[index_offset + 2].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 3].index = index_offset + 3;
|
||||||
|
temp_buf[index_offset + 3].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 4].index = index_offset + 4;
|
||||||
|
temp_buf[index_offset + 4].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 5].index = index_offset + 5;
|
||||||
|
temp_buf[index_offset + 5].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 6].index = index_offset + 6;
|
||||||
|
temp_buf[index_offset + 6].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
indices[AXIS] = temp_buf[index_offset + 7].index = index_offset + 7;
|
||||||
|
temp_buf[index_offset + 7].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
|
}
|
||||||
|
|
||||||
for (uint k = 1; k < last_group_size; k *= 2) {
|
for (uint i = (VALUES_NUM / 8) * 8; i < VALUES_NUM; i++) {
|
||||||
for (uint left = 0; left + k < last_group_size; left += k * 2) {
|
indices[AXIS] = temp_buf[i].index = i;
|
||||||
uint i, j, m;
|
temp_buf[i].value = input[FUNC_CALL(get_input_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])];
|
||||||
uint right = left + k;
|
}
|
||||||
uint right_end = right + k;
|
|
||||||
if (right_end > last_group_size) right_end = last_group_size;
|
for (uint group = 0; group < group_num - 1; group++) {
|
||||||
m = i = left; j = right;
|
uint group_offset = group * group_size;
|
||||||
while ((i < right) && (j < right_end)) {
|
for (uint k = 1; k < group_size; k *= 2) {
|
||||||
if (temp_buf[last_group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[last_group_offset + j].value) {
|
for (uint left = 0; left + k < group_size; left += k * 2) {
|
||||||
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
|
uint i, j, m;
|
||||||
} else {
|
uint right = left + k;
|
||||||
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
|
uint right_end = right + k;
|
||||||
|
if (right_end > group_size) right_end = group_size;
|
||||||
|
m = i = left; j = right;
|
||||||
|
while ((i < right) && (j < right_end)) {
|
||||||
|
if (temp_buf[group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[group_offset + j].value) {
|
||||||
|
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
|
||||||
|
} else {
|
||||||
|
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (i < right)
|
||||||
|
result[group_offset + (m++)] = temp_buf[group_offset + (i++)];
|
||||||
|
while (j < right_end)
|
||||||
|
result[group_offset + (m++)] = temp_buf[group_offset + (j++)];
|
||||||
|
for (m = left; m < right_end; m++)
|
||||||
|
temp_buf[group_offset + m] = result[group_offset + m];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (i < right)
|
|
||||||
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
|
|
||||||
while (j < right_end)
|
|
||||||
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
|
|
||||||
for (m = left; m < right_end; m++)
|
|
||||||
temp_buf[last_group_offset + m] = result[last_group_offset + m];
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
uint merge_counter[group_num];
|
for (uint k = 1; k < last_group_size; k *= 2) {
|
||||||
uint max_merge_counter[group_num];
|
for (uint left = 0; left + k < last_group_size; left += k * 2) {
|
||||||
iav_type merge_buf;
|
uint i, j, m;
|
||||||
bool subgroup_done[group_num];
|
uint right = left + k;
|
||||||
|
uint right_end = right + k;
|
||||||
unroll_for (uint i = 0; i < group_num - 1; i++) {
|
if (right_end > last_group_size) right_end = last_group_size;
|
||||||
merge_counter[i] = 0;
|
m = i = left; j = right;
|
||||||
max_merge_counter[i] = group_size;
|
while ((i < right) && (j < right_end)) {
|
||||||
subgroup_done[i] = false;
|
if (temp_buf[last_group_offset + i].value COMPARE_PARTIAL_SIGN temp_buf[last_group_offset + j].value) {
|
||||||
}
|
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
|
||||||
|
} else {
|
||||||
merge_counter[group_num - 1] = 0;
|
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
|
||||||
max_merge_counter[group_num - 1] = last_group_size;
|
}
|
||||||
subgroup_done[group_num - 1] = false;
|
}
|
||||||
|
while (i < right)
|
||||||
for (uint i = 0; i < TOP_K; i++) {
|
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (i++)];
|
||||||
bool merge_buf_done = false;
|
while (j < right_end)
|
||||||
uint merge_buf_index = 0;
|
result[last_group_offset + (m++)] = temp_buf[last_group_offset + (j++)];
|
||||||
for (uint j = 0; j < group_num; j++) {
|
for (m = left; m < right_end; m++)
|
||||||
if (subgroup_done[j])
|
temp_buf[last_group_offset + m] = result[last_group_offset + m];
|
||||||
continue;
|
|
||||||
|
|
||||||
uint test_index = j * group_size + merge_counter[j];
|
|
||||||
|
|
||||||
if (!merge_buf_done) {
|
|
||||||
merge_buf = temp_buf[test_index];
|
|
||||||
merge_buf_done = true;
|
|
||||||
merge_buf_index = j;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (temp_buf[test_index].value COMPARE_MERGE_SIGN merge_buf.value) {
|
|
||||||
merge_buf = temp_buf[test_index];
|
|
||||||
merge_buf_index = j;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
merge_counter[merge_buf_index]++;
|
#ifdef IS_DYNAMIC
|
||||||
if (merge_counter[merge_buf_index] == max_merge_counter[merge_buf_index])
|
const uint counter_size = group_num * 4;
|
||||||
subgroup_done[merge_buf_index] = true;
|
const uint counter_offset = counter_size * OPERATION_NUM;
|
||||||
|
__global uint* merge_counter = OFFSET_GLOBAL_PTR(uint, tmp_buffer1, output_idx * counter_size);
|
||||||
|
__global uint* max_merge_counter = OFFSET_GLOBAL_PTR(uint, tmp_buffer1, counter_offset + output_idx * counter_size);
|
||||||
|
__global bool* subgroup_done = OFFSET_GLOBAL_PTR(bool, tmp_buffer2, output_idx);
|
||||||
|
#else
|
||||||
|
uint merge_counter[group_num];
|
||||||
|
uint max_merge_counter[group_num];
|
||||||
|
bool subgroup_done[group_num];
|
||||||
|
#endif
|
||||||
|
iav_type merge_buf;
|
||||||
|
|
||||||
result[i] = merge_buf;
|
unroll_for (uint i = 0; i < group_num - 1; i++) {
|
||||||
|
merge_counter[i] = 0;
|
||||||
|
max_merge_counter[i] = group_size;
|
||||||
|
subgroup_done[i] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
merge_counter[group_num - 1] = 0;
|
||||||
|
max_merge_counter[group_num - 1] = last_group_size;
|
||||||
|
subgroup_done[group_num - 1] = false;
|
||||||
|
|
||||||
|
for (uint i = 0; i < TOP_K; i++) {
|
||||||
|
bool merge_buf_done = false;
|
||||||
|
uint merge_buf_index = 0;
|
||||||
|
for (uint j = 0; j < group_num; j++) {
|
||||||
|
if (subgroup_done[j])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
uint test_index = j * group_size + merge_counter[j];
|
||||||
|
|
||||||
|
if (!merge_buf_done) {
|
||||||
|
merge_buf = temp_buf[test_index];
|
||||||
|
merge_buf_done = true;
|
||||||
|
merge_buf_index = j;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (temp_buf[test_index].value COMPARE_MERGE_SIGN merge_buf.value) {
|
||||||
|
merge_buf = temp_buf[test_index];
|
||||||
|
merge_buf_index = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
merge_counter[merge_buf_index]++;
|
||||||
|
if (merge_counter[merge_buf_index] == max_merge_counter[merge_buf_index])
|
||||||
|
subgroup_done[merge_buf_index] = true;
|
||||||
|
|
||||||
|
result[i] = merge_buf;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif // SORT_BY_VALUE
|
#endif // SORT_BY_VALUE
|
||||||
|
|
||||||
#if SORT_BY_VALUE
|
#if SORT_BY_VALUE
|
||||||
indices[AXIS] = sort_position;
|
indices[AXIS] = sort_position;
|
||||||
#ifdef TOP_K_ORDER
|
#ifdef TOP_K_ORDER
|
||||||
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
|
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
|
||||||
#else
|
#else
|
||||||
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
|
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
|
||||||
#endif
|
#endif
|
||||||
#ifdef SECOND_OUTPUT_EXIST
|
#ifdef SECOND_OUTPUT_EXIST
|
||||||
#ifdef MULTIPLE_OUTPUTS
|
#ifdef MULTIPLE_OUTPUTS
|
||||||
#ifdef TOP_K_ORDER
|
#ifdef TOP_K_ORDER
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.index);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.index);
|
||||||
#else
|
#else
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.value);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result.value);
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
#ifdef TOP_K_ORDER
|
#ifdef TOP_K_ORDER
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
|
||||||
#else
|
#else
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
@ -445,22 +472,22 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
|
|||||||
|
|
||||||
indices[AXIS] = out_position;
|
indices[AXIS] = out_position;
|
||||||
#ifdef TOP_K_ORDER
|
#ifdef TOP_K_ORDER
|
||||||
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
|
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
|
||||||
#else
|
#else
|
||||||
output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
|
output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
|
||||||
#endif
|
#endif
|
||||||
#ifdef SECOND_OUTPUT_EXIST
|
#ifdef SECOND_OUTPUT_EXIST
|
||||||
#ifdef MULTIPLE_OUTPUTS
|
#ifdef MULTIPLE_OUTPUTS
|
||||||
#ifdef TOP_K_ORDER
|
#ifdef TOP_K_ORDER
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].index);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].index);
|
||||||
#else
|
#else
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].value);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_OUTPUT1_TYPE(result[top_k].value);
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
#ifdef TOP_K_ORDER
|
#ifdef TOP_K_ORDER
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
|
||||||
#else
|
#else
|
||||||
second_output[FUNC_CALL(get_output_index)(indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
|
second_output[FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[0], indices[1], 0, indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -16,7 +16,10 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
__attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1)))
|
__attribute__((reqd_work_group_size(LOCAL_SIZE, 1, 1)))
|
||||||
KERNEL(arg_max_gpu_top_k)(const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output)
|
KERNEL(arg_max_gpu_top_k)(
|
||||||
|
OPTIONAL_SHAPE_INFO_ARG
|
||||||
|
const __global INPUT0_TYPE* input,
|
||||||
|
__global OUTPUT_TYPE* output)
|
||||||
{
|
{
|
||||||
#include "include/arg_max_min_common.cl"
|
#include "include/arg_max_min_common.cl"
|
||||||
uint results[TOP_K];
|
uint results[TOP_K];
|
||||||
|
@ -20,6 +20,25 @@ size_t getOperationNumber(const arg_max_min_params& params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string getOperationNumberString(const arg_max_min_params& params) {
|
||||||
|
const auto& output = params.outputs[0];
|
||||||
|
auto x = toCodeString(output.X(), 11);
|
||||||
|
auto y = toCodeString(output.Y(), 10);
|
||||||
|
auto z = toCodeString(output.Z(), 9);
|
||||||
|
auto w = toCodeString(output.W(), 8);
|
||||||
|
auto f = toCodeString(output.Feature(), 7);
|
||||||
|
auto b = toCodeString(output.Batch(), 6);
|
||||||
|
switch (params.argMaxMinAxis) {
|
||||||
|
case ArgMaxMinAxis::BATCH: return toVectorMulString({x, y, z, f});
|
||||||
|
case ArgMaxMinAxis::FEATURE: return toVectorMulString({x, y, z, b});
|
||||||
|
case ArgMaxMinAxis::Z: return toVectorMulString({y, z, f, b});
|
||||||
|
case ArgMaxMinAxis::Y: return toVectorMulString({x, z, f, b});
|
||||||
|
case ArgMaxMinAxis::X: return toVectorMulString({y, z, f, b});
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Unsupported axis");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t getSortSize(const arg_max_min_params& params) {
|
size_t getSortSize(const arg_max_min_params& params) {
|
||||||
switch (params.argMaxMinAxis) {
|
switch (params.argMaxMinAxis) {
|
||||||
case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v;
|
case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v;
|
||||||
@ -65,6 +84,7 @@ ParamsKey ArgMaxMinKernelAxis::GetSupportedKey() const {
|
|||||||
k.EnableBatching();
|
k.EnableBatching();
|
||||||
k.EnableTensorPitches();
|
k.EnableTensorPitches();
|
||||||
k.EnableTensorOffset();
|
k.EnableTensorOffset();
|
||||||
|
k.EnableDynamicShapesSupport();
|
||||||
return k;
|
return k;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,38 +103,83 @@ bool ArgMaxMinKernelAxis::Validate(const Params& p, const optional_params& o) co
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelAxis::SetDefault(const arg_max_min_params& params) const {
|
||||||
|
DispatchData dispatchData;
|
||||||
|
|
||||||
|
if (!params.has_dynamic_tensors()) {
|
||||||
|
size_t ops_size = getOperationNumber(params);
|
||||||
|
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
|
||||||
|
size_t sort_size = params.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(params) : 1;
|
||||||
|
|
||||||
|
dispatchData.gws = { ops_size, sort_size, 1 };
|
||||||
|
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dispatchData;
|
||||||
|
}
|
||||||
|
|
||||||
KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const optional_params& options) const {
|
KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||||
if (!Validate(params, options)) {
|
if (!Validate(params, options)) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
|
const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);
|
||||||
|
bool is_dynamic = orgParams.has_dynamic_tensors();
|
||||||
|
|
||||||
size_t ops_size = getOperationNumber(orgParams);
|
auto dispatchData = SetDefault(orgParams);
|
||||||
ops_size = ops_size > 1 ? Align(ops_size, 32) : 1;
|
|
||||||
size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;
|
|
||||||
|
|
||||||
DispatchData dispatchData;
|
|
||||||
|
|
||||||
dispatchData.gws = { ops_size, sort_size, 1 };
|
|
||||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
|
||||||
|
|
||||||
KernelData kd = KernelData::Default<arg_max_min_params>(params);
|
KernelData kd = KernelData::Default<arg_max_min_params>(params);
|
||||||
|
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
|
||||||
|
const auto& prim_params = static_cast<const arg_max_min_params&>(params);
|
||||||
|
auto dispatchData = SetDefault(prim_params);
|
||||||
|
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
|
||||||
|
kd.kernels[0].params.workGroups.global = dispatchData.gws;
|
||||||
|
kd.kernels[0].params.workGroups.local = dispatchData.lws;
|
||||||
|
|
||||||
|
const size_t elem_size = prim_params.inputs[0].ElementSize();
|
||||||
|
const size_t iav_type_size = elem_size + 4;
|
||||||
|
const size_t sort_size = getSortSize(prim_params);
|
||||||
|
const size_t ops_size = getOperationNumber(prim_params);
|
||||||
|
const size_t group_size = prim_params.topK >= 8 ? prim_params.topK : 8;
|
||||||
|
const size_t group_num = ((sort_size - 1) / group_size) + 1;
|
||||||
|
|
||||||
|
kd.internalBufferSizes.clear();
|
||||||
|
kd.internalBufferSizes.push_back(iav_type_size * sort_size * ops_size * 2);
|
||||||
|
kd.internalBufferSizes.push_back(4 * group_num * ops_size * 2);
|
||||||
|
kd.internalBufferSizes.push_back(ops_size * elem_size);
|
||||||
|
kd.internalBufferDataType = prim_params.inputs[0].GetDType();
|
||||||
|
};
|
||||||
|
|
||||||
auto cldnn_jit = GetJitConstants(orgParams);
|
auto cldnn_jit = GetJitConstants(orgParams);
|
||||||
auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options);
|
auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options);
|
||||||
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||||
|
|
||||||
auto& kernel = kd.kernels[0];
|
auto& kernel = kd.kernels[0];
|
||||||
if (!orgParams.use_multiple_outputs) {
|
FillCLKernelData(kernel,
|
||||||
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point);
|
dispatchData,
|
||||||
} else {
|
params.engineInfo,
|
||||||
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point,
|
kernelName,
|
||||||
"", false, false, 1, GetFusedPrimitiveInputsCount(params), 2);
|
jit,
|
||||||
}
|
entry_point,
|
||||||
|
EXE_MODE_DEFAULT,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
1,
|
||||||
|
GetFusedPrimitiveInputsCount(params),
|
||||||
|
orgParams.use_multiple_outputs ? 2 : 1,
|
||||||
|
is_dynamic);
|
||||||
|
|
||||||
if (orgParams.has_second_output && !orgParams.use_multiple_outputs)
|
if (orgParams.has_second_output && !orgParams.use_multiple_outputs)
|
||||||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
|
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 1});
|
||||||
|
|
||||||
|
if (is_dynamic) {
|
||||||
|
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
|
||||||
|
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
|
||||||
|
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
|
||||||
|
kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes());
|
||||||
|
kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes());
|
||||||
|
kd.internalBufferSizes.push_back(orgParams.inputs[0].PhysicalSizeInBytes());
|
||||||
|
kd.internalBufferDataType = orgParams.inputs[0].GetDType();
|
||||||
|
}
|
||||||
|
|
||||||
return {kd};
|
return {kd};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,7 +190,13 @@ KernelsPriority ArgMaxMinKernelAxis::GetKernelsPriority(const Params& /*params*/
|
|||||||
JitConstants ArgMaxMinKernelAxis::GetJitConstants(const arg_max_min_params& params) const {
|
JitConstants ArgMaxMinKernelAxis::GetJitConstants(const arg_max_min_params& params) const {
|
||||||
auto jit = ArgMaxMinKernelBase::GetJitConstants(params);
|
auto jit = ArgMaxMinKernelBase::GetJitConstants(params);
|
||||||
|
|
||||||
jit.AddConstant(MakeJitConstant("OPERATION_NUM", getOperationNumber(params)));
|
if (params.has_dynamic_tensors()) {
|
||||||
|
const std::string operation_num = getOperationNumberString(params);
|
||||||
|
jit.AddConstant(MakeJitConstant("OPERATION_NUM", operation_num));
|
||||||
|
} else {
|
||||||
|
const size_t operation_num = getOperationNumber(params);
|
||||||
|
jit.AddConstant(MakeJitConstant("OPERATION_NUM", operation_num));
|
||||||
|
}
|
||||||
if (params.argMaxMinSortType == ArgMaxMinSortType::VALUE)
|
if (params.argMaxMinSortType == ArgMaxMinSortType::VALUE)
|
||||||
jit.AddConstant(MakeJitConstant("SORT_BY_VALUE", 1));
|
jit.AddConstant(MakeJitConstant("SORT_BY_VALUE", 1));
|
||||||
else
|
else
|
||||||
|
@ -13,6 +13,7 @@ public:
|
|||||||
virtual ~ArgMaxMinKernelAxis() {}
|
virtual ~ArgMaxMinKernelAxis() {}
|
||||||
|
|
||||||
JitConstants GetJitConstants(const arg_max_min_params& params) const override;
|
JitConstants GetJitConstants(const arg_max_min_params& params) const override;
|
||||||
|
DispatchData SetDefault(const arg_max_min_params& params) const override;
|
||||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||||
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
||||||
ParamsKey GetSupportedKey() const override;
|
ParamsKey GetSupportedKey() const override;
|
||||||
|
@ -27,8 +27,10 @@ JitConstants ArgMaxMinKernelBase::GetJitConstants(const arg_max_min_params& para
|
|||||||
ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelBase::SetDefault(const arg_max_min_params& params) const {
|
ArgMaxMinKernelBase::DispatchData ArgMaxMinKernelBase::SetDefault(const arg_max_min_params& params) const {
|
||||||
DispatchData dispatchData;
|
DispatchData dispatchData;
|
||||||
|
|
||||||
dispatchData.gws = { 128, params.inputs[0].Batch().v, 1 };
|
if (!params.has_dynamic_inputs()) {
|
||||||
dispatchData.lws = { 128, 1, 1 };
|
dispatchData.gws = { 128, params.inputs[0].Batch().v, 1 };
|
||||||
|
dispatchData.lws = { 128, 1, 1 };
|
||||||
|
}
|
||||||
|
|
||||||
return dispatchData;
|
return dispatchData;
|
||||||
}
|
}
|
||||||
@ -43,13 +45,32 @@ KernelsData ArgMaxMinKernelBase::GetCommonKernelsData(const Params& params, cons
|
|||||||
DispatchData dispatchData = SetDefault(orgParams);
|
DispatchData dispatchData = SetDefault(orgParams);
|
||||||
|
|
||||||
KernelData kd = KernelData::Default<arg_max_min_params>(params);
|
KernelData kd = KernelData::Default<arg_max_min_params>(params);
|
||||||
|
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
|
||||||
|
const auto& prim_params = static_cast<const arg_max_min_params&>(params);
|
||||||
|
auto dispatchData = SetDefault(prim_params);
|
||||||
|
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
|
||||||
|
kd.kernels[0].params.workGroups.global = dispatchData.gws;
|
||||||
|
kd.kernels[0].params.workGroups.local = dispatchData.lws;
|
||||||
|
};
|
||||||
|
|
||||||
auto cldnn_jit = GetJitConstants(orgParams);
|
auto cldnn_jit = GetJitConstants(orgParams);
|
||||||
auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options);
|
auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, params, options);
|
||||||
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||||
|
|
||||||
auto& kernel = kd.kernels[0];
|
auto& kernel = kd.kernels[0];
|
||||||
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point);
|
FillCLKernelData(kernel,
|
||||||
|
dispatchData,
|
||||||
|
params.engineInfo,
|
||||||
|
kernelName,
|
||||||
|
jit,
|
||||||
|
entry_point,
|
||||||
|
EXE_MODE_DEFAULT,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
(uint32_t)orgParams.inputs.size(),
|
||||||
|
GetFusedPrimitiveInputsCount(params),
|
||||||
|
(uint32_t)orgParams.outputs.size(),
|
||||||
|
orgParams.has_dynamic_tensors());
|
||||||
|
|
||||||
return {kd};
|
return {kd};
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ ParamsKey ArgMaxMinKernelGPURef::GetSupportedKey() const {
|
|||||||
k.EnableDifferentTypes();
|
k.EnableDifferentTypes();
|
||||||
k.EnableBatching();
|
k.EnableBatching();
|
||||||
k.EnableTensorPitches();
|
k.EnableTensorPitches();
|
||||||
|
k.EnableDynamicShapesSupport();
|
||||||
return k;
|
return k;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@
|
|||||||
#include <intel_gpu/primitives/input_layout.hpp>
|
#include <intel_gpu/primitives/input_layout.hpp>
|
||||||
#include <intel_gpu/primitives/mutable_data.hpp>
|
#include <intel_gpu/primitives/mutable_data.hpp>
|
||||||
|
|
||||||
|
#include <arg_max_min_inst.h>
|
||||||
|
|
||||||
#include "test_utils.h"
|
#include "test_utils.h"
|
||||||
|
|
||||||
using namespace cldnn;
|
using namespace cldnn;
|
||||||
@ -870,3 +872,52 @@ TEST(top_k_layer_tests, md_sync) {
|
|||||||
TEST(export_import_top_k_layer_tests, md_sync) {
|
TEST(export_import_top_k_layer_tests, md_sync) {
|
||||||
test_top_k_layer_md_sync<int>(true);
|
test_top_k_layer_md_sync<int>(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(arg_max_min_gpu, dynamic) {
|
||||||
|
static const int32_t x_size = 2, y_size = 2, feature_num = 4, batch_num = 2;
|
||||||
|
auto& engine = get_test_engine();
|
||||||
|
const int top_k = 2;
|
||||||
|
auto input_layout_dynamic = layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx};
|
||||||
|
auto input_layout_static = layout{ov::PartialShape{batch_num, feature_num, y_size, x_size}, data_types::f32, format::bfyx};
|
||||||
|
auto input = engine.allocate_memory(input_layout_static);
|
||||||
|
|
||||||
|
topology topology;
|
||||||
|
topology.add(input_layout("input", input_layout_dynamic));
|
||||||
|
topology.add(arg_max_min("arg_max", { input_info("input") }, ov::op::TopKMode::MIN, top_k, 0));
|
||||||
|
|
||||||
|
std::vector<float> input_vec = {// y0x0 y0x1 y1x0 y1x1
|
||||||
|
/*b0f0*/ 0.1f, -0.1f, 0.9f, 1.5f,
|
||||||
|
/*b0f1*/ 0.2f, 0.2f, -10.f, 5.2f,
|
||||||
|
/*b0f2*/ 0.2f, 0.2f, -10.f, 5.2f,
|
||||||
|
/*b0f3*/ 0.2f, 0.2f, -10.f, 4.2f,
|
||||||
|
|
||||||
|
/*b1f0*/ 3.f, 0.5f, 7.f, 10.f,
|
||||||
|
/*b1f1*/ 4.f, 0.5f, 8.f, 8.2f,
|
||||||
|
/*b1f2*/ 0.2f, 0.2f, -10.f, 5.2f,
|
||||||
|
/*b1f3*/ 4.f, 0.5f, 8.f, 8.2f};
|
||||||
|
|
||||||
|
set_values(input, input_vec);
|
||||||
|
|
||||||
|
ExecutionConfig config;
|
||||||
|
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||||
|
network network(engine, topology, config);
|
||||||
|
network.set_input_data("input", input);
|
||||||
|
|
||||||
|
auto inst = network.get_primitive("arg_max");
|
||||||
|
auto impl = inst->get_impl();
|
||||||
|
ASSERT_TRUE(impl != nullptr);
|
||||||
|
ASSERT_TRUE(impl->is_dynamic());
|
||||||
|
|
||||||
|
auto outputs = network.execute();
|
||||||
|
ASSERT_EQ(outputs.size(), size_t(1));
|
||||||
|
ASSERT_EQ(outputs.begin()->first, "arg_max");
|
||||||
|
|
||||||
|
const int out_size = y_size * feature_num * x_size * top_k;
|
||||||
|
auto output = outputs.at("arg_max").get_memory();
|
||||||
|
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
|
||||||
|
|
||||||
|
ASSERT_EQ(output_ptr.size(), out_size);
|
||||||
|
for (uint32_t i = 0; i < out_size; i++) {
|
||||||
|
ASSERT_FLOAT_EQ(output_ptr[i], i < (out_size / 2) ? 0 : 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user