From 6e3dd4adce2f7a6397371a72c75520af37cd5272 Mon Sep 17 00:00:00 2001 From: Sieun Kim Date: Mon, 13 Jun 2022 13:25:24 +0900 Subject: [PATCH] Update scatter nd update kernel to support blocked_formats (#11533) * draft pr for planar and fsv16 * draft pr for general test * update fusion test (failing) * update fusing test (pass) * update fusing test (include exception) * clean gpu unit test * review comment applied * unit test cases added & cpplint applied * cpplint error fixed * change gpu test cases for fp16 * fusing test fix generate_unique_indices * fix typo * revise cl kernel for occasions when updates shape is altered --- .../src/graph/impls/ocl/scatter_nd_update.cpp | 80 +++ src/plugins/intel_gpu/src/graph/program.cpp | 4 +- .../scatter_nd_update_kernel_ref.cpp | 25 +- .../core/cl_kernels/scatter_nd_update_ref.cl | 220 +++++--- .../src/plugin/ops/scatter_nd_update.cpp | 15 + src/plugins/intel_gpu/tests/CMakeLists.txt | 3 +- .../fusions/scatter_nd_update_fusion_test.cpp | 160 +++++- .../test_cases/scatter_nd_update_gpu_test.cpp | 490 +++++++++++++++++- 8 files changed, 892 insertions(+), 105 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scatter_nd_update.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scatter_nd_update.cpp index b5a803ba259..912358211ce 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/scatter_nd_update.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scatter_nd_update.cpp @@ -55,12 +55,92 @@ attach_scatter_nd_update_impl::attach_scatter_nd_update_impl() { std::make_tuple(data_types::f32, format::bfyx), std::make_tuple(data_types::f16, format::bfyx), std::make_tuple(data_types::i32, format::bfyx), + std::make_tuple(data_types::i8, format::bfyx), + std::make_tuple(data_types::u8, format::bfyx), + std::make_tuple(data_types::f32, format::bfzyx), std::make_tuple(data_types::f16, format::bfzyx), std::make_tuple(data_types::i32, format::bfzyx), + std::make_tuple(data_types::i8, format::bfzyx), + std::make_tuple(data_types::u8, format::bfzyx), + std::make_tuple(data_types::f32, format::bfwzyx), std::make_tuple(data_types::f16, format::bfwzyx), std::make_tuple(data_types::i32, format::bfwzyx), + std::make_tuple(data_types::i8, format::bfwzyx), + std::make_tuple(data_types::u8, format::bfwzyx), + + std::make_tuple(data_types::f32, format::b_fs_yx_fsv4), + std::make_tuple(data_types::f16, format::b_fs_yx_fsv4), + std::make_tuple(data_types::i32, format::b_fs_yx_fsv4), + std::make_tuple(data_types::i8, format::b_fs_yx_fsv4), + std::make_tuple(data_types::u8, format::b_fs_yx_fsv4), + + std::make_tuple(data_types::f32, format::b_fs_yx_fsv16), + std::make_tuple(data_types::f16, format::b_fs_yx_fsv16), + std::make_tuple(data_types::i32, format::b_fs_yx_fsv16), + std::make_tuple(data_types::i8, format::b_fs_yx_fsv16), + std::make_tuple(data_types::u8, format::b_fs_yx_fsv16), + + std::make_tuple(data_types::f32, format::b_fs_yx_fsv32), + std::make_tuple(data_types::f16, format::b_fs_yx_fsv32), + std::make_tuple(data_types::i32, format::b_fs_yx_fsv32), + std::make_tuple(data_types::i8, format::b_fs_yx_fsv32), + std::make_tuple(data_types::u8, format::b_fs_yx_fsv32), + + std::make_tuple(data_types::f32, format::b_fs_zyx_fsv16), + std::make_tuple(data_types::f16, format::b_fs_zyx_fsv16), + std::make_tuple(data_types::i32, format::b_fs_zyx_fsv16), + std::make_tuple(data_types::i8, format::b_fs_zyx_fsv16), + std::make_tuple(data_types::u8, format::b_fs_zyx_fsv16), + + std::make_tuple(data_types::f32, format::b_fs_zyx_fsv32), + std::make_tuple(data_types::f16, format::b_fs_zyx_fsv32), + std::make_tuple(data_types::i32, format::b_fs_zyx_fsv32), + std::make_tuple(data_types::i8, format::b_fs_zyx_fsv32), + std::make_tuple(data_types::u8, format::b_fs_zyx_fsv32), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv2), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv2), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv2), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv2), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv2), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv2), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv2), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv4), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::i32, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32), }); } diff --git a/src/plugins/intel_gpu/src/graph/program.cpp b/src/plugins/intel_gpu/src/graph/program.cpp index e69e4c3b02e..9bd9075c325 100644 --- a/src/plugins/intel_gpu/src/graph/program.cpp +++ b/src/plugins/intel_gpu/src/graph/program.cpp @@ -1397,7 +1397,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) { prim.type() != cldnn::region_yolo::type_id() && prim.type() != cldnn::normalize::type_id() && prim.type() != cldnn::mvn::type_id() && - prim.type() != cldnn::gather::type_id()) { + prim.type() != cldnn::gather::type_id() && + prim.type() != cldnn::scatter_nd_update::type_id()) { can_use_fsv16 = false; } @@ -1423,6 +1424,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) { prim.type() != cldnn::softmax::type_id() && prim.type() != cldnn::fully_connected::type_id() && prim.type() != cldnn::generic_layer::type_id() && + prim.type() != cldnn::scatter_nd_update::type_id() && prim.type() != cldnn::quantize::type_id()) can_use_bs_fs_yx_bsv16_fsv16 = false; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp index 8f2130bb6dc..88dfa626e5f 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp @@ -14,17 +14,15 @@ ParamsKey ScatterNDUpdateKernelRef::GetSupportedKey() const { k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); k.EnableInputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::INT32); k.EnableOutputDataType(Datatype::INT8); k.EnableOutputDataType(Datatype::UINT8); - k.EnableInputLayout(DataLayout::bfyx); - k.EnableOutputLayout(DataLayout::bfyx); - k.EnableInputLayout(DataLayout::bfzyx); - k.EnableOutputLayout(DataLayout::bfzyx); - k.EnableInputLayout(DataLayout::bfwzyx); - k.EnableOutputLayout(DataLayout::bfwzyx); + k.EnableAllInputLayout(); + k.EnableAllOutputLayout(); k.EnableTensorOffset(); k.EnableTensorPitches(); k.EnableBatching(); @@ -115,14 +113,10 @@ bool ScatterNDUpdateKernelRef::Validate(const Params& p, const optional_params& return true; } -static std::string GetInputBlockND(const scatter_nd_update_params& params) { - const auto& input = params.inputs[0]; +static std::string GetInputBlockND(const scatter_nd_update_params& params, int num, const int rank) { + const auto& input = params.inputs[num]; auto input_dims = input.LogicalDims(); std::reverse(input_dims.begin(), input_dims.end()); - while (!input_dims.empty() && input_dims.back() == 1) { - input_dims.pop_back(); - } - const int rank = static_cast(input_dims.size()); std::vector block_nd(rank + 1); block_nd[rank] = 1; for (int idx = (rank - 1); idx >= 0; idx--) { @@ -157,9 +151,14 @@ KernelsData ScatterNDUpdateKernelRef::GetKernelsData(const Params& params, const auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options, i); if (i == 1) { + int input0_rank = static_cast(newParams.inputs[0].LogicalDims().size()); + int input2_rank = static_cast(newParams.inputs[2].LogicalDims().size()); cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true")); cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim)); - cldnn_jit.AddConstant(MakeJitConstant("INPUT_BLOCK_ND", GetInputBlockND(newParams))); + cldnn_jit.AddConstant(MakeJitConstant("INPUT0_BLOCK_ND", GetInputBlockND(newParams, 0, input0_rank))); + cldnn_jit.AddConstant(MakeJitConstant("INPUT1_BLOCK_ND", GetInputBlockND(newParams, 1, newParams.indices_rank - 1))); + cldnn_jit.AddConstant(MakeJitConstant("INPUT2_BLOCK_ND", GetInputBlockND(newParams, 2, input2_rank))); + cldnn_jit.AddConstant(MakeJitConstant("INDICES_RANK", newParams.indices_rank)); } std::pair jit = CreateJit(kernelName, cldnn_jit, entry_point); diff --git a/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl index 8cc60fa3cb8..bfe66e80303 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl @@ -1,3 +1,4 @@ + // Copyright (C) 2018-2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -16,6 +17,24 @@ #define ORDER b,f,w,z,y,x #endif +#if INPUT2_DIMS == 4 + #define UPD_ORDER upd_b,upd_f,upd_y,upd_x +#elif INPUT2_DIMS == 5 + #define UPD_ORDER upd_b,upd_f,upd_z,upd_y,upd_x +#elif INPUT2_DIMS == 6 + #define UPD_ORDER upd_b,upd_f,upd_w,upd_z,upd_y,upd_x +#endif + +#if INPUT1_DIMS == 4 + #define IDX_ORDER idx_b,idx_f,idx_y,idx_x +#elif INPUT1_DIMS == 5 + #define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x +#elif INPUT1_DIMS == 6 + #define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x +#endif + +#define INDICES_MAX_DIM 6 + KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data, const __global INPUT1_TYPE* indices, const __global INPUT2_TYPE* updates, @@ -49,81 +68,122 @@ KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data, #else // Second kernel - const uint blockND[] = {INPUT_BLOCK_ND}; - const uint k = INDICES_LAST_DIM; - const uint size_to_update = blockND[INDICES_LAST_DIM]; - const uint indices_idx = dim2; - const uint indices_offset = indices_idx * k; - uint dst_offset = 0; - - for (uint i = 0; i < k; i++) { - INPUT1_TYPE idxValue = indices[indices_offset + i]; - dst_offset += idxValue * blockND[i + 1]; - } - - uint update_offset = indices_idx * size_to_update; - - for (int i = 0; i < size_to_update; i++) { - uint dst_idx = dst_offset + i; - uint up_idx = update_offset + i; - INPUT2_TYPE val = updates[up_idx]; - - #if HAS_FUSED_OPS - #if OUTPUT_DIMS == 4 - const uint y_pitch = OUTPUT_SIZE_X; - const uint f_pitch = y_pitch * OUTPUT_SIZE_Y; - const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; - - const uint b_remain = dst_idx % b_pitch; - const uint f_remain = b_remain % f_pitch; - const uint y_remain = f_remain % y_pitch; - - const uint b = dst_idx / b_pitch; - const uint f = b_remain / f_pitch; - const uint y = f_remain / y_pitch; - const uint x = y_remain; - #elif OUTPUT_DIMS == 5 - const uint y_pitch = OUTPUT_SIZE_X; - const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; - const uint f_pitch = z_pitch * OUTPUT_SIZE_Z; - const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; - - const uint b_remain = dst_idx % b_pitch; - const uint f_remain = b_remain % f_pitch; - const uint z_remain = f_remain % z_pitch; - const uint y_remain = z_remain % y_pitch; - - const uint b = dst_idx / b_pitch; - const uint f = b_remain / f_pitch; - const uint z = f_remain / z_pitch; - const uint y = z_remain / y_pitch; - const uint x = y_remain; - #elif OUTPUT_DIMS == 6 - const uint y_pitch = OUTPUT_SIZE_X; - const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; - const uint w_pitch = z_pitch * OUTPUT_SIZE_Z; - const uint f_pitch = w_pitch * OUTPUT_SIZE_W; - const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; - - const uint b_remain = dst_idx % b_pitch; - const uint f_remain = b_remain % f_pitch; - const uint w_remain = f_remain % w_pitch; - const uint z_remain = w_remain % z_pitch; - const uint y_remain = z_remain % y_pitch; - - const uint b = dst_idx / b_pitch; - const uint f = b_remain / f_pitch; - const uint w = f_remain / w_pitch; - const uint z = w_remain / z_pitch; - const uint y = z_remain / y_pitch; - const uint x = y_remain; - #endif - - FUSED_OPS_SECOND_KERNEL; - output[dst_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL); - #else - output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + const uint dataND[] = {INPUT0_BLOCK_ND}; + const uint updatesND[] = {INPUT2_BLOCK_ND}; + const uint indicesND[] = {INPUT1_BLOCK_ND}; + const uint size_to_update = dataND[INDICES_LAST_DIM]; + + #if INPUT1_DIMS == 4 + const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X}; + #elif INPUT1_DIMS == 5 + const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; + #elif INPUT1_DIMS == 6 + const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X}; #endif + + #if INPUT0_DIMS == 4 + const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + #elif INPUT0_DIMS == 5 + const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + #elif INPUT0_DIMS == 6 + const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X}; + #endif + + // Get indices index + uint idx[INDICES_MAX_DIM] = {0}; + uint rmd_idx = dim2; + for (int i = 0; i < INDICES_RANK - 1; ++i) { + idx[i] = rmd_idx / indicesND[1 + i]; + rmd_idx %= indicesND[1 + i]; + } + + uint out[INDICES_MAX_DIM] = {0}; + for (int i = 0; i < indices_dim[INDICES_RANK - 1]; ++i) { + idx[INDICES_RANK - 1] = i; + const uint idx_b = idx[0]; + const uint idx_f = idx[1]; + #if INPUT1_DIMS == 4 + const uint idx_y = idx[2]; + const uint idx_x = idx[3]; + #elif INPUT1_DIMS == 5 + const uint idx_z = idx[2]; + const uint idx_y = idx[3]; + const uint idx_x = idx[4]; + #elif INPUT1_DIMS == 6 + const uint idx_w = idx[2]; + const uint idx_z = idx[3]; + const uint idx_y = idx[4]; + const uint idx_x = idx[5]; + #endif + uint index = GET_UPDATES_INDEX(INPUT1, IDX_ORDER); + out[i] = indices[index]; + + // Check if tensor size is valid + // ex) when data format = bfyx and data shape = { 3, 3, 4, 1 }, indices shape is { 2, 1 } with rank = 2, indices values are { 1.0, 4.0 }, + // the second indices value is invalid as data shape has 'b' of size 3, and therefore 4 cannot be a correct index of data + // If indices value is invalid, saturate value to max valid value (ex. 4.0 -> 2.0) + if(out[i] >= data_dim[i]) + out[i] = data_dim[i] - 1; + } + + for (int i = 0; i < size_to_update; ++i) { + // Define updates index + uint upd[INDICES_MAX_DIM] = {0}; + for (int j = 0; j < INDICES_RANK - 1; ++j) { + upd[j] = idx[j]; + } + uint data_rmd = i, updates_rmd = i; + for (int j = indices_dim[INDICES_RANK - 1]; j < INPUT0_DIMS; ++j) { + out[j] = data_rmd / dataND[j + 1]; + data_rmd %= dataND[j + 1]; + } + for (int k = INDICES_RANK - 1; k < INPUT2_DIMS; ++k) { + upd[k] = updates_rmd / updatesND[k + 1]; + updates_rmd %= updatesND[k + 1]; + } + // Get update index + const uint upd_b = upd[0]; + const uint upd_f = upd[1]; + #if INPUT2_DIMS == 4 + const uint upd_y = upd[2]; + const uint upd_x = upd[3]; + #elif INPUT2_DIMS == 5 + const uint upd_z = upd[2]; + const uint upd_y = upd[3]; + const uint upd_x = upd[4]; + #elif INPUT2_DIMS == 6 + const uint upd_w = upd[2]; + const uint upd_z = upd[3]; + const uint upd_y = upd[4]; + const uint upd_x = upd[5]; + #endif + uint upd_idx = GET_UPDATES_INDEX(INPUT2, UPD_ORDER); + + // Get output index + const uint b = out[0]; + const uint f = out[1]; + #if INPUT0_DIMS == 4 + const uint y = out[2]; + const uint x = out[3]; + #elif INPUT0_DIMS == 5 + const uint z = out[2]; + const uint y = out[3]; + const uint x = out[4]; + #elif INPUT0_DIMS == 6 + const uint w = out[2]; + const uint z = out[3]; + const uint y = out[4]; + const uint x = out[5]; + #endif + uint out_idx = GET_OUTPUT_INDEX(ORDER); + INPUT2_TYPE val = updates[upd_idx]; + + #if HAS_FUSED_OPS + FUSED_OPS_SECOND_KERNEL; + output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL); + #else + output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + #endif } #endif @@ -140,3 +200,15 @@ KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data, #ifdef ORDER #undef ORDER #endif + +#ifdef UPD_ORDER +#undef UPD_ORDER +#endif + +#ifdef IDX_ORDER +#undef IDX_ORDER +#endif + +#ifdef INDICES_MAX_DIM +#undef INDICES_MAX_DIM +#endif diff --git a/src/plugins/intel_gpu/src/plugin/ops/scatter_nd_update.cpp b/src/plugins/intel_gpu/src/plugin/ops/scatter_nd_update.cpp index 50f5ae5379d..0934e2f7891 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/scatter_nd_update.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/scatter_nd_update.cpp @@ -20,6 +20,21 @@ static void CreateScatterNDUpdateOp(Program& p, const std::shared_ptrget_input_shape(1).size(); + auto indices_constant = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(1)); + if (indices_constant) { + auto indices = indices_constant->cast_vector(); + auto indices_last_dim = op->get_input_shape(1)[indices_rank - 1]; + auto data_shape = op->get_input_shape(0); + bool valid = true; + for (int i = 0; i < indices.size(); ++i) { + if (indices[i] >= data_shape[i % indices_last_dim]) + valid = false; + } + + if (!valid) + IE_THROW() << "Invaild indices values"; + } + auto primitive = cldnn::scatter_nd_update(layerName, inputPrimitives[0], inputPrimitives[1], diff --git a/src/plugins/intel_gpu/tests/CMakeLists.txt b/src/plugins/intel_gpu/tests/CMakeLists.txt index b15c123e06a..6fad1deda56 100644 --- a/src/plugins/intel_gpu/tests/CMakeLists.txt +++ b/src/plugins/intel_gpu/tests/CMakeLists.txt @@ -47,7 +47,8 @@ target_link_libraries(${TARGET_NAME} PRIVATE openvino_intel_gpu_graph target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/test_utils/ $ - $) + $ + ${CMAKE_HOME_DIRECTORY}/src/core/reference/include/) if(WIN32) target_link_libraries(${TARGET_NAME} PRIVATE setupapi) elseif((NOT ANDROID) AND (UNIX)) diff --git a/src/plugins/intel_gpu/tests/fusions/scatter_nd_update_fusion_test.cpp b/src/plugins/intel_gpu/tests/fusions/scatter_nd_update_fusion_test.cpp index 59c45af5f58..45f6d3f0aa3 100644 --- a/src/plugins/intel_gpu/tests/fusions/scatter_nd_update_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/fusions/scatter_nd_update_fusion_test.cpp @@ -12,6 +12,9 @@ #include #include +#include +#include +#include using namespace cldnn; using namespace ::tests; @@ -82,16 +85,16 @@ public: std::set> unique_indices; std::vector result; auto indices_shape = p.indices_shape.sizes(get_default_format(p.indices_rank)); - auto last_indices_dim = indices_shape.back(); + auto data_shape = p.input_shape.sizes(p.input_format); + auto last_indices_dim = indices_shape.at(p.indices_rank - 1); - auto count = 1; - for (size_t i = 0; i < indices_shape.size() - 1; i++) - count *= indices_shape[i]; + auto count = p.indices_shape.count() / last_indices_dim; while (unique_indices.size() != count) { std::vector indices; - for (size_t i = 0; i < last_indices_dim; i++) - indices.push_back(generate_random_val(0, indices_shape[i])); + for (size_t i = 0; i < last_indices_dim; i++) { + indices.push_back(static_cast(generate_random_val(0, data_shape[i] - 1))); + } unique_indices.insert(indices); } @@ -173,6 +176,55 @@ public: #define CASE_SCATTER_ND_UPDATE_FP32_6D_5 { 6, 7, 8, 9, 2, 2 }, { 5, 5, 1, 1 }, { 5, 8, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx #define CASE_SCATTER_ND_UPDATE_FP32_6D_6 { 6, 7, 8, 9, 2, 2 }, { 5, 6, 1, 1 }, { 5, 1, 1, 1, 1, 1 }, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f16, format::b_fs_yx_fsv16, data_types::f16, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1 { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, 1, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 8, 9 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3 { 6, 7, 8, 9, 10 }, { 5, 3, 1, 1 }, { 5, 9, 1, 1, 8 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4 { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5 { 6, 7, 8, 9, 10 }, { 5, 5, 1, 1 }, { 5, 1, 1, 1, 1 }, 2, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, 3, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, 3, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8 { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, 4, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9 { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, 4, data_types::f16, format::b_fs_zyx_fsv16, data_types::f16, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f32, format::b_fs_yx_fsv16, data_types::f32, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1 { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, 1, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 8, 9 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3 { 6, 7, 8, 9, 10 }, { 5, 3, 1, 1 }, { 5, 9, 1, 1, 8 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4 { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5 { 6, 7, 8, 9, 10 }, { 5, 5, 1, 1 }, { 5, 1, 1, 1, 1 }, 2, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, 3, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7 { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, 3, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8 { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, 4, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9 { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, 4, data_types::f32, format::b_fs_zyx_fsv16, data_types::f32, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f16, format::bs_fs_yx_bsv32_fsv16, data_types::f16, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1 { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, 1, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2 { 6, 6, 1, 1 }, { 3, 2, 1, 1 }, { 3, 1, 1, 1 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4 { 6, 7, 8, 9 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5 { 6, 7, 8, 9 }, { 6, 2, 1, 1 }, { 6, 9, 1, 8 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6 { 6, 7, 8, 9 }, { 6, 3, 1, 1 }, { 6, 8, 1, 1 }, 2, data_types::f32, format::bs_fs_yx_bsv32_fsv16, data_types::f32, format::bfyx + + class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {}; TEST_P(scatter_nd_update_quantize, basic) { auto p = GetParam(); @@ -235,6 +287,54 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_quantize, ::testing::Val scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1, 2, 3 }, // FP16 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1, 2, 3 }, // FP32 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1, 2, 3 }, // FP16 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1, 2, 3 }, // FP32 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6, 2, 3 }, })); class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {}; @@ -298,4 +398,52 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 }, scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_1, 2, 5 }, // FP16 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_4D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_6, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_7, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_8, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_FSV16_5D_9, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_1, 2, 5 }, // FP32 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_4D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_6, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_7, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_8, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_FSV16_5D_9, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_1, 2, 5 }, // FP16 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_BSV32_FSV16_4D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_1, 2, 5 }, // FP32 + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_BSV32_FSV16_4D_6, 2, 5 }, })); diff --git a/src/plugins/intel_gpu/tests/test_cases/scatter_nd_update_gpu_test.cpp b/src/plugins/intel_gpu/tests/test_cases/scatter_nd_update_gpu_test.cpp index 23b869a27ec..27224974750 100644 --- a/src/plugins/intel_gpu/tests/test_cases/scatter_nd_update_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/test_cases/scatter_nd_update_gpu_test.cpp @@ -3,6 +3,7 @@ // #include "test_utils.h" +#include "ngraph/runtime/reference/scatter_nd_update.hpp" #include #include @@ -12,17 +13,489 @@ #include #include +#include +#include +#include +#include using namespace cldnn; using namespace ::tests; +struct scatter_nd_update_basic_test_params +{ + data_types input_type; + data_types indices_type; + data_types updates_type; + format input_format; + format indices_format; + format updates_format; + format input_result_format; + format indices_result_format; + format updates_result_format; + tensor input_size; + tensor indices_size; + tensor updates_size; + int indices_rank; +}; + +struct scatter_nd_update_random_test : testing::TestWithParam +{ + format get_default_format(int rank = 4) { + if (rank <= 4) + return cldnn::format::bfyx; + else if (rank == 5) + return cldnn::format::bfzyx; + else + return cldnn::format::bfwzyx; + } + + template + T generate_random_val(int min, int max, int k = 8) { + static std::default_random_engine generator(random_seed); + // 1/k is the resolution of the floating point numbers + std::uniform_int_distribution distribution(k * min, k * max); + T val = (T)distribution(generator); + val /= k; + + return val; + } + + template + std::vector generate_unique_indices(const scatter_nd_update_basic_test_params& p) { + std::set> unique_indices; + std::vector result; + auto indices_shape = p.indices_size.sizes(get_default_format(p.indices_rank)); + auto data_shape = p.input_size.sizes(p.input_format); + auto last_indices_dim = indices_shape.at(p.indices_rank - 1); + + auto count = p.indices_size.count() / last_indices_dim; + + while (unique_indices.size() != count) { + std::vector indices; + for (size_t i = 0; i < last_indices_dim; i++) { + indices.push_back(static_cast(generate_random_val(0, data_shape[i] - 1))); + } + + unique_indices.insert(indices); + } + + std::for_each(unique_indices.begin(), + unique_indices.end(), + [&](const std::vector& indices) { + result.insert(result.end(), indices.begin(), indices.end()); + }); + + return result; + } + + template + void execute_fp16(const scatter_nd_update_basic_test_params& params) + { + auto& engine = get_test_engine(); + + auto input1 = engine.allocate_memory({ params.input_type, params.input_format, params.input_size }); + auto input2 = engine.allocate_memory({ params.indices_type, params.indices_format, params.indices_size }); + auto input3 = engine.allocate_memory({ params.updates_type, params.updates_format, params.updates_size }); + + std::vector input_vec(static_cast(cldnn::format::dimension(params.input_format))); + for (int i = 0; i < input_vec.size(); ++i) + input_vec[i] = static_cast(params.input_size.sizes()[i]); + std::reverse(input_vec.begin() + 2, input_vec.end()); + + std::vector updates_vec(static_cast(cldnn::format::dimension(params.updates_format))); + for (int i = 0; i < updates_vec.size(); ++i) + updates_vec[i] = static_cast(params.updates_size.sizes()[i]); + std::reverse(updates_vec.begin() + 2, updates_vec.end()); + + std::vector indices_vec(static_cast(cldnn::format::dimension(params.indices_format))); + for (size_t i = 0; i < indices_vec.size(); ++i) + indices_vec[i] = static_cast(params.indices_size.sizes()[i]); + std::reverse(indices_vec.begin() + 2, indices_vec.end()); + indices_vec.resize(params.indices_rank); + + auto input_data_fp16 = generate_random_1d(params.input_size.count(), -127, 127); + auto indices_data_fp16 = generate_unique_indices(params); + auto updates_data_fp16 = generate_random_1d(params.updates_size.count(), -127, 127); + + std::vector input_data(params.input_size.count()); + for (int i = 0; i < params.input_size.count(); ++i) + input_data[i] = static_cast(input_data_fp16[i]); + std::vector indices_data(params.indices_size.count()); + for (int i = 0; i < params.indices_size.count(); ++i) + indices_data[i] = static_cast(indices_data_fp16[i]); + std::vector updates_data(params.updates_size.count()); + for (int i = 0; i < params.updates_size.count(); ++i) + updates_data[i] = static_cast(updates_data_fp16[i]); + + set_values(input1, input_data_fp16); + set_values(input2, indices_data_fp16); + set_values(input3, updates_data_fp16); + + // execute scatter_nd_update + topology topology( + input_layout("InputData", input1->get_layout()), + input_layout("InputIndices", input2->get_layout()), + input_layout("InputUpdates", input3->get_layout()), + reorder("reorder1", "InputData", params.input_result_format, params.input_type), + reorder("reorder2", "InputIndices", params.indices_result_format, params.indices_type), + reorder("reorder3", "InputUpdates", params.updates_result_format, params.updates_type), + scatter_nd_update("scatter_nd_update", "reorder1", "reorder2", "reorder3", params.indices_rank), + reorder("out", "scatter_nd_update", params.input_format, params.input_type) + ); + + network network(engine, topology); + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + auto output = outputs.at("out").get_memory(); + cldnn::mem_lock outputs_ptr(output, get_test_stream()); + + auto outputs_ref = std::vector(params.input_size.count()); + ngraph::runtime::reference::scatterNdUpdate(input_data.data(), + indices_data.data(), + updates_data.data(), + outputs_ref.data(), + ov::Shape(input_vec.begin(), input_vec.end()), + ov::Shape(indices_vec.begin(), indices_vec.end()), + ov::Shape(updates_vec.begin(), updates_vec.end())); + + for (size_t i = 0; i < outputs_ref.size(); ++i) { + EXPECT_EQ(outputs_ref[i], float16_to_float32(outputs_ptr[i])); + } + } + + template + void execute(const scatter_nd_update_basic_test_params& params) + { + // create input, indices, updates using params + auto& engine = get_test_engine(); + + auto input1 = engine.allocate_memory({ params.input_type, params.input_format, params.input_size }); + auto input2 = engine.allocate_memory({ params.indices_type, params.indices_format, params.indices_size }); + auto input3 = engine.allocate_memory({ params.updates_type, params.updates_format, params.updates_size }); + + std::vector input_vec(static_cast(cldnn::format::dimension(params.input_format))); + for (int i = 0; i < input_vec.size(); ++i) + input_vec[i] = static_cast(params.input_size.sizes()[i]); + std::reverse(input_vec.begin() + 2, input_vec.end()); + + std::vector updates_vec(static_cast(cldnn::format::dimension(params.updates_format))); + for (int i = 0; i < updates_vec.size(); ++i) + updates_vec[i] = static_cast(params.updates_size.sizes()[i]); + std::reverse(updates_vec.begin() + 2, updates_vec.end()); + + std::vector indices_vec(static_cast(cldnn::format::dimension(params.indices_format))); + for (size_t i = 0; i < indices_vec.size(); ++i) + indices_vec[i] = static_cast(params.indices_size.sizes()[i]); + std::reverse(indices_vec.begin() + 2, indices_vec.end()); + indices_vec.resize(params.indices_rank); + + auto input_data = generate_random_1d(params.input_size.count(), -127, 127); + auto indices_data = generate_unique_indices(params); + auto updates_data = generate_random_1d(params.updates_size.count(), -127, 127); + + set_values(input1, input_data); + set_values(input2, indices_data); + set_values(input3, updates_data); + + // execute scatter_nd_update + topology topology( + input_layout("InputData", input1->get_layout()), + input_layout("InputIndices", input2->get_layout()), + input_layout("InputUpdates", input3->get_layout()), + reorder("reorder1", "InputData", params.input_result_format, params.input_type), + reorder("reorder2", "InputIndices", params.indices_result_format, params.indices_type), + reorder("reorder3", "InputUpdates", params.updates_result_format, params.updates_type), + scatter_nd_update("scatter_nd_update", "reorder1", "reorder2", "reorder3", params.indices_rank), + reorder("out", "scatter_nd_update", params.input_format, params.input_type) + ); + + network network(engine, topology); + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + auto output = outputs.at("out").get_memory(); + cldnn::mem_lock outputs_ptr(output, get_test_stream()); + + auto outputs_ref = std::vector(params.input_size.count()); + ngraph::runtime::reference::scatterNdUpdate(input_data.data(), + indices_data.data(), + updates_data.data(), + outputs_ref.data(), + ov::Shape(input_vec.begin(), input_vec.end()), + ov::Shape(indices_vec.begin(), indices_vec.end()), + ov::Shape(updates_vec.begin(), updates_vec.end())); + + for (size_t i = 0; i < outputs_ref.size(); ++i) { + EXPECT_EQ(outputs_ref[i], outputs_ptr[i]); + } + } +}; + +TEST_P(scatter_nd_update_random_test, random) +{ + auto param = GetParam(); + if (param.input_type == data_types::u8) + this->execute(param); + else if (param.input_type == data_types::i8) + this->execute(param); + else if (param.input_type == data_types::i32) + this->execute(param); + else if (param.input_type == data_types::i64) + this->execute(param); + else if (param.input_type == data_types::f16) + this->execute_fp16(param); + else if (param.input_type == data_types::f32) + this->execute(param); + else + IE_THROW() << "unidentified data type"; +} + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_bsv32_fsv16_4d_rank_1, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfyx, format::bfyx, format::bfyx, + format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, + { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, + 1 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_bsv32_fsv16_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfyx, format::bfyx, format::bfyx, + format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, + { 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_4d_rank_1, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfyx, format::bfyx, format::bfyx, + format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, + { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, + 1 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfyx, format::bfyx, format::bfyx, + format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, + { 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 3, 3, 10 }, { 5, 2, 1, 1 }, { 5, 10, 1, 3, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_3, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, + 3 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp32_fsv16_5d_rank_4, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f32, data_types::f32, data_types::f32, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, + 4 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_4d_rank_1, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfyx, format::bfyx, format::bfyx, + format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, + { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, + 1 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfyx, format::bfyx, format::bfyx, + format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, + { 48, 24, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_1, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 1, 1, 1 }, { 5, 7, 8, 9, 10 }, + 1 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 4, 1, 1 }, { 5, 8, 1, 1, 1 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_3, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 2, 1, 3 }, { 5, 2, 1, 8, 9 }, + 3 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_fsv16_5d_rank_4, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 2, 4, 3 }, { 5, 2, 1, 8, 3 }, + 4 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_bsv32_fsv16_4d_rank_1, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfyx, format::bfyx, format::bfyx, + format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, + { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, + 1 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_fp16_bsv32_fsv16_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::f16, data_types::f16, data_types::f16, + format::bfyx, format::bfyx, format::bfyx, + format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, + { 48, 24, 3, 3 }, {3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_bsv32_fsv16_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::i8, data_types::i8, data_types::i8, + format::bfyx, format::bfyx, format::bfyx, + format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, format::bs_fs_yx_bsv32_fsv16, + { 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_bsv32_fsv32_4d_rank_1, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::i8, data_types::i8, data_types::i8, + format::bfyx, format::bfyx, format::bfyx, + format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32, + { 6, 1, 1, 1 }, { 3, 1, 1, 1 }, { 3, 1, 1, 1 }, + 1 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv32_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::i8, data_types::i8, data_types::i8, + format::bfyx, format::bfyx, format::bfyx, + format::b_fs_yx_fsv32, format::b_fs_yx_fsv32, format::b_fs_yx_fsv32, + { 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv32_5d_rank_3, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::i8, data_types::i8, data_types::i8, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv32, format::b_fs_yx_fsv32, format::b_fs_zyx_fsv32, + { 6, 7, 8, 9, 10 }, { 5, 2, 1, 2 }, { 5, 2, 8, 9, 10 }, + 3 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv16_4d_rank_2, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::i8, data_types::i8, data_types::i8, + format::bfyx, format::bfyx, format::bfyx, + format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, format::b_fs_yx_fsv16, + { 41, 23, 3, 3 }, { 3, 2, 1, 1 }, { 3, 3, 1, 3 }, + 2 } + })); + +INSTANTIATE_TEST_SUITE_P(scatter_nd_update_gpu_random_test_i8_fsv16_5d_rank_4, + scatter_nd_update_random_test, + testing::ValuesIn( + std::vector{ + { data_types::i8, data_types::i8, data_types::i8, + format::bfzyx, format::bfyx, format::bfzyx, + format::b_fs_zyx_fsv16, format::b_fs_yx_fsv16, format::b_fs_zyx_fsv16, + { 6, 7, 8, 9, 10 }, { 5, 2, 3, 3 }, { 5, 2, 8, 9, 3 }, + 4 } + })); + + TEST(scatter_nd_update_gpu_fp16_test15, data5_indice3_update5) { auto& engine = get_test_engine(); auto input1 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 2, 2, 2, 4, 3 } }); // data auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 2, 1, 1 } }); // indices - auto input3 = engine.allocate_memory({ data_types::f16, format::bfzyx, { 1, 2, 2, 4, 3, 2 } }); // updates + auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 4, 3, 2 } }); // updates set_values(input1, { // 0 @@ -158,7 +631,7 @@ TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) { FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), - FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f), + FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), // 1 @@ -166,9 +639,9 @@ TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) { FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), - FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f), + FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), - FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), }; topology topology; @@ -398,7 +871,6 @@ TEST(scatter_nd_update_gpu_fp16_test11, data6_indice1_update6) { FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), - }); std::vector expected_results = { @@ -511,7 +983,6 @@ TEST(scatter_nd_update_gpu_fp16_test10, data5_indice1_update5) { FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), - }); std::vector expected_results = { @@ -595,7 +1066,6 @@ TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) { FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), - }); std::vector expected_results = { @@ -641,9 +1111,9 @@ TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) { TEST(scatter_nd_update_gpu_fp16_test8, data6_indice2_update5) { auto& engine = get_test_engine(); - auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 3, 4, 2 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 1, 2, 2, 4, 3, 2 } }); // data auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // indices - auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 1, 3, 4, 2 } }); // updates + auto input3 = engine.allocate_memory({ data_types::f16, format::bfwzyx, { 2, 2, 1, 2, 4, 3 } }); // updates set_values(input1, { //0,0 @@ -808,7 +1278,7 @@ TEST(scatter_nd_update_gpu_fp16_test7, data5_indice2_update4) { TEST(scatter_nd_update_gpu_fp16_test6, data4_indice2_update3) { auto& engine = get_test_engine(); - auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 3, 4, 2 } }); // data + auto input1 = engine.allocate_memory({ data_types::f16, format::bfyx, { 2, 3, 2, 4 } }); // data auto input2 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices auto input3 = engine.allocate_memory({ data_types::f16, format::bfyx, { 3, 4, 1, 2 } }); // updates