Update scatter nd update kernel to support blocked_formats (#11533)
* draft pr for planar and fsv16 * draft pr for general test * update fusion test (failing) * update fusing test (pass) * update fusing test (include exception) * clean gpu unit test * review comment applied * unit test cases added & cpplint applied * cpplint error fixed * change gpu test cases for fp16 * fusing test fix generate_unique_indices * fix typo * revise cl kernel for occasions when updates shape is altered
This commit is contained in:
parent
ca7ddae9ba
commit
6e3dd4adce
@ -55,12 +55,92 @@ attach_scatter_nd_update_impl::attach_scatter_nd_update_impl() {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::i8, format::bfyx),
|
||||
std::make_tuple(data_types::u8, format::bfyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
std::make_tuple(data_types::i32, format::bfzyx),
|
||||
std::make_tuple(data_types::i8, format::bfzyx),
|
||||
std::make_tuple(data_types::u8, format::bfzyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
||||
std::make_tuple(data_types::i32, format::bfwzyx),
|
||||
std::make_tuple(data_types::i8, format::bfwzyx),
|
||||
std::make_tuple(data_types::u8, format::bfwzyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_yx_fsv4),
|
||||
std::make_tuple(data_types::f16, format::b_fs_yx_fsv4),
|
||||
std::make_tuple(data_types::i32, format::b_fs_yx_fsv4),
|
||||
std::make_tuple(data_types::i8, format::b_fs_yx_fsv4),
|
||||
std::make_tuple(data_types::u8, format::b_fs_yx_fsv4),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::f16, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::i32, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::f16, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::i32, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::i8, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::u8, format::b_fs_yx_fsv32),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_zyx_fsv16),
|
||||
std::make_tuple(data_types::f16, format::b_fs_zyx_fsv16),
|
||||
std::make_tuple(data_types::i32, format::b_fs_zyx_fsv16),
|
||||
std::make_tuple(data_types::i8, format::b_fs_zyx_fsv16),
|
||||
std::make_tuple(data_types::u8, format::b_fs_zyx_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_zyx_fsv32),
|
||||
std::make_tuple(data_types::f16, format::b_fs_zyx_fsv32),
|
||||
std::make_tuple(data_types::i32, format::b_fs_zyx_fsv32),
|
||||
std::make_tuple(data_types::i8, format::b_fs_zyx_fsv32),
|
||||
std::make_tuple(data_types::u8, format::b_fs_zyx_fsv32),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv2),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv2),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv2),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv2),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv2),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv2),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv2),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv4),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1397,7 +1397,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::region_yolo::type_id() &&
|
||||
prim.type() != cldnn::normalize::type_id() &&
|
||||
prim.type() != cldnn::mvn::type_id() &&
|
||||
prim.type() != cldnn::gather::type_id()) {
|
||||
prim.type() != cldnn::gather::type_id() &&
|
||||
prim.type() != cldnn::scatter_nd_update::type_id()) {
|
||||
can_use_fsv16 = false;
|
||||
}
|
||||
|
||||
@ -1423,6 +1424,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::softmax::type_id() &&
|
||||
prim.type() != cldnn::fully_connected::type_id() &&
|
||||
prim.type() != cldnn::generic_layer::type_id() &&
|
||||
prim.type() != cldnn::scatter_nd_update::type_id() &&
|
||||
prim.type() != cldnn::quantize::type_id())
|
||||
can_use_bs_fs_yx_bsv16_fsv16 = false;
|
||||
}
|
||||
|
@ -14,17 +14,15 @@ ParamsKey ScatterNDUpdateKernelRef::GetSupportedKey() const {
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableInputDataType(Datatype::INT8);
|
||||
k.EnableInputDataType(Datatype::UINT8);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
k.EnableAllInputLayout();
|
||||
k.EnableAllOutputLayout();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
@ -115,14 +113,10 @@ bool ScatterNDUpdateKernelRef::Validate(const Params& p, const optional_params&
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::string GetInputBlockND(const scatter_nd_update_params& params) {
|
||||
const auto& input = params.inputs[0];
|
||||
static std::string GetInputBlockND(const scatter_nd_update_params& params, int num, const int rank) {
|
||||
const auto& input = params.inputs[num];
|
||||
auto input_dims = input.LogicalDims();
|
||||
std::reverse(input_dims.begin(), input_dims.end());
|
||||
while (!input_dims.empty() && input_dims.back() == 1) {
|
||||
input_dims.pop_back();
|
||||
}
|
||||
const int rank = static_cast<int>(input_dims.size());
|
||||
std::vector<size_t> block_nd(rank + 1);
|
||||
block_nd[rank] = 1;
|
||||
for (int idx = (rank - 1); idx >= 0; idx--) {
|
||||
@ -157,9 +151,14 @@ KernelsData ScatterNDUpdateKernelRef::GetKernelsData(const Params& params, const
|
||||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options, i);
|
||||
|
||||
if (i == 1) {
|
||||
int input0_rank = static_cast<int>(newParams.inputs[0].LogicalDims().size());
|
||||
int input2_rank = static_cast<int>(newParams.inputs[2].LogicalDims().size());
|
||||
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
|
||||
cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim));
|
||||
cldnn_jit.AddConstant(MakeJitConstant("INPUT_BLOCK_ND", GetInputBlockND(newParams)));
|
||||
cldnn_jit.AddConstant(MakeJitConstant("INPUT0_BLOCK_ND", GetInputBlockND(newParams, 0, input0_rank)));
|
||||
cldnn_jit.AddConstant(MakeJitConstant("INPUT1_BLOCK_ND", GetInputBlockND(newParams, 1, newParams.indices_rank - 1)));
|
||||
cldnn_jit.AddConstant(MakeJitConstant("INPUT2_BLOCK_ND", GetInputBlockND(newParams, 2, input2_rank)));
|
||||
cldnn_jit.AddConstant(MakeJitConstant("INDICES_RANK", newParams.indices_rank));
|
||||
}
|
||||
std::pair<std::string, std::string> jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
@ -16,6 +17,24 @@
|
||||
#define ORDER b,f,w,z,y,x
|
||||
#endif
|
||||
|
||||
#if INPUT2_DIMS == 4
|
||||
#define UPD_ORDER upd_b,upd_f,upd_y,upd_x
|
||||
#elif INPUT2_DIMS == 5
|
||||
#define UPD_ORDER upd_b,upd_f,upd_z,upd_y,upd_x
|
||||
#elif INPUT2_DIMS == 6
|
||||
#define UPD_ORDER upd_b,upd_f,upd_w,upd_z,upd_y,upd_x
|
||||
#endif
|
||||
|
||||
#if INPUT1_DIMS == 4
|
||||
#define IDX_ORDER idx_b,idx_f,idx_y,idx_x
|
||||
#elif INPUT1_DIMS == 5
|
||||
#define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x
|
||||
#elif INPUT1_DIMS == 6
|
||||
#define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x
|
||||
#endif
|
||||
|
||||
#define INDICES_MAX_DIM 6
|
||||
|
||||
KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
|
||||
const __global INPUT1_TYPE* indices,
|
||||
const __global INPUT2_TYPE* updates,
|
||||
@ -49,81 +68,122 @@ KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
|
||||
|
||||
#else // Second kernel
|
||||
|
||||
const uint blockND[] = {INPUT_BLOCK_ND};
|
||||
const uint k = INDICES_LAST_DIM;
|
||||
const uint size_to_update = blockND[INDICES_LAST_DIM];
|
||||
const uint indices_idx = dim2;
|
||||
const uint indices_offset = indices_idx * k;
|
||||
uint dst_offset = 0;
|
||||
|
||||
for (uint i = 0; i < k; i++) {
|
||||
INPUT1_TYPE idxValue = indices[indices_offset + i];
|
||||
dst_offset += idxValue * blockND[i + 1];
|
||||
}
|
||||
|
||||
uint update_offset = indices_idx * size_to_update;
|
||||
|
||||
for (int i = 0; i < size_to_update; i++) {
|
||||
uint dst_idx = dst_offset + i;
|
||||
uint up_idx = update_offset + i;
|
||||
INPUT2_TYPE val = updates[up_idx];
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
#if OUTPUT_DIMS == 4
|
||||
const uint y_pitch = OUTPUT_SIZE_X;
|
||||
const uint f_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||
|
||||
const uint b_remain = dst_idx % b_pitch;
|
||||
const uint f_remain = b_remain % f_pitch;
|
||||
const uint y_remain = f_remain % y_pitch;
|
||||
|
||||
const uint b = dst_idx / b_pitch;
|
||||
const uint f = b_remain / f_pitch;
|
||||
const uint y = f_remain / y_pitch;
|
||||
const uint x = y_remain;
|
||||
#elif OUTPUT_DIMS == 5
|
||||
const uint y_pitch = OUTPUT_SIZE_X;
|
||||
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||
const uint f_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||
|
||||
const uint b_remain = dst_idx % b_pitch;
|
||||
const uint f_remain = b_remain % f_pitch;
|
||||
const uint z_remain = f_remain % z_pitch;
|
||||
const uint y_remain = z_remain % y_pitch;
|
||||
|
||||
const uint b = dst_idx / b_pitch;
|
||||
const uint f = b_remain / f_pitch;
|
||||
const uint z = f_remain / z_pitch;
|
||||
const uint y = z_remain / y_pitch;
|
||||
const uint x = y_remain;
|
||||
#elif OUTPUT_DIMS == 6
|
||||
const uint y_pitch = OUTPUT_SIZE_X;
|
||||
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||
const uint w_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||
const uint f_pitch = w_pitch * OUTPUT_SIZE_W;
|
||||
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||
|
||||
const uint b_remain = dst_idx % b_pitch;
|
||||
const uint f_remain = b_remain % f_pitch;
|
||||
const uint w_remain = f_remain % w_pitch;
|
||||
const uint z_remain = w_remain % z_pitch;
|
||||
const uint y_remain = z_remain % y_pitch;
|
||||
|
||||
const uint b = dst_idx / b_pitch;
|
||||
const uint f = b_remain / f_pitch;
|
||||
const uint w = f_remain / w_pitch;
|
||||
const uint z = w_remain / z_pitch;
|
||||
const uint y = z_remain / y_pitch;
|
||||
const uint x = y_remain;
|
||||
#endif
|
||||
|
||||
FUSED_OPS_SECOND_KERNEL;
|
||||
output[dst_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
|
||||
#else
|
||||
output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||
const uint dataND[] = {INPUT0_BLOCK_ND};
|
||||
const uint updatesND[] = {INPUT2_BLOCK_ND};
|
||||
const uint indicesND[] = {INPUT1_BLOCK_ND};
|
||||
const uint size_to_update = dataND[INDICES_LAST_DIM];
|
||||
|
||||
#if INPUT1_DIMS == 4
|
||||
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#elif INPUT1_DIMS == 5
|
||||
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#elif INPUT1_DIMS == 6
|
||||
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#endif
|
||||
|
||||
#if INPUT0_DIMS == 4
|
||||
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#elif INPUT0_DIMS == 5
|
||||
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#elif INPUT0_DIMS == 6
|
||||
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#endif
|
||||
|
||||
// Get indices index
|
||||
uint idx[INDICES_MAX_DIM] = {0};
|
||||
uint rmd_idx = dim2;
|
||||
for (int i = 0; i < INDICES_RANK - 1; ++i) {
|
||||
idx[i] = rmd_idx / indicesND[1 + i];
|
||||
rmd_idx %= indicesND[1 + i];
|
||||
}
|
||||
|
||||
uint out[INDICES_MAX_DIM] = {0};
|
||||
for (int i = 0; i < indices_dim[INDICES_RANK - 1]; ++i) {
|
||||
idx[INDICES_RANK - 1] = i;
|
||||
const uint idx_b = idx[0];
|
||||
const uint idx_f = idx[1];
|
||||
#if INPUT1_DIMS == 4
|
||||
const uint idx_y = idx[2];
|
||||
const uint idx_x = idx[3];
|
||||
#elif INPUT1_DIMS == 5
|
||||
const uint idx_z = idx[2];
|
||||
const uint idx_y = idx[3];
|
||||
const uint idx_x = idx[4];
|
||||
#elif INPUT1_DIMS == 6
|
||||
const uint idx_w = idx[2];
|
||||
const uint idx_z = idx[3];
|
||||
const uint idx_y = idx[4];
|
||||
const uint idx_x = idx[5];
|
||||
#endif
|
||||
uint index = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
|
||||
out[i] = indices[index];
|
||||
|
||||
// Check if tensor size is valid
|
||||
// ex) when data format = bfyx and data shape = { 3, 3, 4, 1 }, indices shape is { 2, 1 } with rank = 2, indices values are { 1.0, 4.0 },
|
||||
// the second indices value is invalid as data shape has 'b' of size 3, and therefore 4 cannot be a correct index of data
|
||||
// If indices value is invalid, saturate value to max valid value (ex. 4.0 -> 2.0)
|
||||
if(out[i] >= data_dim[i])
|
||||
out[i] = data_dim[i] - 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < size_to_update; ++i) {
|
||||
// Define updates index
|
||||
uint upd[INDICES_MAX_DIM] = {0};
|
||||
for (int j = 0; j < INDICES_RANK - 1; ++j) {
|
||||
upd[j] = idx[j];
|
||||
}
|
||||
uint data_rmd = i, updates_rmd = i;
|
||||
for (int j = indices_dim[INDICES_RANK - 1]; j < INPUT0_DIMS; ++j) {
|
||||
out[j] = data_rmd / dataND[j + 1];
|
||||
data_rmd %= dataND[j + 1];
|
||||
}
|
||||
for (int k = INDICES_RANK - 1; k < INPUT2_DIMS; ++k) {
|
||||
upd[k] = updates_rmd / updatesND[k + 1];
|
||||
updates_rmd %= updatesND[k + 1];
|
||||
}
|
||||
// Get update index
|
||||
const uint upd_b = upd[0];
|
||||
const uint upd_f = upd[1];
|
||||
#if INPUT2_DIMS == 4
|
||||
const uint upd_y = upd[2];
|
||||
const uint upd_x = upd[3];
|
||||
#elif INPUT2_DIMS == 5
|
||||
const uint upd_z = upd[2];
|
||||
const uint upd_y = upd[3];
|
||||
const uint upd_x = upd[4];
|
||||
#elif INPUT2_DIMS == 6
|
||||
const uint upd_w = upd[2];
|
||||
const uint upd_z = upd[3];
|
||||
const uint upd_y = upd[4];
|
||||
const uint upd_x = upd[5];
|
||||
#endif
|
||||
uint upd_idx = GET_UPDATES_INDEX(INPUT2, UPD_ORDER);
|
||||
|
||||
// Get output index
|
||||
const uint b = out[0];
|
||||
const uint f = out[1];
|
||||
#if INPUT0_DIMS == 4
|
||||
const uint y = out[2];
|
||||
const uint x = out[3];
|
||||
#elif INPUT0_DIMS == 5
|
||||
const uint z = out[2];
|
||||
const uint y = out[3];
|
||||
const uint x = out[4];
|
||||
#elif INPUT0_DIMS == 6
|
||||
const uint w = out[2];
|
||||
const uint z = out[3];
|
||||
const uint y = out[4];
|
||||
const uint x = out[5];
|
||||
#endif
|
||||
uint out_idx = GET_OUTPUT_INDEX(ORDER);
|
||||
INPUT2_TYPE val = updates[upd_idx];
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_SECOND_KERNEL;
|
||||
output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
|
||||
#else
|
||||
output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -140,3 +200,15 @@ KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
|
||||
#ifdef ORDER
|
||||
#undef ORDER
|
||||
#endif
|
||||
|
||||
#ifdef UPD_ORDER
|
||||
#undef UPD_ORDER
|
||||
#endif
|
||||
|
||||
#ifdef IDX_ORDER
|
||||
#undef IDX_ORDER
|
||||
#endif
|
||||
|
||||
#ifdef INDICES_MAX_DIM
|
||||
#undef INDICES_MAX_DIM
|
||||
#endif
|
||||
|
@ -20,6 +20,21 @@ static void CreateScatterNDUpdateOp(Program& p, const std::shared_ptr<ngraph::op
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
auto indices_rank = op->get_input_shape(1).size();
|
||||
|
||||
auto indices_constant = std::dynamic_pointer_cast<ngraph::op::Constant>(op->get_input_node_shared_ptr(1));
|
||||
if (indices_constant) {
|
||||
auto indices = indices_constant->cast_vector<int32_t>();
|
||||
auto indices_last_dim = op->get_input_shape(1)[indices_rank - 1];
|
||||
auto data_shape = op->get_input_shape(0);
|
||||
bool valid = true;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
if (indices[i] >= data_shape[i % indices_last_dim])
|
||||
valid = false;
|
||||
}
|
||||
|
||||
if (!valid)
|
||||
IE_THROW() << "Invaild indices values";
|
||||
}
|
||||
|
||||
auto primitive = cldnn::scatter_nd_update(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
|
@ -47,7 +47,8 @@ target_link_libraries(${TARGET_NAME} PRIVATE openvino_intel_gpu_graph
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test_utils/
|
||||
$<TARGET_PROPERTY:openvino_intel_gpu_kernels,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
$<TARGET_PROPERTY:openvino_intel_gpu_runtime,INTERFACE_INCLUDE_DIRECTORIES>)
|
||||
$<TARGET_PROPERTY:openvino_intel_gpu_runtime,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
${CMAKE_HOME_DIRECTORY}/src/core/reference/include/)
|
||||
if(WIN32)
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE setupapi)
|
||||
elseif((NOT ANDROID) AND (UNIX))
|
||||
|
@ -12,6 +12,9 @@
|
||||
#include <intel_gpu/primitives/scatter_nd_update.hpp>
|
||||
|
||||
#include <cmath>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
@ -82,16 +85,16 @@ public:
|
||||
std::set<std::vector<T>> unique_indices;
|
||||
std::vector<T> result;
|
||||
auto indices_shape = p.indices_shape.sizes(get_default_format(p.indices_rank));
|
||||
auto last_indices_dim = indices_shape.back();
|
||||
auto data_shape = p.input_shape.sizes(p.input_format);
|
||||
auto last_indices_dim = indices_shape.at(p.indices_rank - 1);
|
||||
|
||||
auto count = 1;
|
||||
for (size_t i = 0; i < indices_shape.size() - 1; i++)
|
||||
count *= indices_shape[i];
|
||||
auto count = p.indices_shape.count() / last_indices_dim;
|
||||
|
||||
while (unique_indices.size() != count) {
|
||||
std::vector<T> indices;
|
||||
for (size_t i = 0; i < last_indices_dim; i++)
|
||||
indices.push_back(generate_random_val<T>(0, indices_shape[i]));
|
||||
for (size_t i = 0; i < last_indices_dim; i++) {
|
||||
indices.push_back(static_cast<T>(generate_random_val<int>(0, data_shape[i] - 1)));
|
||||
}
|
||||
|
||||
unique_indices.insert(indices);
|
||||
}
|
||||
@ -173,6 +176,55 @@ public:
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_6D_5 { 6, 7, 8, 9, 2, 2 }, { 5, 5, 1, 1 }, { 5, 8, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_6D_6 { 6, 7, 8, 9, 2, 2 }, { 5, 6, 1, 1 }, { 5, 1, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1 { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 8, 9 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3 { 6, 7, 8, 9, 10 }, { 5, 3, 1, 1 }, { 5, 9, 1, 1, 8 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4 { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5 { 6, 7, 8, 9, 10 }, { 5, 5, 1, 1 }, { 5, 1, 1, 1, 1 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, 3, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, 3, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8 { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, 4, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9 { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, 4, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1 { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 8, 9 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3 { 6, 7, 8, 9, 10 }, { 5, 3, 1, 1 }, { 5, 9, 1, 1, 8 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4 { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5 { 6, 7, 8, 9, 10 }, { 5, 5, 1, 1 }, { 5, 1, 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, 3, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, 3, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8 { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, 4, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9 { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, 4, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
|
||||
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
|
||||
|
||||
|
||||
class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {};
|
||||
TEST_P(scatter_nd_update_quantize, basic) {
|
||||
auto p = GetParam();
|
||||
@ -235,6 +287,54 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_quantize, ::testing::Val
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1, 2, 3 }, // FP16
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6, 2, 3 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9, 2, 3 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1, 2, 3 }, // FP32
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6, 2, 3 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9, 2, 3 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1, 2, 3 }, // FP16
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6, 2, 3 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1, 2, 3 }, // FP32
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5, 2, 3 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6, 2, 3 },
|
||||
}));
|
||||
|
||||
class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {};
|
||||
@ -298,4 +398,52 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1, 2, 5 }, // FP16
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1, 2, 5 }, // FP32
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1, 2, 5 }, // FP16
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1, 2, 5 }, // FP32
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6, 2, 5 },
|
||||
}));
|
||||
|
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/scatter_update.hpp>
|
||||
@ -12,17 +13,489 @@
|
||||
#include <intel_gpu/graph/network.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
#include <stdlib.h>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
|
||||
struct scatter_nd_update_basic_test_params
|
||||
{
|
||||
data_types input_type;
|
||||
data_types indices_type;
|
||||
data_types updates_type;
|
||||
format input_format;
|
||||
format indices_format;
|
||||
format updates_format;
|
||||
format input_result_format;
|
||||
format indices_result_format;
|
||||
format updates_result_format;
|
||||
tensor input_size;
|
||||
tensor indices_size;
|
||||
tensor updates_size;
|
||||
int indices_rank;
|
||||
};
|
||||
|
||||
struct scatter_nd_update_random_test : testing::TestWithParam<scatter_nd_update_basic_test_params>
|
||||
{
|
||||
format get_default_format(int rank = 4) {
|
||||
if (rank <= 4)
|
||||
return cldnn::format::bfyx;
|
||||
else if (rank == 5)
|
||||
return cldnn::format::bfzyx;
|
||||
else
|
||||
return cldnn::format::bfwzyx;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T generate_random_val(int min, int max, int k = 8) {
|
||||
static std::default_random_engine generator(random_seed);
|
||||
// 1/k is the resolution of the floating point numbers
|
||||
std::uniform_int_distribution<int> distribution(k * min, k * max);
|
||||
T val = (T)distribution(generator);
|
||||
val /= k;
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> generate_unique_indices(const scatter_nd_update_basic_test_params& p) {
|
||||
std::set<std::vector<T>> unique_indices;
|
||||
std::vector<T> result;
|
||||
auto indices_shape = p.indices_size.sizes(get_default_format(p.indices_rank));
|
||||
auto data_shape = p.input_size.sizes(p.input_format);
|
||||
auto last_indices_dim = indices_shape.at(p.indices_rank - 1);
|
||||
|
||||
auto count = p.indices_size.count() / last_indices_dim;
|
||||
|
||||
while (unique_indices.size() != count) {
|
||||
std::vector<T> indices;
|
||||
for (size_t i = 0; i < last_indices_dim; i++) {
|
||||
indices.push_back(static_cast<T>(generate_random_val<int>(0, data_shape[i] - 1)));
|
||||
}
|
||||
|
||||
unique_indices.insert(indices);
|
||||
}
|
||||
|
||||
std::for_each(unique_indices.begin(),
|
||||
unique_indices.end(),
|
||||
[&](const std::vector<T>& indices) {
|
||||
result.insert(result.end(), indices.begin(), indices.end());
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename T, typename T_size>
|
||||
void execute_fp16(const scatter_nd_update_basic_test_params& params)
|
||||
{
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({ params.input_type, params.input_format, params.input_size });
|
||||
auto input2 = engine.allocate_memory({ params.indices_type, params.indices_format, params.indices_size });
|
||||
auto input3 = engine.allocate_memory({ params.updates_type, params.updates_format, params.updates_size });
|
||||
|
||||
std::vector<int> input_vec(static_cast<int>(cldnn::format::dimension(params.input_format)));
|
||||
for (int i = 0; i < input_vec.size(); ++i)
|
||||
input_vec[i] = static_cast<int>(params.input_size.sizes()[i]);
|
||||
std::reverse(input_vec.begin() + 2, input_vec.end());
|
||||
|
||||
std::vector<int> updates_vec(static_cast<int>(cldnn::format::dimension(params.updates_format)));
|
||||
for (int i = 0; i < updates_vec.size(); ++i)
|
||||
updates_vec[i] = static_cast<int>(params.updates_size.sizes()[i]);
|
||||
std::reverse(updates_vec.begin() + 2, updates_vec.end());
|
||||
|
||||
std::vector<int> indices_vec(static_cast<int>(cldnn::format::dimension(params.indices_format)));
|
||||
for (size_t i = 0; i < indices_vec.size(); ++i)
|
||||
indices_vec[i] = static_cast<int>(params.indices_size.sizes()[i]);
|
||||
std::reverse(indices_vec.begin() + 2, indices_vec.end());
|
||||
indices_vec.resize(params.indices_rank);
|
||||
|
||||
auto input_data_fp16 = generate_random_1d<T>(params.input_size.count(), -127, 127);
|
||||
auto indices_data_fp16 = generate_unique_indices<T>(params);
|
||||
auto updates_data_fp16 = generate_random_1d<T>(params.updates_size.count(), -127, 127);
|
||||
|
||||
std::vector<float> input_data(params.input_size.count());
|
||||
for (int i = 0; i < params.input_size.count(); ++i)
|
||||
input_data[i] = static_cast<float>(input_data_fp16[i]);
|
||||
std::vector<float> indices_data(params.indices_size.count());
|
||||
for (int i = 0; i < params.indices_size.count(); ++i)
|
||||
indices_data[i] = static_cast<float>(indices_data_fp16[i]);
|
||||
std::vector<float> updates_data(params.updates_size.count());
|
||||
for (int i = 0; i < params.updates_size.count(); ++i)
|
||||
updates_data[i] = static_cast<float>(updates_data_fp16[i]);
|
||||
|
||||
set_values(input1, input_data_fp16);
|
||||
set_values(input2, indices_data_fp16);
|
||||
set_values(input3, updates_data_fp16);
|
||||
|
||||
// execute scatter_nd_update
|
||||
topology topology(
|
||||
input_layout("InputData", input1->get_layout()),
|
||||
input_layout("InputIndices", input2->get_layout()),
|
||||
input_layout("InputUpdates", input3->get_layout()),
|
||||
reorder("reorder1", "InputData", params.input_result_format, params.input_type),
|
||||
reorder("reorder2", "InputIndices", params.indices_result_format, params.indices_type),
|
||||
reorder("reorder3", "InputUpdates", params.updates_result_format, params.updates_type),
|
||||
scatter_nd_update("scatter_nd_update", "reorder1", "reorder2", "reorder3", params.indices_rank),
|
||||
reorder("out", "scatter_nd_update", params.input_format, params.input_type)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("InputData", input1);
|
||||
network.set_input_data("InputIndices", input2);
|
||||
network.set_input_data("InputUpdates", input3);
|
||||
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("out").get_memory();
|
||||
cldnn::mem_lock<T_size> outputs_ptr(output, get_test_stream());
|
||||
|
||||
auto outputs_ref = std::vector<float>(params.input_size.count());
|
||||
ngraph::runtime::reference::scatterNdUpdate<float, float>(input_data.data(),
|
||||
indices_data.data(),
|
||||
updates_data.data(),
|
||||
outputs_ref.data(),
|
||||
ov::Shape(input_vec.begin(), input_vec.end()),
|
||||
ov::Shape(indices_vec.begin(), indices_vec.end()),
|
||||
ov::Shape(updates_vec.begin(), updates_vec.end()));
|
||||
|
||||
for (size_t i = 0; i < outputs_ref.size(); ++i) {
|
||||
EXPECT_EQ(outputs_ref[i], float16_to_float32(outputs_ptr[i]));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void execute(const scatter_nd_update_basic_test_params& params)
|
||||
{
|
||||
// create input, indices, updates using params
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({ params.input_type, params.input_format, params.input_size });
|
||||
auto input2 = engine.allocate_memory({ params.indices_type, params.indices_format, params.indices_size });
|
||||
auto input3 = engine.allocate_memory({ params.updates_type, params.updates_format, params.updates_size });
|
||||
|
||||
std::vector<int> input_vec(static_cast<int>(cldnn::format::dimension(params.input_format)));
|
||||
for (int i = 0; i < input_vec.size(); ++i)
|
||||
input_vec[i] = static_cast<int>(params.input_size.sizes()[i]);
|
||||
std::reverse(input_vec.begin() + 2, input_vec.end());
|
||||
|
||||
std::vector<int> updates_vec(static_cast<int>(cldnn::format::dimension(params.updates_format)));
|
||||
for (int i = 0; i < updates_vec.size(); ++i)
|
||||
updates_vec[i] = static_cast<int>(params.updates_size.sizes()[i]);
|
||||
std::reverse(updates_vec.begin() + 2, updates_vec.end());
|
||||
|
||||
std::vector<int> indices_vec(static_cast<int>(cldnn::format::dimension(params.indices_format)));
|
||||
for (size_t i = 0; i < indices_vec.size(); ++i)
|
||||
indices_vec[i] = static_cast<int>(params.indices_size.sizes()[i]);
|
||||
std::reverse(indices_vec.begin() + 2, indices_vec.end());
|
||||
indices_vec.resize(params.indices_rank);
|
||||
|
||||
auto input_data = generate_random_1d<T>(params.input_size.count(), -127, 127);
|
||||
auto indices_data = generate_unique_indices<T>(params);
|
||||
auto updates_data = generate_random_1d<T>(params.updates_size.count(), -127, 127);
|
||||
|
||||
set_values(input1, input_data);
|
||||
set_values(input2, indices_data);
|
||||
set_values(input3, updates_data);
|
||||
|
||||
// execute scatter_nd_update
|
||||
topology topology(
|
||||
input_layout("InputData", input1->get_layout()),
|
||||
input_layout("InputIndices", input2->get_layout()),
|
||||
input_layout("InputUpdates", input3->get_layout()),
|
||||
reorder("reorder1", "InputData", params.input_result_format, params.input_type),
|
||||
reorder("reorder2", "InputIndices", params.indices_result_format, params.indices_type),
|
||||
reorder("reorder3", "InputUpdates", params.updates_result_format, params.updates_type),
|
||||
scatter_nd_update("scatter_nd_update", "reorder1", "reorder2", "reorder3", params.indices_rank),
|
||||
reorder("out", "scatter_nd_update", params.input_format, params.input_type)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("InputData", input1);
|
||||
network.set_input_data("InputIndices", input2);
|
||||
network.set_input_data("InputUpdates", input3);
|
||||
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("out").get_memory();
|
||||
cldnn::mem_lock<T> outputs_ptr(output, get_test_stream());
|
||||
|
||||
auto outputs_ref = std::vector<T>(params.input_size.count());
|
||||
ngraph::runtime::reference::scatterNdUpdate<T, T>(input_data.data(),
|
||||
indices_data.data(),
|
||||
updates_data.data(),
|
||||
outputs_ref.data(),
|
||||
ov::Shape(input_vec.begin(), input_vec.end()),
|
||||
ov::Shape(indices_vec.begin(), indices_vec.end()),
|
||||
ov::Shape(updates_vec.begin(), updates_vec.end()));
|
||||
|
||||
for (size_t i = 0; i < outputs_ref.size(); ++i) {
|
||||
EXPECT_EQ(outputs_ref[i], outputs_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(scatter_nd_update_random_test, random)
|
||||
{
|
||||
auto param = GetParam();
|
||||
if (param.input_type == data_types::u8)
|
||||
this->execute<uint8_t>(param);
|
||||
else if (param.input_type == data_types::i8)
|
||||
this->execute<int8_t>(param);
|
||||
else if (param.input_type == data_types::i32)
|
||||
this->execute<int32_t>(param);
|
||||
else if (param.input_type == data_types::i64)
|
||||
this->execute<int64_t>(param);
|
||||
else if (param.input_type == data_types::f16)
|
||||
this->execute_fp16<FLOAT16, uint16_t>(param);
|
||||
else if (param.input_type == data_types::f32)
|
||||
this->execute<float>(param);
|
||||
else
|
||||
IE_THROW() << "unidentified data type";
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_bsv32_fsv16_4d_rank_1,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
|
||||
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
|
||||
1 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_bsv32_fsv16_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
|
||||
{ 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_4d_rank_1,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
|
||||
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
|
||||
1 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
|
||||
{ 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 3, 3, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 3, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_3,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 },
|
||||
3 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_4,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f32, data_types::f32, data_types::f32,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 },
|
||||
4 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_4d_rank_1,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
|
||||
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
|
||||
1 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
|
||||
{ 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_1,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 },
|
||||
1 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_3,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 },
|
||||
3 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_4,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 },
|
||||
4 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_bsv32_fsv16_4d_rank_1,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
|
||||
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
|
||||
1 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_bsv32_fsv16_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::f16, data_types::f16, data_types::f16,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
|
||||
{ 48, 24, 3, 3 }, {3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_bsv32_fsv16_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::i8, data_types::i8, data_types::i8,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
|
||||
{ 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_bsv32_fsv32_4d_rank_1,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::i8, data_types::i8, data_types::i8,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32,
|
||||
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
|
||||
1 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv32_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::i8, data_types::i8, data_types::i8,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::b_fs_yx_fsv32, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32,
|
||||
{ 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv32_5d_rank_3,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::i8, data_types::i8, data_types::i8,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv32, format::b_fs_yx_fsv32, format::b_fs_zyx_fsv32,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 },
|
||||
3 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv16_4d_rank_2,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::i8, data_types::i8, data_types::i8,
|
||||
format::bfyx, format::bfyx, format::bfyx,
|
||||
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
|
||||
{ 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
|
||||
2 }
|
||||
}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv16_5d_rank_4,
|
||||
scatter_nd_update_random_test,
|
||||
testing::ValuesIn(
|
||||
std::vector<scatter_nd_update_basic_test_params>{
|
||||
{ data_types::i8, data_types::i8, data_types::i8,
|
||||
format::bfzyx, format::bfyx, format::bfzyx,
|
||||
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
|
||||
{ 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 },
|
||||
4 }
|
||||
}));
|
||||
|
||||
|
||||
TEST(scatter_nd_update_gpu_fp16_test15, data5_indice3_update5) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 2, 2, 4, 3 } }); // data
|
||||
auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 2, 1, 1 } }); // indices
|
||||
auto input3 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 2, 4, 3, 2 } }); // updates
|
||||
auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 4, 3, 2 } }); // updates
|
||||
|
||||
set_values(input1, {
|
||||
// 0
|
||||
@ -158,7 +631,7 @@ TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) {
|
||||
FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
|
||||
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f),
|
||||
FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f),
|
||||
FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f),
|
||||
FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
|
||||
|
||||
// 1
|
||||
@ -166,9 +639,9 @@ TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) {
|
||||
FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
|
||||
|
||||
FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f),
|
||||
FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f),
|
||||
FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f),
|
||||
FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f),
|
||||
};
|
||||
|
||||
topology topology;
|
||||
@ -398,7 +871,6 @@ TEST(scatter_nd_update_gpu_fp16_test11, data6_indice1_update6) {
|
||||
FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f),
|
||||
FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f),
|
||||
FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f),
|
||||
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
@ -511,7 +983,6 @@ TEST(scatter_nd_update_gpu_fp16_test10, data5_indice1_update5) {
|
||||
FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f),
|
||||
FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f),
|
||||
FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f),
|
||||
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
@ -595,7 +1066,6 @@ TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) {
|
||||
FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f),
|
||||
FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f),
|
||||
FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f),
|
||||
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
@ -641,9 +1111,9 @@ TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) {
|
||||
TEST(scatter_nd_update_gpu_fp16_test8, data6_indice2_update5) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 3, 4, 2 } }); // data
|
||||
auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 4, 3, 2 } }); // data
|
||||
auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // indices
|
||||
auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 1, 3, 4, 2 } }); // updates
|
||||
auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 1, 2, 4, 3 } }); // updates
|
||||
|
||||
set_values(input1, {
|
||||
//0,0
|
||||
@ -808,7 +1278,7 @@ TEST(scatter_nd_update_gpu_fp16_test7, data5_indice2_update4) {
|
||||
TEST(scatter_nd_update_gpu_fp16_test6, data4_indice2_update3) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 3, 4, 2 } }); // data
|
||||
auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 3, 2, 4 } }); // data
|
||||
auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices
|
||||
auto input3 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 4, 1, 2 } }); // updates
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user