Update scatter nd update kernel to support blocked_formats (#11533)

* draft pr for planar and fsv16

* draft pr for general test

* update fusion test (failing)

* update fusing test (pass)

* update fusing test (include exception)

* clean gpu unit test

* review comment applied

* unit test cases added & cpplint applied

* cpplint error fixed

* change gpu test cases for fp16

* fusing test fix generate_unique_indices

* fix typo

* revise cl kernel for occasions when updates shape is altered
This commit is contained in:
Sieun Kim 2022-06-13 13:25:24 +09:00 committed by GitHub
parent ca7ddae9ba
commit 6e3dd4adce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 892 additions and 105 deletions

View File

@ -55,12 +55,92 @@ attach_scatter_nd_update_impl::attach_scatter_nd_update_impl() {
std::make_tuple(data_types::f32, format::bfyx), std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx), std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx), std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx), std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx), std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx), std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx), std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i32, format::bfwzyx), std::make_tuple(data_types::i32, format::bfwzyx),
std::make_tuple(data_types::i8, format::bfwzyx),
std::make_tuple(data_types::u8, format::bfwzyx),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv4),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv4),
std::make_tuple(data_types::i32, format::b_fs_yx_fsv4),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv4),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv4),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv16),
std::make_tuple(data_types::i32, format::b_fs_yx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv32),
std::make_tuple(data_types::i32, format::b_fs_yx_fsv32),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f32, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::f16, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::i32, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::u8, format::b_fs_zyx_fsv16),
std::make_tuple(data_types::f32, format::b_fs_zyx_fsv32),
std::make_tuple(data_types::f16, format::b_fs_zyx_fsv32),
std::make_tuple(data_types::i32, format::b_fs_zyx_fsv32),
std::make_tuple(data_types::i8, format::b_fs_zyx_fsv32),
std::make_tuple(data_types::u8, format::b_fs_zyx_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv2),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv2),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv2),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32),
}); });
} }

View File

@ -1397,7 +1397,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::region_yolo::type_id() && prim.type() != cldnn::region_yolo::type_id() &&
prim.type() != cldnn::normalize::type_id() && prim.type() != cldnn::normalize::type_id() &&
prim.type() != cldnn::mvn::type_id() && prim.type() != cldnn::mvn::type_id() &&
prim.type() != cldnn::gather::type_id()) { prim.type() != cldnn::gather::type_id() &&
prim.type() != cldnn::scatter_nd_update::type_id()) {
can_use_fsv16 = false; can_use_fsv16 = false;
} }
@ -1423,6 +1424,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::softmax::type_id() && prim.type() != cldnn::softmax::type_id() &&
prim.type() != cldnn::fully_connected::type_id() && prim.type() != cldnn::fully_connected::type_id() &&
prim.type() != cldnn::generic_layer::type_id() && prim.type() != cldnn::generic_layer::type_id() &&
prim.type() != cldnn::scatter_nd_update::type_id() &&
prim.type() != cldnn::quantize::type_id()) prim.type() != cldnn::quantize::type_id())
can_use_bs_fs_yx_bsv16_fsv16 = false; can_use_bs_fs_yx_bsv16_fsv16 = false;
} }

View File

@ -14,17 +14,15 @@ ParamsKey ScatterNDUpdateKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32); k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32); k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT8); k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx); k.EnableAllInputLayout();
k.EnableOutputLayout(DataLayout::bfyx); k.EnableAllOutputLayout();
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset(); k.EnableTensorOffset();
k.EnableTensorPitches(); k.EnableTensorPitches();
k.EnableBatching(); k.EnableBatching();
@ -115,14 +113,10 @@ bool ScatterNDUpdateKernelRef::Validate(const Params& p, const optional_params&
return true; return true;
} }
static std::string GetInputBlockND(const scatter_nd_update_params& params) { static std::string GetInputBlockND(const scatter_nd_update_params& params, int num, const int rank) {
const auto& input = params.inputs[0]; const auto& input = params.inputs[num];
auto input_dims = input.LogicalDims(); auto input_dims = input.LogicalDims();
std::reverse(input_dims.begin(), input_dims.end()); std::reverse(input_dims.begin(), input_dims.end());
while (!input_dims.empty() && input_dims.back() == 1) {
input_dims.pop_back();
}
const int rank = static_cast<int>(input_dims.size());
std::vector<size_t> block_nd(rank + 1); std::vector<size_t> block_nd(rank + 1);
block_nd[rank] = 1; block_nd[rank] = 1;
for (int idx = (rank - 1); idx >= 0; idx--) { for (int idx = (rank - 1); idx >= 0; idx--) {
@ -157,9 +151,14 @@ KernelsData ScatterNDUpdateKernelRef::GetKernelsData(const Params& params, const
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options, i); auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options, i);
if (i == 1) { if (i == 1) {
int input0_rank = static_cast<int>(newParams.inputs[0].LogicalDims().size());
int input2_rank = static_cast<int>(newParams.inputs[2].LogicalDims().size());
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true")); cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim)); cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim));
cldnn_jit.AddConstant(MakeJitConstant("INPUT_BLOCK_ND", GetInputBlockND(newParams))); cldnn_jit.AddConstant(MakeJitConstant("INPUT0_BLOCK_ND", GetInputBlockND(newParams, 0, input0_rank)));
cldnn_jit.AddConstant(MakeJitConstant("INPUT1_BLOCK_ND", GetInputBlockND(newParams, 1, newParams.indices_rank - 1)));
cldnn_jit.AddConstant(MakeJitConstant("INPUT2_BLOCK_ND", GetInputBlockND(newParams, 2, input2_rank)));
cldnn_jit.AddConstant(MakeJitConstant("INDICES_RANK", newParams.indices_rank));
} }
std::pair<std::string, std::string> jit = CreateJit(kernelName, cldnn_jit, entry_point); std::pair<std::string, std::string> jit = CreateJit(kernelName, cldnn_jit, entry_point);

View File

@ -1,3 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation // Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
@ -16,6 +17,24 @@
#define ORDER b,f,w,z,y,x #define ORDER b,f,w,z,y,x
#endif #endif
#if INPUT2_DIMS == 4
#define UPD_ORDER upd_b,upd_f,upd_y,upd_x
#elif INPUT2_DIMS == 5
#define UPD_ORDER upd_b,upd_f,upd_z,upd_y,upd_x
#elif INPUT2_DIMS == 6
#define UPD_ORDER upd_b,upd_f,upd_w,upd_z,upd_y,upd_x
#endif
#if INPUT1_DIMS == 4
#define IDX_ORDER idx_b,idx_f,idx_y,idx_x
#elif INPUT1_DIMS == 5
#define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x
#elif INPUT1_DIMS == 6
#define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x
#endif
#define INDICES_MAX_DIM 6
KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data, KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
const __global INPUT1_TYPE* indices, const __global INPUT1_TYPE* indices,
const __global INPUT2_TYPE* updates, const __global INPUT2_TYPE* updates,
@ -49,81 +68,122 @@ KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
#else // Second kernel #else // Second kernel
const uint blockND[] = {INPUT_BLOCK_ND}; const uint dataND[] = {INPUT0_BLOCK_ND};
const uint k = INDICES_LAST_DIM; const uint updatesND[] = {INPUT2_BLOCK_ND};
const uint size_to_update = blockND[INDICES_LAST_DIM]; const uint indicesND[] = {INPUT1_BLOCK_ND};
const uint indices_idx = dim2; const uint size_to_update = dataND[INDICES_LAST_DIM];
const uint indices_offset = indices_idx * k;
uint dst_offset = 0; #if INPUT1_DIMS == 4
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
for (uint i = 0; i < k; i++) { #elif INPUT1_DIMS == 5
INPUT1_TYPE idxValue = indices[indices_offset + i]; const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
dst_offset += idxValue * blockND[i + 1]; #elif INPUT1_DIMS == 6
} const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
uint update_offset = indices_idx * size_to_update;
for (int i = 0; i < size_to_update; i++) {
uint dst_idx = dst_offset + i;
uint up_idx = update_offset + i;
INPUT2_TYPE val = updates[up_idx];
#if HAS_FUSED_OPS
#if OUTPUT_DIMS == 4
const uint y_pitch = OUTPUT_SIZE_X;
const uint f_pitch = y_pitch * OUTPUT_SIZE_Y;
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
const uint b_remain = dst_idx % b_pitch;
const uint f_remain = b_remain % f_pitch;
const uint y_remain = f_remain % y_pitch;
const uint b = dst_idx / b_pitch;
const uint f = b_remain / f_pitch;
const uint y = f_remain / y_pitch;
const uint x = y_remain;
#elif OUTPUT_DIMS == 5
const uint y_pitch = OUTPUT_SIZE_X;
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
const uint f_pitch = z_pitch * OUTPUT_SIZE_Z;
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
const uint b_remain = dst_idx % b_pitch;
const uint f_remain = b_remain % f_pitch;
const uint z_remain = f_remain % z_pitch;
const uint y_remain = z_remain % y_pitch;
const uint b = dst_idx / b_pitch;
const uint f = b_remain / f_pitch;
const uint z = f_remain / z_pitch;
const uint y = z_remain / y_pitch;
const uint x = y_remain;
#elif OUTPUT_DIMS == 6
const uint y_pitch = OUTPUT_SIZE_X;
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
const uint w_pitch = z_pitch * OUTPUT_SIZE_Z;
const uint f_pitch = w_pitch * OUTPUT_SIZE_W;
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
const uint b_remain = dst_idx % b_pitch;
const uint f_remain = b_remain % f_pitch;
const uint w_remain = f_remain % w_pitch;
const uint z_remain = w_remain % z_pitch;
const uint y_remain = z_remain % y_pitch;
const uint b = dst_idx / b_pitch;
const uint f = b_remain / f_pitch;
const uint w = f_remain / w_pitch;
const uint z = w_remain / z_pitch;
const uint y = z_remain / y_pitch;
const uint x = y_remain;
#endif
FUSED_OPS_SECOND_KERNEL;
output[dst_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
#else
output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif #endif
#if INPUT0_DIMS == 4
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#elif INPUT0_DIMS == 5
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#elif INPUT0_DIMS == 6
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#endif
// Get indices index
uint idx[INDICES_MAX_DIM] = {0};
uint rmd_idx = dim2;
for (int i = 0; i < INDICES_RANK - 1; ++i) {
idx[i] = rmd_idx / indicesND[1 + i];
rmd_idx %= indicesND[1 + i];
}
uint out[INDICES_MAX_DIM] = {0};
for (int i = 0; i < indices_dim[INDICES_RANK - 1]; ++i) {
idx[INDICES_RANK - 1] = i;
const uint idx_b = idx[0];
const uint idx_f = idx[1];
#if INPUT1_DIMS == 4
const uint idx_y = idx[2];
const uint idx_x = idx[3];
#elif INPUT1_DIMS == 5
const uint idx_z = idx[2];
const uint idx_y = idx[3];
const uint idx_x = idx[4];
#elif INPUT1_DIMS == 6
const uint idx_w = idx[2];
const uint idx_z = idx[3];
const uint idx_y = idx[4];
const uint idx_x = idx[5];
#endif
uint index = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
out[i] = indices[index];
// Check if tensor size is valid
// ex) when data format = bfyx and data shape = { 3, 3, 4, 1 }, indices shape is { 2, 1 } with rank = 2, indices values are { 1.0, 4.0 },
// the second indices value is invalid as data shape has 'b' of size 3, and therefore 4 cannot be a correct index of data
// If indices value is invalid, saturate value to max valid value (ex. 4.0 -> 2.0)
if(out[i] >= data_dim[i])
out[i] = data_dim[i] - 1;
}
for (int i = 0; i < size_to_update; ++i) {
// Define updates index
uint upd[INDICES_MAX_DIM] = {0};
for (int j = 0; j < INDICES_RANK - 1; ++j) {
upd[j] = idx[j];
}
uint data_rmd = i, updates_rmd = i;
for (int j = indices_dim[INDICES_RANK - 1]; j < INPUT0_DIMS; ++j) {
out[j] = data_rmd / dataND[j + 1];
data_rmd %= dataND[j + 1];
}
for (int k = INDICES_RANK - 1; k < INPUT2_DIMS; ++k) {
upd[k] = updates_rmd / updatesND[k + 1];
updates_rmd %= updatesND[k + 1];
}
// Get update index
const uint upd_b = upd[0];
const uint upd_f = upd[1];
#if INPUT2_DIMS == 4
const uint upd_y = upd[2];
const uint upd_x = upd[3];
#elif INPUT2_DIMS == 5
const uint upd_z = upd[2];
const uint upd_y = upd[3];
const uint upd_x = upd[4];
#elif INPUT2_DIMS == 6
const uint upd_w = upd[2];
const uint upd_z = upd[3];
const uint upd_y = upd[4];
const uint upd_x = upd[5];
#endif
uint upd_idx = GET_UPDATES_INDEX(INPUT2, UPD_ORDER);
// Get output index
const uint b = out[0];
const uint f = out[1];
#if INPUT0_DIMS == 4
const uint y = out[2];
const uint x = out[3];
#elif INPUT0_DIMS == 5
const uint z = out[2];
const uint y = out[3];
const uint x = out[4];
#elif INPUT0_DIMS == 6
const uint w = out[2];
const uint z = out[3];
const uint y = out[4];
const uint x = out[5];
#endif
uint out_idx = GET_OUTPUT_INDEX(ORDER);
INPUT2_TYPE val = updates[upd_idx];
#if HAS_FUSED_OPS
FUSED_OPS_SECOND_KERNEL;
output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
#else
output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif
} }
#endif #endif
@ -140,3 +200,15 @@ KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
#ifdef ORDER #ifdef ORDER
#undef ORDER #undef ORDER
#endif #endif
#ifdef UPD_ORDER
#undef UPD_ORDER
#endif
#ifdef IDX_ORDER
#undef IDX_ORDER
#endif
#ifdef INDICES_MAX_DIM
#undef INDICES_MAX_DIM
#endif

View File

@ -20,6 +20,21 @@ static void CreateScatterNDUpdateOp(Program& p, const std::shared_ptr<ngraph::op
std::string layerName = layer_type_name_ID(op); std::string layerName = layer_type_name_ID(op);
auto indices_rank = op->get_input_shape(1).size(); auto indices_rank = op->get_input_shape(1).size();
auto indices_constant = std::dynamic_pointer_cast<ngraph::op::Constant>(op->get_input_node_shared_ptr(1));
if (indices_constant) {
auto indices = indices_constant->cast_vector<int32_t>();
auto indices_last_dim = op->get_input_shape(1)[indices_rank - 1];
auto data_shape = op->get_input_shape(0);
bool valid = true;
for (int i = 0; i < indices.size(); ++i) {
if (indices[i] >= data_shape[i % indices_last_dim])
valid = false;
}
if (!valid)
IE_THROW() << "Invaild indices values";
}
auto primitive = cldnn::scatter_nd_update(layerName, auto primitive = cldnn::scatter_nd_update(layerName,
inputPrimitives[0], inputPrimitives[0],
inputPrimitives[1], inputPrimitives[1],

View File

@ -47,7 +47,8 @@ target_link_libraries(${TARGET_NAME} PRIVATE openvino_intel_gpu_graph
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/test_utils/ ${CMAKE_CURRENT_SOURCE_DIR}/test_utils/
$<TARGET_PROPERTY:openvino_intel_gpu_kernels,INTERFACE_INCLUDE_DIRECTORIES> $<TARGET_PROPERTY:openvino_intel_gpu_kernels,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:openvino_intel_gpu_runtime,INTERFACE_INCLUDE_DIRECTORIES>) $<TARGET_PROPERTY:openvino_intel_gpu_runtime,INTERFACE_INCLUDE_DIRECTORIES>
${CMAKE_HOME_DIRECTORY}/src/core/reference/include/)
if(WIN32) if(WIN32)
target_link_libraries(${TARGET_NAME} PRIVATE setupapi) target_link_libraries(${TARGET_NAME} PRIVATE setupapi)
elseif((NOT ANDROID) AND (UNIX)) elseif((NOT ANDROID) AND (UNIX))

View File

@ -12,6 +12,9 @@
#include <intel_gpu/primitives/scatter_nd_update.hpp> #include <intel_gpu/primitives/scatter_nd_update.hpp>
#include <cmath> #include <cmath>
#include <stdlib.h>
#include <time.h>
#include <algorithm>
using namespace cldnn; using namespace cldnn;
using namespace ::tests; using namespace ::tests;
@ -82,16 +85,16 @@ public:
std::set<std::vector<T>> unique_indices; std::set<std::vector<T>> unique_indices;
std::vector<T> result; std::vector<T> result;
auto indices_shape = p.indices_shape.sizes(get_default_format(p.indices_rank)); auto indices_shape = p.indices_shape.sizes(get_default_format(p.indices_rank));
auto last_indices_dim = indices_shape.back(); auto data_shape = p.input_shape.sizes(p.input_format);
auto last_indices_dim = indices_shape.at(p.indices_rank - 1);
auto count = 1; auto count = p.indices_shape.count() / last_indices_dim;
for (size_t i = 0; i < indices_shape.size() - 1; i++)
count *= indices_shape[i];
while (unique_indices.size() != count) { while (unique_indices.size() != count) {
std::vector<T> indices; std::vector<T> indices;
for (size_t i = 0; i < last_indices_dim; i++) for (size_t i = 0; i < last_indices_dim; i++) {
indices.push_back(generate_random_val<T>(0, indices_shape[i])); indices.push_back(static_cast<T>(generate_random_val<int>(0, data_shape[i] - 1)));
}
unique_indices.insert(indices); unique_indices.insert(indices);
} }
@ -173,6 +176,55 @@ public:
#define CASE_SCATTER_ND_UPDATE_FP32_6D_5 { 6, 7, 8, 9, 2, 2 }, { 5, 5, 1, 1 }, { 5, 8, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx #define CASE_SCATTER_ND_UPDATE_FP32_6D_5 { 6, 7, 8, 9, 2, 2 }, { 5, 5, 1, 1 }, { 5, 8, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_6 { 6, 7, 8, 9, 2, 2 }, { 5, 6, 1, 1 }, { 5, 1, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx #define CASE_SCATTER_ND_UPDATE_FP32_6D_6 { 6, 7, 8, 9, 2, 2 }, { 5, 6, 1, 1 }, { 5, 1, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1 { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 8, 9 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3 { 6, 7, 8, 9, 10 }, { 5, 3, 1, 1 }, { 5, 9, 1, 1, 8 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4 { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5 { 6, 7, 8, 9, 10 }, { 5, 5, 1, 1 }, { 5, 1, 1, 1, 1 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, 3, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, 3, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8 { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, 4, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9 { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, 4, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1 { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 8, 9 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3 { 6, 7, 8, 9, 10 }, { 5, 3, 1, 1 }, { 5, 9, 1, 1, 8 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4 { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5 { 6, 7, 8, 9, 10 }, { 5, 5, 1, 1 }, { 5, 1, 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, 3, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, 3, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8 { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, 4, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9 { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, 4, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx
class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {}; class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {};
TEST_P(scatter_nd_update_quantize, basic) { TEST_P(scatter_nd_update_quantize, basic) {
auto p = GetParam(); auto p = GetParam();
@ -235,6 +287,54 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_quantize, ::testing::Val
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1, 2, 3 }, // FP16
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1, 2, 3 }, // FP32
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1, 2, 3 }, // FP16
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1, 2, 3 }, // FP32
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6, 2, 3 },
})); }));
class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {}; class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {};
@ -298,4 +398,52 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1, 2, 5 }, // FP16
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1, 2, 5 }, // FP32
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1, 2, 5 }, // FP16
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1, 2, 5 }, // FP32
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6, 2, 5 },
})); }));

View File

@ -3,6 +3,7 @@
// //
#include "test_utils.h" #include "test_utils.h"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
#include <intel_gpu/primitives/input_layout.hpp> #include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/scatter_update.hpp> #include <intel_gpu/primitives/scatter_update.hpp>
@ -12,17 +13,489 @@
#include <intel_gpu/graph/network.hpp> #include <intel_gpu/graph/network.hpp>
#include <cstddef> #include <cstddef>
#include <cstring>
#include <numeric>
#include <stdlib.h>
#include <algorithm>
using namespace cldnn; using namespace cldnn;
using namespace ::tests; using namespace ::tests;
struct scatter_nd_update_basic_test_params
{
data_types input_type;
data_types indices_type;
data_types updates_type;
format input_format;
format indices_format;
format updates_format;
format input_result_format;
format indices_result_format;
format updates_result_format;
tensor input_size;
tensor indices_size;
tensor updates_size;
int indices_rank;
};
struct scatter_nd_update_random_test : testing::TestWithParam<scatter_nd_update_basic_test_params>
{
format get_default_format(int rank = 4) {
if (rank <= 4)
return cldnn::format::bfyx;
else if (rank == 5)
return cldnn::format::bfzyx;
else
return cldnn::format::bfwzyx;
}
template<typename T>
T generate_random_val(int min, int max, int k = 8) {
static std::default_random_engine generator(random_seed);
// 1/k is the resolution of the floating point numbers
std::uniform_int_distribution<int> distribution(k * min, k * max);
T val = (T)distribution(generator);
val /= k;
return val;
}
template <typename T>
std::vector<T> generate_unique_indices(const scatter_nd_update_basic_test_params& p) {
std::set<std::vector<T>> unique_indices;
std::vector<T> result;
auto indices_shape = p.indices_size.sizes(get_default_format(p.indices_rank));
auto data_shape = p.input_size.sizes(p.input_format);
auto last_indices_dim = indices_shape.at(p.indices_rank - 1);
auto count = p.indices_size.count() / last_indices_dim;
while (unique_indices.size() != count) {
std::vector<T> indices;
for (size_t i = 0; i < last_indices_dim; i++) {
indices.push_back(static_cast<T>(generate_random_val<int>(0, data_shape[i] - 1)));
}
unique_indices.insert(indices);
}
std::for_each(unique_indices.begin(),
unique_indices.end(),
[&](const std::vector<T>& indices) {
result.insert(result.end(), indices.begin(), indices.end());
});
return result;
}
template<typename T, typename T_size>
void execute_fp16(const scatter_nd_update_basic_test_params& params)
{
auto& engine = get_test_engine();
auto input1 = engine.allocate_memory({ params.input_type, params.input_format, params.input_size });
auto input2 = engine.allocate_memory({ params.indices_type, params.indices_format, params.indices_size });
auto input3 = engine.allocate_memory({ params.updates_type, params.updates_format, params.updates_size });
std::vector<int> input_vec(static_cast<int>(cldnn::format::dimension(params.input_format)));
for (int i = 0; i < input_vec.size(); ++i)
input_vec[i] = static_cast<int>(params.input_size.sizes()[i]);
std::reverse(input_vec.begin() + 2, input_vec.end());
std::vector<int> updates_vec(static_cast<int>(cldnn::format::dimension(params.updates_format)));
for (int i = 0; i < updates_vec.size(); ++i)
updates_vec[i] = static_cast<int>(params.updates_size.sizes()[i]);
std::reverse(updates_vec.begin() + 2, updates_vec.end());
std::vector<int> indices_vec(static_cast<int>(cldnn::format::dimension(params.indices_format)));
for (size_t i = 0; i < indices_vec.size(); ++i)
indices_vec[i] = static_cast<int>(params.indices_size.sizes()[i]);
std::reverse(indices_vec.begin() + 2, indices_vec.end());
indices_vec.resize(params.indices_rank);
auto input_data_fp16 = generate_random_1d<T>(params.input_size.count(), -127, 127);
auto indices_data_fp16 = generate_unique_indices<T>(params);
auto updates_data_fp16 = generate_random_1d<T>(params.updates_size.count(), -127, 127);
std::vector<float> input_data(params.input_size.count());
for (int i = 0; i < params.input_size.count(); ++i)
input_data[i] = static_cast<float>(input_data_fp16[i]);
std::vector<float> indices_data(params.indices_size.count());
for (int i = 0; i < params.indices_size.count(); ++i)
indices_data[i] = static_cast<float>(indices_data_fp16[i]);
std::vector<float> updates_data(params.updates_size.count());
for (int i = 0; i < params.updates_size.count(); ++i)
updates_data[i] = static_cast<float>(updates_data_fp16[i]);
set_values(input1, input_data_fp16);
set_values(input2, indices_data_fp16);
set_values(input3, updates_data_fp16);
// execute scatter_nd_update
topology topology(
input_layout("InputData", input1->get_layout()),
input_layout("InputIndices", input2->get_layout()),
input_layout("InputUpdates", input3->get_layout()),
reorder("reorder1", "InputData", params.input_result_format, params.input_type),
reorder("reorder2", "InputIndices", params.indices_result_format, params.indices_type),
reorder("reorder3", "InputUpdates", params.updates_result_format, params.updates_type),
scatter_nd_update("scatter_nd_update", "reorder1", "reorder2", "reorder3", params.indices_rank),
reorder("out", "scatter_nd_update", params.input_format, params.input_type)
);
network network(engine, topology);
network.set_input_data("InputData", input1);
network.set_input_data("InputIndices", input2);
network.set_input_data("InputUpdates", input3);
auto outputs = network.execute();
auto output = outputs.at("out").get_memory();
cldnn::mem_lock<T_size> outputs_ptr(output, get_test_stream());
auto outputs_ref = std::vector<float>(params.input_size.count());
ngraph::runtime::reference::scatterNdUpdate<float, float>(input_data.data(),
indices_data.data(),
updates_data.data(),
outputs_ref.data(),
ov::Shape(input_vec.begin(), input_vec.end()),
ov::Shape(indices_vec.begin(), indices_vec.end()),
ov::Shape(updates_vec.begin(), updates_vec.end()));
for (size_t i = 0; i < outputs_ref.size(); ++i) {
EXPECT_EQ(outputs_ref[i], float16_to_float32(outputs_ptr[i]));
}
}
template<typename T>
void execute(const scatter_nd_update_basic_test_params& params)
{
// create input, indices, updates using params
auto& engine = get_test_engine();
auto input1 = engine.allocate_memory({ params.input_type, params.input_format, params.input_size });
auto input2 = engine.allocate_memory({ params.indices_type, params.indices_format, params.indices_size });
auto input3 = engine.allocate_memory({ params.updates_type, params.updates_format, params.updates_size });
std::vector<int> input_vec(static_cast<int>(cldnn::format::dimension(params.input_format)));
for (int i = 0; i < input_vec.size(); ++i)
input_vec[i] = static_cast<int>(params.input_size.sizes()[i]);
std::reverse(input_vec.begin() + 2, input_vec.end());
std::vector<int> updates_vec(static_cast<int>(cldnn::format::dimension(params.updates_format)));
for (int i = 0; i < updates_vec.size(); ++i)
updates_vec[i] = static_cast<int>(params.updates_size.sizes()[i]);
std::reverse(updates_vec.begin() + 2, updates_vec.end());
std::vector<int> indices_vec(static_cast<int>(cldnn::format::dimension(params.indices_format)));
for (size_t i = 0; i < indices_vec.size(); ++i)
indices_vec[i] = static_cast<int>(params.indices_size.sizes()[i]);
std::reverse(indices_vec.begin() + 2, indices_vec.end());
indices_vec.resize(params.indices_rank);
auto input_data = generate_random_1d<T>(params.input_size.count(), -127, 127);
auto indices_data = generate_unique_indices<T>(params);
auto updates_data = generate_random_1d<T>(params.updates_size.count(), -127, 127);
set_values(input1, input_data);
set_values(input2, indices_data);
set_values(input3, updates_data);
// execute scatter_nd_update
topology topology(
input_layout("InputData", input1->get_layout()),
input_layout("InputIndices", input2->get_layout()),
input_layout("InputUpdates", input3->get_layout()),
reorder("reorder1", "InputData", params.input_result_format, params.input_type),
reorder("reorder2", "InputIndices", params.indices_result_format, params.indices_type),
reorder("reorder3", "InputUpdates", params.updates_result_format, params.updates_type),
scatter_nd_update("scatter_nd_update", "reorder1", "reorder2", "reorder3", params.indices_rank),
reorder("out", "scatter_nd_update", params.input_format, params.input_type)
);
network network(engine, topology);
network.set_input_data("InputData", input1);
network.set_input_data("InputIndices", input2);
network.set_input_data("InputUpdates", input3);
auto outputs = network.execute();
auto output = outputs.at("out").get_memory();
cldnn::mem_lock<T> outputs_ptr(output, get_test_stream());
auto outputs_ref = std::vector<T>(params.input_size.count());
ngraph::runtime::reference::scatterNdUpdate<T, T>(input_data.data(),
indices_data.data(),
updates_data.data(),
outputs_ref.data(),
ov::Shape(input_vec.begin(), input_vec.end()),
ov::Shape(indices_vec.begin(), indices_vec.end()),
ov::Shape(updates_vec.begin(), updates_vec.end()));
for (size_t i = 0; i < outputs_ref.size(); ++i) {
EXPECT_EQ(outputs_ref[i], outputs_ptr[i]);
}
}
};
TEST_P(scatter_nd_update_random_test, random)
{
auto param = GetParam();
if (param.input_type == data_types::u8)
this->execute<uint8_t>(param);
else if (param.input_type == data_types::i8)
this->execute<int8_t>(param);
else if (param.input_type == data_types::i32)
this->execute<int32_t>(param);
else if (param.input_type == data_types::i64)
this->execute<int64_t>(param);
else if (param.input_type == data_types::f16)
this->execute_fp16<FLOAT16, uint16_t>(param);
else if (param.input_type == data_types::f32)
this->execute<float>(param);
else
IE_THROW() << "unidentified data type";
}
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_bsv32_fsv16_4d_rank_1,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfyx, format::bfyx, format::bfyx,
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
1 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_bsv32_fsv16_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfyx, format::bfyx, format::bfyx,
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
{ 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_4d_rank_1,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfyx, format::bfyx, format::bfyx,
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
1 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfyx, format::bfyx, format::bfyx,
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
{ 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 3, 3, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 3, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_3,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 },
3 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_4,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f32, data_types::f32, data_types::f32,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 },
4 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_4d_rank_1,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfyx, format::bfyx, format::bfyx,
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
1 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfyx, format::bfyx, format::bfyx,
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
{ 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_1,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 },
1 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_3,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 },
3 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_4,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 },
4 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_bsv32_fsv16_4d_rank_1,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfyx, format::bfyx, format::bfyx,
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
1 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_bsv32_fsv16_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::f16, data_types::f16, data_types::f16,
format::bfyx, format::bfyx, format::bfyx,
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
{ 48, 24, 3, 3 }, {3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_bsv32_fsv16_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::i8, data_types::i8, data_types::i8,
format::bfyx, format::bfyx, format::bfyx,
format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16,
{ 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_bsv32_fsv32_4d_rank_1,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::i8, data_types::i8, data_types::i8,
format::bfyx, format::bfyx, format::bfyx,
format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32,
{ 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 },
1 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv32_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::i8, data_types::i8, data_types::i8,
format::bfyx, format::bfyx, format::bfyx,
format::b_fs_yx_fsv32, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32,
{ 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv32_5d_rank_3,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::i8, data_types::i8, data_types::i8,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv32, format::b_fs_yx_fsv32, format::b_fs_zyx_fsv32,
{ 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 },
3 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv16_4d_rank_2,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::i8, data_types::i8, data_types::i8,
format::bfyx, format::bfyx, format::bfyx,
format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16,
{ 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 },
2 }
}));
INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv16_5d_rank_4,
scatter_nd_update_random_test,
testing::ValuesIn(
std::vector<scatter_nd_update_basic_test_params>{
{ data_types::i8, data_types::i8, data_types::i8,
format::bfzyx, format::bfyx, format::bfzyx,
format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16,
{ 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 },
4 }
}));
TEST(scatter_nd_update_gpu_fp16_test15, data5_indice3_update5) { TEST(scatter_nd_update_gpu_fp16_test15, data5_indice3_update5) {
auto& engine = get_test_engine(); auto& engine = get_test_engine();
auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 2, 2, 4, 3 } }); // data auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 2, 2, 4, 3 } }); // data
auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 2, 1, 1 } }); // indices auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 2, 1, 1 } }); // indices
auto input3 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 2, 4, 3, 2 } }); // updates auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 4, 3, 2 } }); // updates
set_values(input1, { set_values(input1, {
// 0 // 0
@ -158,7 +631,7 @@ TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) {
FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f),
FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f),
FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
// 1 // 1
@ -166,9 +639,9 @@ TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) {
FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f),
FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f),
FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f),
}; };
topology topology; topology topology;
@ -398,7 +871,6 @@ TEST(scatter_nd_update_gpu_fp16_test11, data6_indice1_update6) {
FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f),
FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f),
FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f),
}); });
std::vector<float> expected_results = { std::vector<float> expected_results = {
@ -511,7 +983,6 @@ TEST(scatter_nd_update_gpu_fp16_test10, data5_indice1_update5) {
FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f),
FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f),
FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f),
}); });
std::vector<float> expected_results = { std::vector<float> expected_results = {
@ -595,7 +1066,6 @@ TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) {
FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f),
FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f),
FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f),
}); });
std::vector<float> expected_results = { std::vector<float> expected_results = {
@ -641,9 +1111,9 @@ TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) {
TEST(scatter_nd_update_gpu_fp16_test8, data6_indice2_update5) { TEST(scatter_nd_update_gpu_fp16_test8, data6_indice2_update5) {
auto& engine = get_test_engine(); auto& engine = get_test_engine();
auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 3, 4, 2 } }); // data auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 4, 3, 2 } }); // data
auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // indices auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // indices
auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 1, 3, 4, 2 } }); // updates auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 1, 2, 4, 3 } }); // updates
set_values(input1, { set_values(input1, {
//0,0 //0,0
@ -808,7 +1278,7 @@ TEST(scatter_nd_update_gpu_fp16_test7, data5_indice2_update4) {
TEST(scatter_nd_update_gpu_fp16_test6, data4_indice2_update3) { TEST(scatter_nd_update_gpu_fp16_test6, data4_indice2_update3) {
auto& engine = get_test_engine(); auto& engine = get_test_engine();
auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 3, 4, 2 } }); // data auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 3, 2, 4 } }); // data
auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices
auto input3 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 4, 1, 2 } }); // updates auto input3 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 4, 1, 2 } }); // updates