diff --git a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp index ff2aafa3f3c..6d897460c88 100644 --- a/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp +++ b/inference-engine/src/cldnn_engine/cldnn_primitives_list.hpp @@ -158,6 +158,7 @@ REGISTER_FACTORY(v3, EmbeddingSegmentsSum); REGISTER_FACTORY(v3, ExtractImagePatches); REGISTER_FACTORY(v3, ScatterUpdate); REGISTER_FACTORY(v3, ScatterElementsUpdate); +REGISTER_FACTORY(v3, ScatterNDUpdate); // REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion // ----------------------------- Unsupported v3 ops ----------------------------- // @@ -167,7 +168,6 @@ REGISTER_FACTORY(v3, ScatterElementsUpdate); // REGISTER_FACTORY(v3, NonZero); // REGISTER_FACTORY(v3, ROIAlign); // REGISTER_FACTORY(v3, ReadValue); -// REGISTER_FACTORY(v3, ScatterNDUpdate); // REGISTER_FACTORY(v3, ShapeOf); // REGISTER_FACTORY(v3, TopK); diff --git a/inference-engine/src/cldnn_engine/ops/scatter_nd_update.cpp b/inference-engine/src/cldnn_engine/ops/scatter_nd_update.cpp new file mode 100644 index 00000000000..cd3b06194df --- /dev/null +++ b/inference-engine/src/cldnn_engine/ops/scatter_nd_update.cpp @@ -0,0 +1,33 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "cldnn_program.h" +#include "cldnn_common_utils.h" + +#include "ngraph/op/scatter_nd_update.hpp" +#include "ngraph/op/constant.hpp" + +#include "api/scatter_nd_update.hpp" + +namespace CLDNNPlugin { + +void CreateScatterNDUpdateOp(Program& p, const std::shared_ptr& op) { + p.ValidateInputs(op, {3}); + auto inputPrimitives = p.GetInputPrimitiveIDs(op); + std::string layerName = layer_type_name_ID(op); + auto indices_rank = op->get_input_shape(1).size(); + + auto primitive = cldnn::scatter_nd_update(layerName, + inputPrimitives[0], + inputPrimitives[1], + inputPrimitives[2], + indices_rank); + + p.AddPrimitive(primitive); + p.AddPrimitiveToProfiler(op); +} + +REGISTER_FACTORY_IMPL(v3, ScatterNDUpdate); + +} // namespace CLDNNPlugin diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/scatter_nd_update.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/scatter_nd_update.cpp new file mode 100644 index 00000000000..5a1b6cf3e5c --- /dev/null +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/single_layer_tests/scatter_nd_update.cpp @@ -0,0 +1,53 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "single_layer_tests/scatter_ND_update.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; +using namespace ngraph::opset3; + +namespace { + +// map> +// updateShape is gotten from inputShape and indicesShape +std::map, std::map, std::vector>> sliceSelectInShape{ + {{4, 3, 2, 3, 2}, {{{2, 2, 1}, {3, 2, 0, 1}}}}, + {{10, 9, 9, 11}, {{{4, 1}, {1, 3, 5, 7}}, {{1, 2}, {4, 6}}, {{2, 3}, {0, 1, 1, 2, 2, 2}}, {{1, 4}, {5, 5, 4, 9}}}}, + {{10, 9, 12, 10, 11}, {{{2, 2, 1}, {5, 6, 2, 8}}, {{2, 3}, {0, 4, 6, 5, 7, 1}}}}, + {{15}, {{{2, 1}, {1, 3}}}}, + {{15, 14}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}}}, + {{15, 14, 13}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}}}, + {{15, 14, 13, 12}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}}, + {{2, 2, 2}, {2, 3, 1, 8, 7, 5, 6, 5}}}}, + {{15, 14, 13, 12, 16}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}}, + {{2, 5}, {2, 3, 1, 8, 6, 9, 7, 5, 6, 5}}}}, + {{15, 14, 13, 12, 16, 10}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}}, + {{1, 2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}}, {{2, 5}, {2, 3, 1, 8, 6, 9, 7, 5, 6, 5}}, {{2, 6}, {2, 3, 1, 8, 6, 5, 9, 7, 5, 6, 5, 7}}}} +}; + + +const std::vector inputPrecisions = { + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16, + InferenceEngine::Precision::I32, +}; + +const std::vector idxPrecisions = { + InferenceEngine::Precision::I32, + InferenceEngine::Precision::I64, +}; + +const auto ScatterNDUpdateCases = ::testing::Combine( + ::testing::ValuesIn(ScatterNDUpdateLayerTest::combineShapes(sliceSelectInShape)), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName); +} // namespace diff --git a/inference-engine/thirdparty/clDNN/api/scatter_nd_update.hpp b/inference-engine/thirdparty/clDNN/api/scatter_nd_update.hpp new file mode 100644 index 00000000000..d418dffcbde --- /dev/null +++ b/inference-engine/thirdparty/clDNN/api/scatter_nd_update.hpp @@ -0,0 +1,54 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include "primitive.hpp" + +namespace cldnn { +/// @addtogroup cpp_api C++ API +/// @{ +/// @addtogroup cpp_topology Network Topology +/// @{ +/// @addtogroup cpp_primitives Primitives +/// @{ + +/// @brief +/// @details +struct scatter_nd_update : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(scatter_nd_update) + + /// @brief Constructs scatter_nd_update primitive. + /// @param id This primitive id. + /// @param dict Input data primitive id. + /// @param idx Input indexes primitive id. + /// @param idupd Input updates primitive id. + /// @param indices_rank Rank of indices. + scatter_nd_update(const primitive_id& id, + const primitive_id& data, + const primitive_id& idx, + const primitive_id& idupd, + const size_t indices_rank, + const padding& output_padding = padding()) + : primitive_base(id, {data, idx, idupd}, output_padding), indices_rank(indices_rank) {} + + /// @brief ScatterNDUpdate indices_rank + size_t indices_rank; +}; +/// @} +/// @} +/// @} +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h index 20096bdcfcf..94c1aaffc91 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h @@ -58,6 +58,7 @@ enum class KernelType { ONE_HOT, GATHER, SCATTER_UPDATE, + SCATTER_ND_UPDATE, SCATTER_ELEMENTS_UPDATE, DEPTH_TO_SPACE, BATCH_TO_SPACE, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp new file mode 100644 index 00000000000..84a34f3c96c --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.cpp @@ -0,0 +1,192 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "scatter_nd_update_kernel_ref.h" +#include "kernel_selector_utils.h" +#include +#include + +namespace kernel_selector { + +ParamsKey ScatterNDUpdateKernelRef::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableInputLayout(DataLayout::bfzyx); + k.EnableOutputLayout(DataLayout::bfzyx); + k.EnableInputLayout(DataLayout::bfwzyx); + k.EnableOutputLayout(DataLayout::bfwzyx); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + return k; +} + +static inline std::string GetOrderString(std::vector& order) { + std::string order_str = order[0]; + for (size_t i = 1; i < order.size(); i++) + order_str += ", " + order[i]; + + return order_str; +} + +static inline std::vector GetDefaultOrder(size_t size) { + std::vector default_order; + if (size <= 4) { + default_order = {"b", "f", "y", "x"}; + } else if (size == 5) { + default_order = {"b", "f", "z", "y", "x"}; + } else if (size == 6) { + default_order = {"b", "f", "w", "z", "y", "x"}; + } + + return default_order; +} + +ScatterNDUpdateKernelRef::DispatchData +ScatterNDUpdateKernelRef::SetDefault(const scatter_nd_update_params& params, const optional_params&, bool is_second) const { + DispatchData dispatchData; + + if (!is_second) { + const auto& scope = params.output; + dispatchData.gws = { scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v }; + } else { + auto indices_rank = params.indices_rank; + const auto& indices = params.inputs[1]; + auto indices_dims = indices.LogicalDims(); + + if (indices_dims.size() > 1) { + std::reverse(indices_dims.begin(), indices_dims.end()); + } + + dispatchData.indicesLastDim = indices_dims[indices_rank - 1]; + size_t indices_set_size = 1; + for (size_t i = 0; i < (indices_rank - 1); i++) { + indices_set_size *= indices_dims[i]; + } + + dispatchData.gws = {1, 1, indices_set_size}; + } + + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); + + return dispatchData; +} + +JitConstants ScatterNDUpdateKernelRef::GetJitConstants(const scatter_nd_update_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + if (!params.fused_ops.empty()) { + FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; + FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() }; + jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2})); + } + + return jit; +} + +bool ScatterNDUpdateKernelRef::Validate(const Params& p, const optional_params& o) const { + if (p.GetType() != KernelType:: SCATTER_ND_UPDATE || o.GetType() != KernelType::SCATTER_ND_UPDATE) { + return false; + } + + const scatter_nd_update_params& params = static_cast(p); + auto input_dims = params.inputs[0].LogicalDims(); + auto indices_dims = params.inputs[1].LogicalDims(); + std::reverse(indices_dims.begin(), indices_dims.end()); + + auto indices_rank = params.indices_rank; + if (indices_rank < 1) { + return false; + } + + if (indices_dims[indices_rank - 1] > input_dims.size()) { + return false; + } + + for (auto& fused_op : params.fused_ops) { + if (!IsFusedPrimitiveSupported(fused_op)) + return false; + } + + return true; +} + +static std::string GetInputBlockND(const scatter_nd_update_params& params) { + const auto& input = params.inputs[0]; + auto input_dims = input.LogicalDims(); + std::reverse(input_dims.begin(), input_dims.end()); + while (!input_dims.empty() && input_dims.back() == 1) { + input_dims.pop_back(); + } + const int rank = static_cast(input_dims.size()); + std::vector block_nd(rank + 1); + block_nd[rank] = 1; + for (int idx = (rank - 1); idx >= 0; idx--) { + block_nd[idx] = input_dims[idx] * block_nd[idx + 1]; + } + + std::stringstream s; + for (int i = 0; i < (rank + 1); i++) { + if (i < rank) { + s << block_nd[i] << ","; + } else { + s << block_nd[i]; + } + } + auto str_result = s.str(); + return str_result; +} + +KernelsData ScatterNDUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const { + if (!Validate(params, options)) { + return {}; + } + + KernelData kd = KernelData::Default(params, 2); + scatter_nd_update_params& newParams = *static_cast(kd.params.get()); + auto cldnn_jit = GetJitConstants(newParams); + + // First iter - copy input data to output data + // Second iter - update values specified by updates at specific index position specified by indices + for (int i = 0; i < 2; i++) { + auto dispatchData = SetDefault(newParams, options, (i == 1)); + auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options); + + if (i == 1) { + cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true")); + cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim)); + cldnn_jit.AddConstant(MakeJitConstant("INPUT_BLOCK_ND", GetInputBlockND(newParams))); + } + std::string jit = CreateJit(kernelName, cldnn_jit, entry_point); + + clKernelData& kernel = kd.kernels[i]; + + FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params)); + } + + return {kd}; +} +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.h new file mode 100644 index 00000000000..55b74dfd003 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_ref.h @@ -0,0 +1,62 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#pragma once + +#include "kernel_base_opencl.h" + +namespace kernel_selector { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// scatter_nd_update_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct scatter_nd_update_params : public base_params { + scatter_nd_update_params() : base_params(KernelType::SCATTER_ND_UPDATE), indices_rank(0) {} + + size_t indices_rank; + + virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// scatter_nd_update_optional_params +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct scatter_nd_update_optional_params : optional_params { + scatter_nd_update_optional_params() : optional_params(KernelType::SCATTER_ND_UPDATE) {} +}; + +class ScatterNDUpdateKernelRef : public KernelBaseOpenCL { +public: + struct DispatchData : public CommonDispatchData { + size_t indicesLastDim; + }; + + ScatterNDUpdateKernelRef() : KernelBaseOpenCL("scatter_nd_update_ref") {} + virtual ~ScatterNDUpdateKernelRef() {} + virtual JitConstants GetJitConstants(const scatter_nd_update_params& params) const; + virtual DispatchData SetDefault(const scatter_nd_update_params& params, const optional_params&, bool is_second) const; + KernelsData GetKernelsData(const Params& params, const optional_params& options) const override; + ParamsKey GetSupportedKey() const override; + std::vector GetSupportedFusedOps() const override { + return { FusedOpType::QUANTIZE, + FusedOpType::SCALE, + FusedOpType::ACTIVATION, + FusedOpType::ELTWISE }; + } + +protected: + bool Validate(const Params& p, const optional_params& o) const override; +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_selector.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_selector.cpp new file mode 100644 index 00000000000..59affe7d7dd --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_selector.cpp @@ -0,0 +1,27 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "scatter_nd_update_kernel_selector.h" +#include "scatter_nd_update_kernel_ref.h" + +namespace kernel_selector { + +scatter_nd_update_kernel_selector::scatter_nd_update_kernel_selector() { Attach(); } + +KernelsData scatter_nd_update_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const { + return GetNaiveBestKernel(params, options, KernelType::SCATTER_ND_UPDATE); +} +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_selector.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_selector.h new file mode 100644 index 00000000000..0e09ca6c831 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/scatter_update/scatter_nd_update_kernel_selector.h @@ -0,0 +1,35 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class scatter_nd_update_kernel_selector : public kernel_selector_base { +public: + static scatter_nd_update_kernel_selector& Instance() { + static scatter_nd_update_kernel_selector instance_; + return instance_; + } + + scatter_nd_update_kernel_selector(); + + virtual ~scatter_nd_update_kernel_selector() {} + + KernelsData GetBestKernels(const Params& params, const optional_params& options) const override; +}; +} // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl new file mode 100644 index 00000000000..68f95a47f49 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl @@ -0,0 +1,152 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "include/include_all.cl" + +#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order) +#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order) + +#if OUTPUT_DIMS == 4 + #define ORDER b,f,y,x +#elif OUTPUT_DIMS == 5 + #define ORDER b,f,z,y,x +#elif OUTPUT_DIMS == 6 + #define ORDER b,f,w,z,y,x +#endif + +KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data, + const __global INPUT1_TYPE* indices, + const __global INPUT2_TYPE* updates, + __global OUTPUT_TYPE* output +#if HAS_FUSED_OPS_DECLS + , FUSED_OPS_DECLS +#endif +) +{ + + const uint dim0 = get_global_id(0); + const uint dim1 = get_global_id(1); + const uint dim2 = get_global_id(2); + +#ifndef IS_SECOND_ITER // First kernel + const uint x = dim0 % OUTPUT_SIZE_X; + const uint y = dim0 / OUTPUT_SIZE_X; + const uint z = dim1 % OUTPUT_SIZE_Z; + const uint w = dim1 / OUTPUT_SIZE_Z; + const uint f = dim2 % OUTPUT_FEATURE_NUM; + const uint b = dim2 / OUTPUT_FEATURE_NUM; + + const uint output_idx = GET_OUTPUT_INDEX(ORDER); + INPUT0_TYPE val = data[output_idx]; + #if HAS_FUSED_OPS + FUSED_OPS_FIRST_KERNEL; + output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_FIRST_KERNEL); + #else + output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + #endif + +#else // Second kernel + + const uint blockND[] = {INPUT_BLOCK_ND}; + const uint k = INDICES_LAST_DIM; + const uint size_to_update = blockND[INDICES_LAST_DIM]; + const uint indices_idx = dim2; + const uint indices_offset = indices_idx * k; + uint dst_offset = 0; + + for (uint i = 0; i < k; i++) { + INPUT1_TYPE idxValue = indices[indices_offset + i]; + dst_offset += idxValue * blockND[i + 1]; + } + + uint update_offset = indices_idx * size_to_update; + + for (int i = 0; i < size_to_update; i++) { + uint dst_idx = dst_offset + i; + uint up_idx = update_offset + i; + INPUT2_TYPE val = updates[up_idx]; + + #if HAS_FUSED_OPS + #if OUTPUT_DIMS == 4 + const uint y_pitch = OUTPUT_SIZE_X; + const uint f_pitch = y_pitch * OUTPUT_SIZE_Y; + const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; + + const uint b_remain = dst_idx % b_pitch; + const uint f_remain = b_remain % f_pitch; + const uint y_remain = f_remain % y_pitch; + + const uint b = dst_idx / b_pitch; + const uint f = b_remain / f_pitch; + const uint y = f_remain / y_pitch; + const uint x = y_remain; + #elif OUTPUT_DIMS == 5 + const uint y_pitch = OUTPUT_SIZE_X; + const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; + const uint f_pitch = z_pitch * OUTPUT_SIZE_Z; + const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; + + const uint b_remain = dst_idx % b_pitch; + const uint f_remain = b_remain % f_pitch; + const uint z_remain = f_remain % z_pitch; + const uint y_remain = z_remain % y_pitch; + + const uint b = dst_idx / b_pitch; + const uint f = b_remain / f_pitch; + const uint z = f_remain / z_pitch; + const uint y = z_remain / y_pitch; + const uint x = y_remain; + #elif OUTPUT_DIMS == 6 + const uint y_pitch = OUTPUT_SIZE_X; + const uint z_pitch = y_pitch * OUTPUT_SIZE_Y; + const uint w_pitch = z_pitch * OUTPUT_SIZE_Z; + const uint f_pitch = w_pitch * OUTPUT_SIZE_W; + const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM; + + const uint b_remain = dst_idx % b_pitch; + const uint f_remain = b_remain % f_pitch; + const uint w_remain = f_remain % w_pitch; + const uint z_remain = w_remain % z_pitch; + const uint y_remain = z_remain % y_pitch; + + const uint b = dst_idx / b_pitch; + const uint f = b_remain / f_pitch; + const uint w = f_remain / w_pitch; + const uint z = w_remain / z_pitch; + const uint y = z_remain / y_pitch; + const uint x = y_remain; + #endif + + FUSED_OPS_SECOND_KERNEL; + output[dst_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL); + #else + output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS); + #endif + } +#endif + +} + +#ifdef GET_UPDATES_INDEX +#undef GET_UPDATES_INDEX +#endif + +#ifdef GET_OUTPUT_INDEX +#undef GET_OUTPUT_INDEX +#endif + +#ifdef ORDER +#undef ORDER +#endif diff --git a/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.cpp index fab9ed83113..c81d96d370b 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.cpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.cpp @@ -69,6 +69,7 @@ void register_implementations_gpu() { REGISTER_GPU(roi_pooling); REGISTER_GPU(scale); REGISTER_GPU(scatter_update); + REGISTER_GPU(scatter_nd_update); REGISTER_GPU(scatter_elements_update); REGISTER_GPU(select); REGISTER_GPU(shuffle_channels); diff --git a/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.hpp b/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.hpp index 37795e88812..a0283f15806 100644 --- a/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.hpp +++ b/inference-engine/thirdparty/clDNN/src/gpu/register_gpu.hpp @@ -62,6 +62,7 @@ #include "api/scale.hpp" #include "api/scatter_update.hpp" #include "api/scatter_elements_update.hpp" +#include "api/scatter_nd_update.hpp" #include "api/select.hpp" #include "api/shuffle_channels.hpp" #include "api/softmax.hpp" @@ -138,6 +139,7 @@ REGISTER_GPU(roi_pooling); REGISTER_GPU(scale); REGISTER_GPU(scatter_update); REGISTER_GPU(scatter_elements_update); +REGISTER_GPU(scatter_nd_update); REGISTER_GPU(select); REGISTER_GPU(shuffle_channels); REGISTER_GPU(softmax); diff --git a/inference-engine/thirdparty/clDNN/src/gpu/scatter_nd_update_gpu.cpp b/inference-engine/thirdparty/clDNN/src/gpu/scatter_nd_update_gpu.cpp new file mode 100644 index 00000000000..bd4f222cbde --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/gpu/scatter_nd_update_gpu.cpp @@ -0,0 +1,78 @@ +/* +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "scatter_nd_update_inst.h" +#include "primitive_gpu_base.h" +#include "implementation_map.h" +#include "kernel_selector_helper.h" +#include "scatter_update/scatter_nd_update_kernel_selector.h" +#include "scatter_update/scatter_nd_update_kernel_ref.h" +#include "error_handler.h" + +using namespace cldnn; + +namespace cldnn { +namespace gpu { + +struct scatter_nd_update_gpu : typed_primitive_gpu_impl { + using parent = typed_primitive_gpu_impl; + using parent::parent; + +public: + static primitive_impl* create(const scatter_nd_update_node& arg) { + auto scatter_nd_update_params = get_default_params(arg); + auto scatter_nd_update_optional_params = + get_default_optional_params(arg.get_program()); + + scatter_nd_update_params.indices_rank = arg.get_primitive()->indices_rank; + + scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout())); + scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout())); + + auto& kernel_selector = kernel_selector::scatter_nd_update_kernel_selector::Instance(); + auto best_kernels = kernel_selector.GetBestKernels(scatter_nd_update_params, scatter_nd_update_optional_params); + + CLDNN_ERROR_BOOL(arg.id(), + "Best_kernel.empty()", + best_kernels.empty(), + "Cannot find a proper kernel with this arguments"); + + auto scatter_nd_update = new scatter_nd_update_gpu(arg, best_kernels[0]); + + return scatter_nd_update; + } +}; + +namespace detail { + +attach_scatter_nd_update_gpu::attach_scatter_nd_update_gpu() { + auto val_fw = scatter_nd_update_gpu::create; + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw); + + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw); + implementation_map::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw); +} + +} // namespace detail +} // namespace gpu +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp index 742c86aa3e2..763c2a8bbfd 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/prepare_primitive_fusing.cpp @@ -45,6 +45,7 @@ #include "space_to_depth_inst.h" #include "gather_inst.h" #include "scatter_update_inst.h" +#include "scatter_nd_update_inst.h" #include "scatter_elements_update_inst.h" #include "reverse_sequence_inst.h" #include "shuffle_channels_inst.h" @@ -206,6 +207,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) { !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && + !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type() && !input.is_type())) @@ -540,6 +542,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -604,6 +608,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); + should_fuse |= input_data.is_type(); should_fuse |= input_data.is_type(); @@ -690,6 +696,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); + should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type() && quantize_node.get_scale_shift_opt(); @@ -745,6 +753,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) { (parents[i]->is_type()) || (parents[i]->is_type() && eltwise_supports_fusings(parents[i]->as())) || (parents[i]->is_type()) || + (parents[i]->is_type()) || (parents[i]->is_type()) || (parents[i]->is_type() && pooling_supports_fusings(parents[i]->as())) || (parents[i]->is_type() && dts_supports_fusings(parents[i]->as())) || diff --git a/inference-engine/thirdparty/clDNN/src/include/scatter_nd_update_inst.h b/inference-engine/thirdparty/clDNN/src/include/scatter_nd_update_inst.h new file mode 100644 index 00000000000..75bf2c04a05 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/include/scatter_nd_update_inst.h @@ -0,0 +1,49 @@ +/* +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma once +#include "api/scatter_nd_update.hpp" +#include "primitive_inst.h" +#include + +namespace cldnn { +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input(size_t index = 0) const { return get_dependency(index); } +}; + +using scatter_nd_update_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + +public: + static layout calc_output_layout(scatter_nd_update_node const& node); + static std::string to_string(scatter_nd_update_node const& node); + +public: + typed_primitive_inst(network_impl& network, scatter_nd_update_node const& desc); +}; + +using scatter_nd_update_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/src/scatter_nd_update.cpp b/inference-engine/thirdparty/clDNN/src/scatter_nd_update.cpp new file mode 100644 index 00000000000..46d2defb0c8 --- /dev/null +++ b/inference-engine/thirdparty/clDNN/src/scatter_nd_update.cpp @@ -0,0 +1,66 @@ +/* +// Copyright (c) 2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +*/ + +#include "scatter_nd_update_inst.h" + +#include "primitive_type_base.h" +#include "error_handler.h" +#include "json_object.h" +#include + +namespace cldnn { +primitive_type_id scatter_nd_update::type_id() { + static primitive_type_base instance; + return &instance; +} + + +layout scatter_nd_update_inst::calc_output_layout(scatter_nd_update_node const& node) { + auto input_layout = node.input(0).get_output_layout(); + + auto output_shape = input_layout.size; + auto input_format = input_layout.format; + auto output_type = input_layout.data_type; + + if (node.has_fused_primitives()) { + output_type = node.get_fused_output_layout().data_type; + } + + return layout{output_type, input_format, output_shape}; +} + +std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + auto& input = node.input(); + + std::stringstream primitive_description; + + json_composite scatter_nd_update_info; + scatter_nd_update_info.add("input id", input.id()); + scatter_nd_update_info.add("input shape", node.input(0).get_output_layout().size.to_string()); + scatter_nd_update_info.add("indices shape", node.input(1).get_output_layout().size.to_string()); + scatter_nd_update_info.add("updates shape", node.input(2).get_output_layout().size.to_string()); + + node_info->add("scatter_nd_update info", scatter_nd_update_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +scatter_nd_update_inst::typed_primitive_inst(network_impl& network, scatter_nd_update_node const& node) : parent(network, node) {} + +} // namespace cldnn diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp index 42ffd7e66e4..bde4906ff88 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/fusings_gpu_test.cpp @@ -35,6 +35,7 @@ #include "api/permute.hpp" #include "api/gather.hpp" #include "api/scatter_update.hpp" +#include "api/scatter_nd_update.hpp" #include "api/scatter_elements_update.hpp" #include "api/depth_to_space.hpp" #include "api/space_to_depth.hpp" @@ -7569,3 +7570,224 @@ INSTANTIATE_TEST_CASE_P(DISABLED_fusings_gpu, reduce_test_params{CASE_REDUCE_I8_3, 2, 4, reduce_mode::mean, {reduce::along_x}, true, "reduce_ref"}, reduce_test_params{CASE_REDUCE_U8_3, 2, 4, reduce_mode::l2, {reduce::along_x}, true, "reduce_ref"} }), ); + + +/* ----------------------------------------------------------------------------------------------------- */ +/* ------------------------------------------ ScatterNDUpdate cases ------------------------------ */ +/* ----------------------------------------------------------------------------------------------------- */ +struct scatter_nd_update_test_params { + tensor input_shape; + tensor indices_shape; + tensor updates_shape; + int max_number_in_indices; + int indices_rank; + data_types data_type; + format input_format; + data_types default_type; + format default_format; + size_t expected_fused_primitives; + size_t expected_not_fused_primitives; +}; + +#define CASE_SCATTER_ND_UPDATE_FP16_4D_1 {6, 1, 1, 1}, {3, 1, 1, 1}, {3, 1, 1, 1}, 6, 1, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_4D_2 {6, 6, 1, 1}, {3, 2, 1, 1}, {3, 1, 1, 1}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_4D_3 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_4D_4 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_4D_5 {6, 7, 8, 9}, {6, 2, 1, 1}, {6, 9, 1, 8}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_4D_6 {6, 7, 8, 9}, {6, 3, 1, 1}, {6, 8, 1, 1}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP16_5D_1 {6, 7, 8, 9, 10}, {5, 1, 1, 1, 1}, {5, 7, 8, 9, 10}, 6, 1, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_2 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 1}, {5, 10, 1, 8, 9}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_3 {6, 7, 8, 9, 10}, {5, 3, 1, 1, 1}, {5, 9, 1, 1, 8}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_4 {6, 7, 8, 9, 10}, {5, 4, 1, 1, 1}, {5, 8, 1, 1, 1}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_5 {6, 7, 8, 9, 10}, {5, 5, 1, 1, 1}, {5, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_6 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 2}, {5, 2, 8, 9, 10}, 6, 3, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_7 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 3}, {5, 2, 1, 8, 9}, 6, 3, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_8 {6, 7, 8, 9, 10}, {5, 2, 1, 4, 3}, {5, 2, 1, 8, 3}, 6, 4, data_types::f16, format::bfzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_5D_9 {6, 7, 8, 9, 10}, {5, 2, 1, 3, 3}, {5, 2, 8, 9, 3}, 6, 4, data_types::f16, format::bfzyx, data_types::f16, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP16_6D_1 {6, 7, 8, 9, 10, 11}, {5, 1, 1, 1}, {5, 7, 8, 9, 10, 11}, 6, 1, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_6D_2 {6, 7, 8, 9, 10, 11}, {5, 2, 1, 1}, {5, 11, 1, 8, 9, 10}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_6D_3 {6, 7, 8, 9, 10, 11}, {5, 3, 1, 1}, {5, 10, 1, 1, 8, 9}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_6D_4 {6, 7, 8, 9, 10, 11}, {5, 4, 1, 1}, {5, 9, 1, 1, 1, 8}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_6D_5 {6, 7, 8, 9, 2, 2}, {5, 5, 1, 1}, {5, 8, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP16_6D_6 {6, 7, 8, 9, 2, 2}, {5, 6, 1, 1}, {5, 1, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP32_4D_1 {6, 1, 1, 1}, {3, 1, 1, 1}, {3, 1, 1, 1}, 6, 1, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_4D_2 {6, 6, 1, 1}, {3, 2, 1, 1}, {3, 1, 1, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_4D_3 {6, 7, 8, 1}, {5, 1, 1, 1}, {5, 7, 8, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_4D_4 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_4D_5 {6, 7, 8, 9}, {6, 2, 1, 1}, {6, 9, 1, 8}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_4D_6 {6, 7, 8, 9}, {6, 3, 1, 1}, {6, 8, 1, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP32_5D_1 {6, 7, 8, 9, 10}, {5, 1, 1, 1, 1}, {5, 7, 8, 9, 10}, 6, 1, data_types::f32, format::bfzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_5D_2 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 1}, {5, 10, 1, 8, 9}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_5D_3 {6, 7, 8, 9, 10}, {5, 3, 1, 1, 1}, {5, 9, 1, 1, 8}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_5D_4 {6, 7, 8, 9, 10}, {5, 4, 1, 1, 1}, {5, 8, 1, 1, 1}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_5D_5 {6, 7, 8, 9, 10}, {5, 5, 1, 1, 1}, {5, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx + +#define CASE_SCATTER_ND_UPDATE_FP32_6D_1 {6, 7, 8, 9, 10, 11}, {5, 1, 1, 1}, {5, 7, 8, 9, 10, 11}, 6, 1, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_6D_2 {6, 7, 8, 9, 10, 11}, {5, 2, 1, 1}, {5, 11, 1, 8, 9, 10}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_6D_3 {6, 7, 8, 9, 10, 11}, {5, 3, 1, 1}, {5, 10, 1, 1, 8, 9}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_6D_4 {6, 7, 8, 9, 10, 11}, {5, 4, 1, 1}, {5, 9, 1, 1, 1, 8}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_6D_5 {6, 7, 8, 9, 2, 2}, {5, 5, 1, 1}, {5, 8, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx +#define CASE_SCATTER_ND_UPDATE_FP32_6D_6 {6, 7, 8, 9, 2, 2}, {5, 6, 1, 1}, {5, 1, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx + +class ScatterNDUpdatePrimitiveFusingTest : public ::BaseFusingTest { +public: + void execute(scatter_nd_update_test_params& p) { + auto input_prim = get_mem(get_input_layout(p)); + network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused); + network network_fused(this->engine, this->topology_fused, bo_fused); + network_fused.set_input_data("input", input_prim); + network_not_fused.set_input_data("input", input_prim); + compare(network_not_fused, network_fused, p); + } + + layout get_input_layout(scatter_nd_update_test_params& p) { + return layout{ p.data_type, p.input_format, p.input_shape }; + } + + layout get_indices_layout(scatter_nd_update_test_params& p) { + return layout{ p.data_type, p.input_format, p.indices_shape }; + } + + layout get_updates_layout(scatter_nd_update_test_params& p) { + return layout{ p.data_type, p.input_format, p.updates_shape }; + } + + layout get_per_channel_layout(scatter_nd_update_test_params& p) { + return layout{ p.default_type, p.default_format, tensor{1, p.input_shape.feature[0], 1, 1} }; + } +}; + +class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {}; +TEST_P(scatter_nd_update_quantize, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)), + data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)), + data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)), + data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), + data("out_lo", get_mem(get_single_element_layout(p), -127)), + data("out_hi", get_mem(get_single_element_layout(p), 127)), + scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank), + quantize("quantize", "scatter_nd_update_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), + reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) + ); + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_quantize, + ::testing::ValuesIn(std::vector{ + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_6, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_5, 2, 3 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_1, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_2, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_3, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 }, +}), ); + +class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {}; +TEST_P(scatter_nd_update_scale_activation_eltwise, basic) { + auto p = GetParam(); + create_topologies(input_layout("input", get_input_layout(p)), + data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)), + data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)), + data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)), + data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.input_shape })), + scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank), + activation("activation", "scatter_nd_update_prim", activation_func::abs), + scale("scale", "activation", "scale_data"), + eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type), + reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32) + ); + + tolerance = 1.f; + execute(p); +} + +INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise, + ::testing::ValuesIn(std::vector{ + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_6, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_5, 2, 5 }, + + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_1, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_2, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_3, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 }, + scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 }, +}), ); diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/scatter_nd_update_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/scatter_nd_update_gpu_test.cpp new file mode 100644 index 00000000000..579e7276d0d --- /dev/null +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/scatter_nd_update_gpu_test.cpp @@ -0,0 +1,3927 @@ +// Copyright (c) 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace cldnn; +using namespace ::tests; + + +TEST(scatter_nd_update_gpu_fp16_test15, data5_indice3_update5) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 2, 4, 3 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 2, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 2, 4, 3, 2 } }); // updates + + set_values(input1, { + // 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), + FLOAT16(0.0f), + }); + + set_values(input3, { + // 0 + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + // 1 + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + }); + + std::vector expected_results = { + // 0 + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + // 1 + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + + FLOAT16(91.0f), FLOAT16(2.0f), FLOAT16(83.0f), FLOAT16(4.0f), FLOAT16(71.0f), FLOAT16(2.0f), FLOAT16(63.0f), FLOAT16(4.0f), + FLOAT16(95.0f), FLOAT16(6.0f), FLOAT16(87.0f), FLOAT16(8.0f), FLOAT16(75.0f), FLOAT16(6.0f), FLOAT16(67.0f), FLOAT16(8.0f), + FLOAT16(99.0f), FLOAT16(10.0f), FLOAT16(811.0f), FLOAT16(12.0f), FLOAT16(79.0f), FLOAT16(10.0f), FLOAT16(611.0f), FLOAT16(12.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 3) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test14, data5_indice2_update3) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 2, 4, 3 } }); // data 2x2x3x4x2 (bfzyx) + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 4, 1, 1, 2 } }); // updates + + set_values(input1, { + // 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(2.0f), + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(0.0f), + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), + }); + + set_values(input3, { + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f), + FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f), + }); + + std::vector expected_results = { + // 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), FLOAT16(75.0f), FLOAT16(76.0f), FLOAT16(77.0f), FLOAT16(78.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), FLOAT16(67.0f), FLOAT16(68.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test13, data4_indice2_update2) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 4, 2 } }); // data 2x3x2x4 (bfyx) + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 4, 1, 1 } }); // updates + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(0.0f), + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(1.0f), + FLOAT16(0.0f), FLOAT16(2.0f), FLOAT16(1.0f), + }); + + set_values(input3, { + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), + FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), + FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }); + + std::vector expected_results = { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test12, data3_indice3_update1) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 1, 4 } }); // data 3x3x4 (bfy) + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 4, 3, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 4, 1, 1, 1 } }); // updates + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(2.0f), FLOAT16(0.0f), FLOAT16(0.0f), + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(0.0f), + }); + + set_values(input3, { + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), + }); + + std::vector expected_results = { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(54.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(53.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(52.0f), + + FLOAT16(51.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test11, data6_indice1_update6) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 3, 4, 2 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 1, 1, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 2, 3, 4, 2 } }); // updates + + set_values(input1, { + // 0, 0, 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + // 0, 0, 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 0, 1, 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + + // 1, 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 1, 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), + }); + + set_values(input3, { + // 0 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(50.0f), FLOAT16(51.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + }); + + std::vector expected_results = { + // 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 1 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(50.0f), FLOAT16(51.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test10, data5_indice1_update5) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 3, 4, 2 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 3, 4, 2 } }); // updates + + set_values(input1, { + // 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + // 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), FLOAT16(0.0f), + }); + + set_values(input3, { + // 0 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(50.0f), FLOAT16(51.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + // 1 + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + }); + + std::vector expected_results = { + // 0 + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(150.0f), FLOAT16(151.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + // 1 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(50.0f), FLOAT16(51.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test9, data4_indice1_update4) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 4, 2 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 4, 2 } }); // updates + + set_values(input1, { + // 0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + // 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + // 2 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(2.0f), FLOAT16(0.0f), + }); + + set_values(input3, { + // 0 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + // 1 + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + }); + + std::vector expected_results = { + // 0 + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + // 1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + // 2 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test8, data6_indice2_update5) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 1, 2, 2, 3, 4, 2 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 1, 3, 4, 2 } }); // updates + + set_values(input1, { + //0,0 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + //0,1 + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), + FLOAT16(0.0f), FLOAT16(0.0f) + }); + + set_values(input3, { + // 0 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + + // 1 + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + }); + + std::vector expected_results = { + // 0,0 + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + // 0,1 + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test7, data5_indice2_update4) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 1, 2, 3, 4, 2 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 1, 3, 4 } }); // updates + + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), + FLOAT16(0.0f), FLOAT16(0.0f) + }); + + set_values(input3, { + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + }); + + std::vector expected_results = { + FLOAT16(151.0f), FLOAT16(152.0f), FLOAT16(153.0f), FLOAT16(154.0f), FLOAT16(155.0f), FLOAT16(156.0f), FLOAT16(157.0f), FLOAT16(158.0f), + FLOAT16(159.0f), FLOAT16(160.0f), FLOAT16(161.0f), FLOAT16(162.0f), FLOAT16(163.0f), FLOAT16(164.0f), FLOAT16(165.0f), FLOAT16(166.0f), + FLOAT16(167.0f), FLOAT16(168.0f), FLOAT16(169.0f), FLOAT16(170.0f), FLOAT16(171.0f), FLOAT16(172.0f), FLOAT16(173.0f), FLOAT16(174.0f), + + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16_test6, data4_indice2_update3) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 4, 2 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 4, 1, 2 } }); // updates + + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(1.0f), FLOAT16(0.0f), + FLOAT16(0.0f), FLOAT16(2.0f) + }); + + set_values(input3, { + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }); + + std::vector expected_results = { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(67.0f), FLOAT16(68.0f), FLOAT16(69.0f), FLOAT16(70.0f), FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(59.0f), FLOAT16(60.0f), FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), FLOAT16(65.0f), FLOAT16(66.0f), + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), FLOAT16(55.0f), FLOAT16(56.0f), FLOAT16(57.0f), FLOAT16(58.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test5, data3_indice2_update2) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 4 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 4, 1, 1 } }); // updates + + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(1.0f), FLOAT16(0.0f), + FLOAT16(0.0f), FLOAT16(2.0f) + }); + + set_values(input3, { + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), + FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), + FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + }); + + std::vector expected_results = { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(71.0f), FLOAT16(72.0f), FLOAT16(73.0f), FLOAT16(74.0f), + + FLOAT16(61.0f), FLOAT16(62.0f), FLOAT16(63.0f), FLOAT16(64.0f), + FLOAT16(51.0f), FLOAT16(52.0f), FLOAT16(53.0f), FLOAT16(54.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test4, data2_indice2_update1) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 4, 1, 1 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 1, 1, 1 } }); // updates + + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(2.0f), FLOAT16(1.0f), + FLOAT16(0.0f), FLOAT16(3.0f), + FLOAT16(0.0f), FLOAT16(2.0f) + }); + + set_values(input3, { + FLOAT16(21.0f), FLOAT16(22.0f), FLOAT16(23.0f) + }); + + std::vector expected_results = { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(23.0f), FLOAT16(22.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(21.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test3, data3_indice1_update3) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 4, 1 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 4, 1 } }); // updates + + + set_values(input1, { + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + }); + + set_values(input2, { + FLOAT16(2.0f), FLOAT16(0.0f) + }); + + set_values(input3, { + FLOAT16(21.0f), FLOAT16(22.0f), FLOAT16(23.0f), FLOAT16(24.0f), + FLOAT16(25.0f), FLOAT16(26.0f), FLOAT16(27.0f), FLOAT16(28.0f), + FLOAT16(29.0f), FLOAT16(30.0f), FLOAT16(31.0f), FLOAT16(32.0f), + + FLOAT16(41.0f), FLOAT16(42.0f), FLOAT16(43.0f), FLOAT16(44.0f), + FLOAT16(45.0f), FLOAT16(46.0f), FLOAT16(47.0f), FLOAT16(48.0f), + FLOAT16(49.0f), FLOAT16(50.0f), FLOAT16(51.0f), FLOAT16(52.0f), + }); + + std::vector expected_results = { + FLOAT16(41.0f), FLOAT16(42.0f), FLOAT16(43.0f), FLOAT16(44.0f), + FLOAT16(45.0f), FLOAT16(46.0f), FLOAT16(47.0f), FLOAT16(48.0f), + FLOAT16(49.0f), FLOAT16(50.0f), FLOAT16(51.0f), FLOAT16(52.0f), + + FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), + FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f), + FLOAT16(9.0f), FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), + + FLOAT16(21.0f), FLOAT16(22.0f), FLOAT16(23.0f), FLOAT16(24.0f), + FLOAT16(25.0f), FLOAT16(26.0f), FLOAT16(27.0f), FLOAT16(28.0f), + FLOAT16(29.0f), FLOAT16(30.0f), FLOAT16(31.0f), FLOAT16(32.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16_test2, data2_indice1_update2) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 4, 1, 1 } }); // data + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 4, 1, 1 } }); // updates + + + set_values(input1, { + FLOAT16(13.0f), FLOAT16(12.0f), FLOAT16(11.0f), FLOAT16(10.0f), + FLOAT16(9.0f), FLOAT16(8.0f), FLOAT16(7.0f), FLOAT16(6.0f), + FLOAT16(5.0f), FLOAT16(4.0f), FLOAT16(3.0f), FLOAT16(2.0f) + }); + + set_values(input2, { + FLOAT16(2.0f), FLOAT16(0.0f) + }); + + set_values(input3, { + FLOAT16(20.0f), FLOAT16(21.0f), FLOAT16(22.0f), FLOAT16(23.0f), + FLOAT16(24.0f), FLOAT16(25.0f), FLOAT16(26.0f), FLOAT16(27.0f) + }); + + std::vector expected_results = { + FLOAT16(24.0f), FLOAT16(25.0f), FLOAT16(26.0f), FLOAT16(27.0f), + FLOAT16(9.0f), FLOAT16(8.0f), FLOAT16(7.0f), FLOAT16(6.0f), + FLOAT16(20.0f), FLOAT16(21.0f), FLOAT16(22.0f), FLOAT16(23.0f), + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16_test1, data1_indice1_update1) { + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 8, 1, 1, 1 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 4, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 4, 1, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(9.0f), FLOAT16(8.0f), FLOAT16(7.0f), FLOAT16(6.0f), FLOAT16(5.0f), FLOAT16(4.0f), FLOAT16(3.0f), FLOAT16(2.0f) + }); + + set_values(input2, { + FLOAT16(2.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(7.0f) + }); + + set_values(input3, { + FLOAT16(10.0f), FLOAT16(11.0f), FLOAT16(12.0f), FLOAT16(13.0f) + }); + + std::vector expected_results = { + 9.f, 8.f, 10.f, 6.f, 11.f, 12.f, 3.f, 13.f + }; + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + + +TEST(scatter_nd_update_gpu_fp16, d6661_i2311) { + // Dictionary : 6x6x6x1 + // Indexes : 2x3x1x1 + // Updates : 2x1x1x1 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 6, 6, 1, 6 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // Updates + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(136.f), FLOAT16(137.f), FLOAT16(138.f), FLOAT16(139.f), FLOAT16(140.f), FLOAT16(141.f), + FLOAT16(142.f), FLOAT16(143.f), FLOAT16(144.f), FLOAT16(145.f), FLOAT16(146.f), FLOAT16(147.f), + FLOAT16(148.f), FLOAT16(149.f), FLOAT16(150.f), FLOAT16(151.f), FLOAT16(152.f), FLOAT16(153.f), + FLOAT16(154.f), FLOAT16(155.f), FLOAT16(156.f), FLOAT16(157.f), FLOAT16(158.f), FLOAT16(159.f), + FLOAT16(160.f), FLOAT16(161.f), FLOAT16(162.f), FLOAT16(163.f), FLOAT16(164.f), FLOAT16(165.f), + FLOAT16(166.f), FLOAT16(167.f), FLOAT16(168.f), FLOAT16(169.f), FLOAT16(170.f), FLOAT16(171.f), + + FLOAT16(172.f), FLOAT16(173.f), FLOAT16(174.f), FLOAT16(175.f), FLOAT16(176.f), FLOAT16(177.f), + FLOAT16(178.f), FLOAT16(179.f), FLOAT16(180.f), FLOAT16(181.f), FLOAT16(182.f), FLOAT16(183.f), + FLOAT16(184.f), FLOAT16(185.f), FLOAT16(186.f), FLOAT16(187.f), FLOAT16(188.f), FLOAT16(189.f), + FLOAT16(190.f), FLOAT16(191.f), FLOAT16(192.f), FLOAT16(193.f), FLOAT16(194.f), FLOAT16(195.f), + FLOAT16(196.f), FLOAT16(197.f), FLOAT16(198.f), FLOAT16(199.f), FLOAT16(200.f), FLOAT16(201.f), + FLOAT16(202.f), FLOAT16(203.f), FLOAT16(204.f), FLOAT16(205.f), FLOAT16(206.f), FLOAT16(207.f), + + FLOAT16(208.f), FLOAT16(209.f), FLOAT16(210.f), FLOAT16(211.f), FLOAT16(212.f), FLOAT16(213.f), + FLOAT16(214.f), FLOAT16(215.f), FLOAT16(216.f), FLOAT16(217.f), FLOAT16(218.f), FLOAT16(219.f), + FLOAT16(220.f), FLOAT16(221.f), FLOAT16(222.f), FLOAT16(223.f), FLOAT16(224.f), FLOAT16(225.f), + FLOAT16(226.f), FLOAT16(227.f), FLOAT16(228.f), FLOAT16(229.f), FLOAT16(230.f), FLOAT16(231.f), + FLOAT16(232.f), FLOAT16(233.f), FLOAT16(234.f), FLOAT16(235.f), FLOAT16(236.f), FLOAT16(237.f), + FLOAT16(238.f), FLOAT16(239.f), FLOAT16(240.f), FLOAT16(241.f), FLOAT16(242.f), FLOAT16(243.f), + + FLOAT16(244.f), FLOAT16(245.f), FLOAT16(246.f), FLOAT16(247.f), FLOAT16(248.f), FLOAT16(249.f), + FLOAT16(250.f), FLOAT16(251.f), FLOAT16(252.f), FLOAT16(253.f), FLOAT16(254.f), FLOAT16(255.f), + FLOAT16(256.f), FLOAT16(257.f), FLOAT16(258.f), FLOAT16(259.f), FLOAT16(260.f), FLOAT16(261.f), + FLOAT16(262.f), FLOAT16(263.f), FLOAT16(264.f), FLOAT16(265.f), FLOAT16(266.f), FLOAT16(267.f), + FLOAT16(268.f), FLOAT16(269.f), FLOAT16(270.f), FLOAT16(271.f), FLOAT16(272.f), FLOAT16(273.f), + FLOAT16(274.f), FLOAT16(275.f), FLOAT16(276.f), FLOAT16(277.f), FLOAT16(278.f), FLOAT16(279.f), + + FLOAT16(280.f), FLOAT16(281.f), FLOAT16(282.f), FLOAT16(283.f), FLOAT16(284.f), FLOAT16(285.f), + FLOAT16(286.f), FLOAT16(287.f), FLOAT16(288.f), FLOAT16(289.f), FLOAT16(290.f), FLOAT16(291.f), + FLOAT16(292.f), FLOAT16(293.f), FLOAT16(294.f), FLOAT16(295.f), FLOAT16(296.f), FLOAT16(297.f), + FLOAT16(298.f), FLOAT16(299.f), FLOAT16(300.f), FLOAT16(301.f), FLOAT16(302.f), FLOAT16(303.f), + FLOAT16(304.f), FLOAT16(305.f), FLOAT16(306.f), FLOAT16(307.f), FLOAT16(308.f), FLOAT16(309.f), + FLOAT16(310.f), FLOAT16(311.f), FLOAT16(312.f), FLOAT16(313.f), FLOAT16(314.f), FLOAT16(315.f), + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(2.0f), + FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f) + }); + + set_values(input3, { + FLOAT16(999.0f), FLOAT16(888.0f) + }); + + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, + 106.f, 107.f, 999.f, 109.f, 110.f, 111.f, + 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, + 118.f, 119.f, 120.f, 121.f, 122.f, 123.f, + 124.f, 125.f, 126.f, 127.f, 128.f, 129.f, + 130.f, 131.f, 132.f, 133.f, 134.f, 135.f, + + 136.f, 137.f, 138.f, 139.f, 140.f, 141.f, + 142.f, 143.f, 144.f, 145.f, 146.f, 147.f, + 148.f, 149.f, 150.f, 151.f, 152.f, 153.f, + 154.f, 155.f, 156.f, 157.f, 158.f, 159.f, + 160.f, 161.f, 162.f, 163.f, 164.f, 165.f, + 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, + + 172.f, 173.f, 174.f, 175.f, 176.f, 177.f, + 178.f, 179.f, 180.f, 181.f, 182.f, 183.f, + 184.f, 185.f, 186.f, 187.f, 188.f, 189.f, + 190.f, 191.f, 192.f, 193.f, 194.f, 195.f, + 196.f, 197.f, 198.f, 199.f, 200.f, 201.f, + 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, + + 208.f, 209.f, 210.f, 211.f, 212.f, 213.f, + 214.f, 215.f, 216.f, 217.f, 218.f, 219.f, + 220.f, 221.f, 222.f, 223.f, 224.f, 225.f, + 226.f, 227.f, 228.f, 229.f, 230.f, 231.f, + 232.f, 233.f, 234.f, 235.f, 236.f, 888.f, + 238.f, 239.f, 240.f, 241.f, 242.f, 243.f, + + 244.f, 245.f, 246.f, 247.f, 248.f, 249.f, + 250.f, 251.f, 252.f, 253.f, 254.f, 255.f, + 256.f, 257.f, 258.f, 259.f, 260.f, 261.f, + 262.f, 263.f, 264.f, 265.f, 266.f, 267.f, + 268.f, 269.f, 270.f, 271.f, 272.f, 273.f, + 274.f, 275.f, 276.f, 277.f, 278.f, 279.f, + + 280.f, 281.f, 282.f, 283.f, 284.f, 285.f, + 286.f, 287.f, 288.f, 289.f, 290.f, 291.f, + 292.f, 293.f, 294.f, 295.f, 296.f, 297.f, + 298.f, 299.f, 300.f, 301.f, 302.f, 303.f, + 304.f, 305.f, 306.f, 307.f, 308.f, 309.f, + 310.f, 311.f, 312.f, 313.f, 314.f, 315.f, + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16, d6661_i2211) { + // Dictionary : 6x6x6x1 + // Indexes : 2x2x1x1 + // Updates : 2x6x1x1 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 6, 6, 1, 6 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 6, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(136.f), FLOAT16(137.f), FLOAT16(138.f), FLOAT16(139.f), FLOAT16(140.f), FLOAT16(141.f), + FLOAT16(142.f), FLOAT16(143.f), FLOAT16(144.f), FLOAT16(145.f), FLOAT16(146.f), FLOAT16(147.f), + FLOAT16(148.f), FLOAT16(149.f), FLOAT16(150.f), FLOAT16(151.f), FLOAT16(152.f), FLOAT16(153.f), + FLOAT16(154.f), FLOAT16(155.f), FLOAT16(156.f), FLOAT16(157.f), FLOAT16(158.f), FLOAT16(159.f), + FLOAT16(160.f), FLOAT16(161.f), FLOAT16(162.f), FLOAT16(163.f), FLOAT16(164.f), FLOAT16(165.f), + FLOAT16(166.f), FLOAT16(167.f), FLOAT16(168.f), FLOAT16(169.f), FLOAT16(170.f), FLOAT16(171.f), + + FLOAT16(172.f), FLOAT16(173.f), FLOAT16(174.f), FLOAT16(175.f), FLOAT16(176.f), FLOAT16(177.f), + FLOAT16(178.f), FLOAT16(179.f), FLOAT16(180.f), FLOAT16(181.f), FLOAT16(182.f), FLOAT16(183.f), + FLOAT16(184.f), FLOAT16(185.f), FLOAT16(186.f), FLOAT16(187.f), FLOAT16(188.f), FLOAT16(189.f), + FLOAT16(190.f), FLOAT16(191.f), FLOAT16(192.f), FLOAT16(193.f), FLOAT16(194.f), FLOAT16(195.f), + FLOAT16(196.f), FLOAT16(197.f), FLOAT16(198.f), FLOAT16(199.f), FLOAT16(200.f), FLOAT16(201.f), + FLOAT16(202.f), FLOAT16(203.f), FLOAT16(204.f), FLOAT16(205.f), FLOAT16(206.f), FLOAT16(207.f), + + FLOAT16(208.f), FLOAT16(209.f), FLOAT16(210.f), FLOAT16(211.f), FLOAT16(212.f), FLOAT16(213.f), + FLOAT16(214.f), FLOAT16(215.f), FLOAT16(216.f), FLOAT16(217.f), FLOAT16(218.f), FLOAT16(219.f), + FLOAT16(220.f), FLOAT16(221.f), FLOAT16(222.f), FLOAT16(223.f), FLOAT16(224.f), FLOAT16(225.f), + FLOAT16(226.f), FLOAT16(227.f), FLOAT16(228.f), FLOAT16(229.f), FLOAT16(230.f), FLOAT16(231.f), + FLOAT16(232.f), FLOAT16(233.f), FLOAT16(234.f), FLOAT16(235.f), FLOAT16(236.f), FLOAT16(237.f), + FLOAT16(238.f), FLOAT16(239.f), FLOAT16(240.f), FLOAT16(241.f), FLOAT16(242.f), FLOAT16(243.f), + + FLOAT16(244.f), FLOAT16(245.f), FLOAT16(246.f), FLOAT16(247.f), FLOAT16(248.f), FLOAT16(249.f), + FLOAT16(250.f), FLOAT16(251.f), FLOAT16(252.f), FLOAT16(253.f), FLOAT16(254.f), FLOAT16(255.f), + FLOAT16(256.f), FLOAT16(257.f), FLOAT16(258.f), FLOAT16(259.f), FLOAT16(260.f), FLOAT16(261.f), + FLOAT16(262.f), FLOAT16(263.f), FLOAT16(264.f), FLOAT16(265.f), FLOAT16(266.f), FLOAT16(267.f), + FLOAT16(268.f), FLOAT16(269.f), FLOAT16(270.f), FLOAT16(271.f), FLOAT16(272.f), FLOAT16(273.f), + FLOAT16(274.f), FLOAT16(275.f), FLOAT16(276.f), FLOAT16(277.f), FLOAT16(278.f), FLOAT16(279.f), + + FLOAT16(280.f), FLOAT16(281.f), FLOAT16(282.f), FLOAT16(283.f), FLOAT16(284.f), FLOAT16(285.f), + FLOAT16(286.f), FLOAT16(287.f), FLOAT16(288.f), FLOAT16(289.f), FLOAT16(290.f), FLOAT16(291.f), + FLOAT16(292.f), FLOAT16(293.f), FLOAT16(294.f), FLOAT16(295.f), FLOAT16(296.f), FLOAT16(297.f), + FLOAT16(298.f), FLOAT16(299.f), FLOAT16(300.f), FLOAT16(301.f), FLOAT16(302.f), FLOAT16(303.f), + FLOAT16(304.f), FLOAT16(305.f), FLOAT16(306.f), FLOAT16(307.f), FLOAT16(308.f), FLOAT16(309.f), + FLOAT16(310.f), FLOAT16(311.f), FLOAT16(312.f), FLOAT16(313.f), FLOAT16(314.f), FLOAT16(315.f), + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), + FLOAT16(3.0f), FLOAT16(4.0f), + }); + + set_values(input3, { + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, + 999.f, 999.f, 999.f, 999.f, 999.f, 999.f, + 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, + 118.f, 119.f, 120.f, 121.f, 122.f, 123.f, + 124.f, 125.f, 126.f, 127.f, 128.f, 129.f, + 130.f, 131.f, 132.f, 133.f, 134.f, 135.f, + + 136.f, 137.f, 138.f, 139.f, 140.f, 141.f, + 142.f, 143.f, 144.f, 145.f, 146.f, 147.f, + 148.f, 149.f, 150.f, 151.f, 152.f, 153.f, + 154.f, 155.f, 156.f, 157.f, 158.f, 159.f, + 160.f, 161.f, 162.f, 163.f, 164.f, 165.f, + 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, + + 172.f, 173.f, 174.f, 175.f, 176.f, 177.f, + 178.f, 179.f, 180.f, 181.f, 182.f, 183.f, + 184.f, 185.f, 186.f, 187.f, 188.f, 189.f, + 190.f, 191.f, 192.f, 193.f, 194.f, 195.f, + 196.f, 197.f, 198.f, 199.f, 200.f, 201.f, + 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, + + 208.f, 209.f, 210.f, 211.f, 212.f, 213.f, + 214.f, 215.f, 216.f, 217.f, 218.f, 219.f, + 220.f, 221.f, 222.f, 223.f, 224.f, 225.f, + 226.f, 227.f, 228.f, 229.f, 230.f, 231.f, + 888.f, 888.f, 888.f, 888.f, 888.f, 888.f, + 238.f, 239.f, 240.f, 241.f, 242.f, 243.f, + + 244.f, 245.f, 246.f, 247.f, 248.f, 249.f, + 250.f, 251.f, 252.f, 253.f, 254.f, 255.f, + 256.f, 257.f, 258.f, 259.f, 260.f, 261.f, + 262.f, 263.f, 264.f, 265.f, 266.f, 267.f, + 268.f, 269.f, 270.f, 271.f, 272.f, 273.f, + 274.f, 275.f, 276.f, 277.f, 278.f, 279.f, + + 280.f, 281.f, 282.f, 283.f, 284.f, 285.f, + 286.f, 287.f, 288.f, 289.f, 290.f, 291.f, + 292.f, 293.f, 294.f, 295.f, 296.f, 297.f, + 298.f, 299.f, 300.f, 301.f, 302.f, 303.f, + 304.f, 305.f, 306.f, 307.f, 308.f, 309.f, + 310.f, 311.f, 312.f, 313.f, 314.f, 315.f, + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16, d6661_i2111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 6, 6, 1, 6 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 6, 1, 6 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(136.f), FLOAT16(137.f), FLOAT16(138.f), FLOAT16(139.f), FLOAT16(140.f), FLOAT16(141.f), + FLOAT16(142.f), FLOAT16(143.f), FLOAT16(144.f), FLOAT16(145.f), FLOAT16(146.f), FLOAT16(147.f), + FLOAT16(148.f), FLOAT16(149.f), FLOAT16(150.f), FLOAT16(151.f), FLOAT16(152.f), FLOAT16(153.f), + FLOAT16(154.f), FLOAT16(155.f), FLOAT16(156.f), FLOAT16(157.f), FLOAT16(158.f), FLOAT16(159.f), + FLOAT16(160.f), FLOAT16(161.f), FLOAT16(162.f), FLOAT16(163.f), FLOAT16(164.f), FLOAT16(165.f), + FLOAT16(166.f), FLOAT16(167.f), FLOAT16(168.f), FLOAT16(169.f), FLOAT16(170.f), FLOAT16(171.f), + + FLOAT16(172.f), FLOAT16(173.f), FLOAT16(174.f), FLOAT16(175.f), FLOAT16(176.f), FLOAT16(177.f), + FLOAT16(178.f), FLOAT16(179.f), FLOAT16(180.f), FLOAT16(181.f), FLOAT16(182.f), FLOAT16(183.f), + FLOAT16(184.f), FLOAT16(185.f), FLOAT16(186.f), FLOAT16(187.f), FLOAT16(188.f), FLOAT16(189.f), + FLOAT16(190.f), FLOAT16(191.f), FLOAT16(192.f), FLOAT16(193.f), FLOAT16(194.f), FLOAT16(195.f), + FLOAT16(196.f), FLOAT16(197.f), FLOAT16(198.f), FLOAT16(199.f), FLOAT16(200.f), FLOAT16(201.f), + FLOAT16(202.f), FLOAT16(203.f), FLOAT16(204.f), FLOAT16(205.f), FLOAT16(206.f), FLOAT16(207.f), + + FLOAT16(208.f), FLOAT16(209.f), FLOAT16(210.f), FLOAT16(211.f), FLOAT16(212.f), FLOAT16(213.f), + FLOAT16(214.f), FLOAT16(215.f), FLOAT16(216.f), FLOAT16(217.f), FLOAT16(218.f), FLOAT16(219.f), + FLOAT16(220.f), FLOAT16(221.f), FLOAT16(222.f), FLOAT16(223.f), FLOAT16(224.f), FLOAT16(225.f), + FLOAT16(226.f), FLOAT16(227.f), FLOAT16(228.f), FLOAT16(229.f), FLOAT16(230.f), FLOAT16(231.f), + FLOAT16(232.f), FLOAT16(233.f), FLOAT16(234.f), FLOAT16(235.f), FLOAT16(236.f), FLOAT16(237.f), + FLOAT16(238.f), FLOAT16(239.f), FLOAT16(240.f), FLOAT16(241.f), FLOAT16(242.f), FLOAT16(243.f), + + FLOAT16(244.f), FLOAT16(245.f), FLOAT16(246.f), FLOAT16(247.f), FLOAT16(248.f), FLOAT16(249.f), + FLOAT16(250.f), FLOAT16(251.f), FLOAT16(252.f), FLOAT16(253.f), FLOAT16(254.f), FLOAT16(255.f), + FLOAT16(256.f), FLOAT16(257.f), FLOAT16(258.f), FLOAT16(259.f), FLOAT16(260.f), FLOAT16(261.f), + FLOAT16(262.f), FLOAT16(263.f), FLOAT16(264.f), FLOAT16(265.f), FLOAT16(266.f), FLOAT16(267.f), + FLOAT16(268.f), FLOAT16(269.f), FLOAT16(270.f), FLOAT16(271.f), FLOAT16(272.f), FLOAT16(273.f), + FLOAT16(274.f), FLOAT16(275.f), FLOAT16(276.f), FLOAT16(277.f), FLOAT16(278.f), FLOAT16(279.f), + + FLOAT16(280.f), FLOAT16(281.f), FLOAT16(282.f), FLOAT16(283.f), FLOAT16(284.f), FLOAT16(285.f), + FLOAT16(286.f), FLOAT16(287.f), FLOAT16(288.f), FLOAT16(289.f), FLOAT16(290.f), FLOAT16(291.f), + FLOAT16(292.f), FLOAT16(293.f), FLOAT16(294.f), FLOAT16(295.f), FLOAT16(296.f), FLOAT16(297.f), + FLOAT16(298.f), FLOAT16(299.f), FLOAT16(300.f), FLOAT16(301.f), FLOAT16(302.f), FLOAT16(303.f), + FLOAT16(304.f), FLOAT16(305.f), FLOAT16(306.f), FLOAT16(307.f), FLOAT16(308.f), FLOAT16(309.f), + FLOAT16(310.f), FLOAT16(311.f), FLOAT16(312.f), FLOAT16(313.f), FLOAT16(314.f), FLOAT16(315.f), + }); + + set_values(input2, { + FLOAT16(0.0f), + FLOAT16(3.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(777.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(777.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(777.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(777.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(777.0f), + + FLOAT16(666.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(666.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(666.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(666.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(666.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(666.0f), + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 777.f, 999.f, 999.f, 999.f, 999.f, 999.f, + 999.f, 777.f, 999.f, 999.f, 999.f, 999.f, + 999.f, 999.f, 777.f, 999.f, 999.f, 999.f, + 999.f, 999.f, 999.f, 777.f, 999.f, 999.f, + 999.f, 999.f, 999.f, 999.f, 777.f, 999.f, + 999.f, 999.f, 999.f, 999.f, 999.f, 777.f, + + 136.f, 137.f, 138.f, 139.f, 140.f, 141.f, + 142.f, 143.f, 144.f, 145.f, 146.f, 147.f, + 148.f, 149.f, 150.f, 151.f, 152.f, 153.f, + 154.f, 155.f, 156.f, 157.f, 158.f, 159.f, + 160.f, 161.f, 162.f, 163.f, 164.f, 165.f, + 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, + + 172.f, 173.f, 174.f, 175.f, 176.f, 177.f, + 178.f, 179.f, 180.f, 181.f, 182.f, 183.f, + 184.f, 185.f, 186.f, 187.f, 188.f, 189.f, + 190.f, 191.f, 192.f, 193.f, 194.f, 195.f, + 196.f, 197.f, 198.f, 199.f, 200.f, 201.f, + 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, + + 666.f, 888.f, 888.f, 888.f, 888.f, 888.f, + 888.f, 666.f, 888.f, 888.f, 888.f, 888.f, + 888.f, 888.f, 666.f, 888.f, 888.f, 888.f, + 888.f, 888.f, 888.f, 666.f, 888.f, 888.f, + 888.f, 888.f, 888.f, 888.f, 666.f, 888.f, + 888.f, 888.f, 888.f, 888.f, 888.f, 666.f, + + 244.f, 245.f, 246.f, 247.f, 248.f, 249.f, + 250.f, 251.f, 252.f, 253.f, 254.f, 255.f, + 256.f, 257.f, 258.f, 259.f, 260.f, 261.f, + 262.f, 263.f, 264.f, 265.f, 266.f, 267.f, + 268.f, 269.f, 270.f, 271.f, 272.f, 273.f, + 274.f, 275.f, 276.f, 277.f, 278.f, 279.f, + + 280.f, 281.f, 282.f, 283.f, 284.f, 285.f, + 286.f, 287.f, 288.f, 289.f, 290.f, 291.f, + 292.f, 293.f, 294.f, 295.f, 296.f, 297.f, + 298.f, 299.f, 300.f, 301.f, 302.f, 303.f, + 304.f, 305.f, 306.f, 307.f, 308.f, 309.f, + 310.f, 311.f, 312.f, 313.f, 314.f, 315.f, + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d3232_i2411) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 4, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(1.0f), + FLOAT16(2.0f), FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + 104.f, 105.f, + + 106.f, 107.f, + 108.f, 109.f, + 110.f, 777.f, + + 112.f, 113.f, + 114.f, 115.f, + 116.f, 117.f, + + 118.f, 119.f, + 120.f, 121.f, + 122.f, 123.f, + + 124.f, 125.f, + 126.f, 127.f, + 128.f, 129.f, + + 130.f, 131.f, + 132.f, 133.f, + 134.f, 999.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d3232_i2311) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(2.0f), + FLOAT16(2.0f), FLOAT16(1.0f), FLOAT16(2.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + 104.f, 105.f, + + 106.f, 107.f, + 108.f, 109.f, + 777.f, 777.f, + + 112.f, 113.f, + 114.f, 115.f, + 116.f, 117.f, + + 118.f, 119.f, + 120.f, 121.f, + 122.f, 123.f, + + 124.f, 125.f, + 126.f, 127.f, + 128.f, 129.f, + + 130.f, 131.f, + 132.f, 133.f, + 999.f, 999.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d3232_i2211) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 2 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), + FLOAT16(2.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + 104.f, 105.f, + + 777.f, 777.f, + 777.f, 777.f, + 777.f, 777.f, + + 112.f, 113.f, + 114.f, 115.f, + 116.f, 117.f, + + 118.f, 119.f, + 120.f, 121.f, + 122.f, 123.f, + + 124.f, 125.f, + 126.f, 127.f, + 128.f, 129.f, + + 999.f, 999.f, + 999.f, 999.f, + 999.f, 999.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d3232_i2111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 2, 3 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), + FLOAT16(2.0f) + }); + + set_values(input3, { + FLOAT16(666.0f), FLOAT16(666.0f), + FLOAT16(666.0f), FLOAT16(666.0f), + FLOAT16(666.0f), FLOAT16(666.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 666.f, 666.f, + 666.f, 666.f, + 666.f, 666.f, + + 777.f, 777.f, + 777.f, 777.f, + 777.f, 777.f, + + 112.f, 113.f, + 114.f, 115.f, + 116.f, 117.f, + + 118.f, 119.f, + 120.f, 121.f, + 122.f, 123.f, + + 888.f, 888.f, + 888.f, 888.f, + 888.f, 888.f, + + 999.f, 999.f, + 999.f, 999.f, + 999.f, 999.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16, d32323_i25111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 3, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 5, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 1, 1, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 2 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 3 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(2.0f), + FLOAT16(2.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 128.f, 777.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 2 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 128.f, 129.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 3 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 999.f, 129.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d32323_i24111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 3, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 4, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 1, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 2 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 3 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(2.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 777.f, 777.f, 777.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 2 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 128.f, 129.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 3 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 999.f, 999.f, 999.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d32323_i23111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 3, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 1, 1, 3 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 2 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 3 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(2.0f), FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 777.f, 777.f, 777.f, + 777.f, 777.f, 777.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 2 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 128.f, 129.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 3 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 999.f, 999.f, 999.f, + 999.f, 999.f, 999.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d32323_i22111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 3, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 1, 3, 2 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 2 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 3 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), + FLOAT16(2.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(555.0f), FLOAT16(555.0f), FLOAT16(555.0f), + FLOAT16(555.0f), FLOAT16(555.0f), FLOAT16(555.0f), + + FLOAT16(666.0f), FLOAT16(666.0f), FLOAT16(666.0f), + FLOAT16(666.0f), FLOAT16(666.0f), FLOAT16(666.0f), + + FLOAT16(444.0f), FLOAT16(444.0f), FLOAT16(444.0f), + FLOAT16(444.0f), FLOAT16(444.0f), FLOAT16(444.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 555.f, 555.f, 555.f, + 555.f, 555.f, 555.f, + + 666.f, 666.f, 666.f, + 666.f, 666.f, 666.f, + + 444.f, 444.f, 444.f, + 444.f, 444.f, 444.f, + + // 2 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 128.f, 129.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 3 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 777.f, 777.f, 777.f, + 777.f, 777.f, 777.f, + + 888.f, 888.f, 888.f, + 888.f, 888.f, 888.f, + + 999.f, 999.f, 999.f, + 999.f, 999.f, 999.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d32323_i21111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 3, 2, 3 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 2, 3, 2, 3 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 2 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f), + + // 3 + FLOAT16(100.f), FLOAT16(101.f), FLOAT16(102.f), + FLOAT16(103.f), FLOAT16(104.f), FLOAT16(105.f), + + FLOAT16(106.f), FLOAT16(107.f), FLOAT16(108.f), + FLOAT16(109.f), FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), FLOAT16(114.f), + FLOAT16(115.f), FLOAT16(116.f), FLOAT16(117.f), + + FLOAT16(118.f), FLOAT16(119.f), FLOAT16(120.f), + FLOAT16(121.f), FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), FLOAT16(126.f), + FLOAT16(127.f), FLOAT16(128.f), FLOAT16(129.f), + + FLOAT16(130.f), FLOAT16(131.f), FLOAT16(132.f), + FLOAT16(133.f), FLOAT16(134.f), FLOAT16(135.f) + }); + + set_values(input2, { + FLOAT16(0.0f), + FLOAT16(2.0f) + }); + + set_values(input3, { + FLOAT16(555.0f), FLOAT16(555.0f), FLOAT16(555.0f), + FLOAT16(555.0f), FLOAT16(555.0f), FLOAT16(555.0f), + + FLOAT16(666.0f), FLOAT16(666.0f), FLOAT16(666.0f), + FLOAT16(666.0f), FLOAT16(666.0f), FLOAT16(666.0f), + + FLOAT16(444.0f), FLOAT16(444.0f), FLOAT16(444.0f), + FLOAT16(444.0f), FLOAT16(444.0f), FLOAT16(444.0f), + + FLOAT16(555.0f), FLOAT16(555.0f), FLOAT16(555.0f), + FLOAT16(555.0f), FLOAT16(555.0f), FLOAT16(555.0f), + + FLOAT16(666.0f), FLOAT16(666.0f), FLOAT16(666.0f), + FLOAT16(666.0f), FLOAT16(666.0f), FLOAT16(666.0f), + + FLOAT16(444.0f), FLOAT16(444.0f), FLOAT16(444.0f), + FLOAT16(444.0f), FLOAT16(444.0f), FLOAT16(444.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + FLOAT16(888.0f), FLOAT16(888.0f), FLOAT16(888.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 555.f, 555.f, 555.f, + 555.f, 555.f, 555.f, + + 666.f, 666.f, 666.f, + 666.f, 666.f, 666.f, + + 444.f, 444.f, 444.f, + 444.f, 444.f, 444.f, + + 555.f, 555.f, 555.f, + 555.f, 555.f, 555.f, + + 666.f, 666.f, 666.f, + 666.f, 666.f, 666.f, + + 444.f, 444.f, 444.f, + 444.f, 444.f, 444.f, + + // 2 + 100.f, 101.f, 102.f, + 103.f, 104.f, 105.f, + + 106.f, 107.f, 108.f, + 109.f, 110.f, 111.f, + + 112.f, 113.f, 114.f, + 115.f, 116.f, 117.f, + + 118.f, 119.f, 120.f, + 121.f, 122.f, 123.f, + + 124.f, 125.f, 126.f, + 127.f, 128.f, 129.f, + + 130.f, 131.f, 132.f, + 133.f, 134.f, 135.f, + + // 3 + 777.f, 777.f, 777.f, + 777.f, 777.f, 777.f, + + 888.f, 888.f, 888.f, + 888.f, 888.f, 888.f, + + 999.f, 999.f, 999.f, + 999.f, 999.f, 999.f, + + 777.f, 777.f, 777.f, + 777.f, 777.f, 777.f, + + 888.f, 888.f, 888.f, + 888.f, 888.f, 888.f, + + 999.f, 999.f, 999.f, + 999.f, 999.f, 999.f + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d222222_i261111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + // memory order is bfxyzw + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 6, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 1, 1, 1, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f),//1 + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f),//2 + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f),//3 + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + + FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f),//4 + + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f),//5 + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f),//6 + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f),//7 + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f),//8 + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(0.0f), + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(0.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + + 104.f, 105.f, + 106.f, 107.f,//1 + + 108.f, 109.f, + 110.f, 111.f, + + 112.f, 113.f, + 114.f, 115.f,//2 + + 116.f, 117.f, + 118.f, 119.f, + + 120.f, 121.f, + 122.f, 123.f,//3 + + 124.f, 125.f, + 126.f, 127.f, + + 128.f, 129.f, + 777.f, 131.f,//4 + + 132.f, 133.f, + 134.f, 135.f, + + 100.f, 101.f, + 102.f, 103.f,//5 + + 104.f, 105.f, + 106.f, 107.f, + + 108.f, 109.f, + 110.f, 111.f,//6 + + 112.f, 113.f, + 114.f, 115.f, + + 116.f, 117.f, + 118.f, 119.f,//7 + + 120.f, 121.f, + 122.f, 123.f, + + 124.f, 125.f, + 999.f, 127.f,//8 + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d222222_i251111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + // memory order is bfxyzw + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 5, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 1, 1, 1, 1 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f),//1 + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f),//2 + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f),//3 + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + + FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f),//4 + + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f),//5 + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f),//6 + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f),//7 + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f),//8 + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + + 104.f, 105.f, + 106.f, 107.f,//1 + + 108.f, 109.f, + 110.f, 111.f, + + 112.f, 113.f, + 114.f, 115.f,//2 + + 116.f, 117.f, + 118.f, 119.f, + + 120.f, 121.f, + 122.f, 123.f,//3 + + 124.f, 125.f, + 126.f, 127.f, + + 128.f, 129.f, + 777.f, 777.f,//4 + + 132.f, 133.f, + 134.f, 135.f, + + 100.f, 101.f, + 102.f, 103.f,//5 + + 104.f, 105.f, + 106.f, 107.f, + + 108.f, 109.f, + 110.f, 111.f,//6 + + 112.f, 113.f, + 114.f, 115.f, + + 116.f, 117.f, + 118.f, 119.f,//7 + + 120.f, 121.f, + 122.f, 123.f, + + 124.f, 125.f, + 999.f, 999.f,//8 + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d222222_i241111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + // memory order is bfxyzw + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 4, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 1, 1, 1, 2 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f),//1 + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f),//2 + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f),//3 + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + + FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f),//4 + + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f),//5 + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f),//6 + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f),//7 + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f),//8 + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + + 104.f, 105.f, + 106.f, 107.f,//1 + + 108.f, 109.f, + 110.f, 111.f, + + 112.f, 113.f, + 114.f, 115.f,//2 + + 116.f, 117.f, + 118.f, 119.f, + + 120.f, 121.f, + 122.f, 123.f,//3 + + 124.f, 125.f, + 126.f, 127.f, + + 777.f, 777.f, + 777.f, 777.f,//4 + + 132.f, 133.f, + 134.f, 135.f, + + 100.f, 101.f, + 102.f, 103.f,//5 + + 104.f, 105.f, + 106.f, 107.f, + + 108.f, 109.f, + 110.f, 111.f,//6 + + 112.f, 113.f, + 114.f, 115.f, + + 116.f, 117.f, + 118.f, 119.f,//7 + + 120.f, 121.f, + 122.f, 123.f, + + 999.f, 999.f, + 999.f, 999.f,//8 + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + + +TEST(scatter_nd_update_gpu_fp16, d222222_i231111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + // memory order is bfxyzw + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 1, 1, 2, 2 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f),//1 + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f),//2 + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f),//3 + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + + FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f),//4 + + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f),//5 + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f),//6 + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f),//7 + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f),//8 + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), FLOAT16(1.0f), + FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + + 104.f, 105.f, + 106.f, 107.f,//1 + + 108.f, 109.f, + 110.f, 111.f, + + 112.f, 113.f, + 114.f, 115.f,//2 + + 116.f, 117.f, + 118.f, 119.f, + + 120.f, 121.f, + 122.f, 123.f,//3 + + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//4 + + 132.f, 133.f, + 134.f, 135.f, + + 100.f, 101.f, + 102.f, 103.f,//5 + + 104.f, 105.f, + 106.f, 107.f, + + 108.f, 109.f, + 110.f, 111.f,//6 + + 112.f, 113.f, + 114.f, 115.f, + + 116.f, 117.f, + 118.f, 119.f,//7 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//8 + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + + +TEST(scatter_nd_update_gpu_fp16, d222222_i221111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + // memory order is bfxyzw + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 1, 2, 2, 2 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f),//1 + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f),//2 + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f),//3 + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + + FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f),//4 + + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f),//5 + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f),//6 + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f),//7 + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f),//8 + }); + + set_values(input2, { + FLOAT16(0.0f), FLOAT16(1.0f), + FLOAT16(1.0f), FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 100.f, 101.f, + 102.f, 103.f, + + 104.f, 105.f, + 106.f, 107.f,//1 + + 108.f, 109.f, + 110.f, 111.f, + + 112.f, 113.f, + 114.f, 115.f,//2 + + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//3 + + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//4 + + 132.f, 133.f, + 134.f, 135.f, + + 100.f, 101.f, + 102.f, 103.f,//5 + + 104.f, 105.f, + 106.f, 107.f, + + 108.f, 109.f, + 110.f, 111.f,//6 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//7 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//8 + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +} + +TEST(scatter_nd_update_gpu_fp16, d222222_i211111) { + // Dictionary : 6x6x6x1 + // Indexes : 2x1x1x1 + // Updates : 2x6x1x6 + // Output : 6x6x6x1 + // Input values in fp16 + // + + engine engine; + + // memory order is bfxyzw + auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Dictionary + auto input2 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 1, 1, 1, 1, 1 } }); // Indexes + auto input3 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 2, 2, 2, 2, 2 } }); // Updates + + + set_values(input1, { + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f), + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f),//1 + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f), + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f),//2 + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f), + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f),//3 + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f), + + FLOAT16(128.f), FLOAT16(129.f), + FLOAT16(130.f), FLOAT16(131.f),//4 + + FLOAT16(132.f), FLOAT16(133.f), + FLOAT16(134.f), FLOAT16(135.f), + + FLOAT16(100.f), FLOAT16(101.f), + FLOAT16(102.f), FLOAT16(103.f),//5 + + FLOAT16(104.f), FLOAT16(105.f), + FLOAT16(106.f), FLOAT16(107.f), + + FLOAT16(108.f), FLOAT16(109.f), + FLOAT16(110.f), FLOAT16(111.f),//6 + + FLOAT16(112.f), FLOAT16(113.f), + FLOAT16(114.f), FLOAT16(115.f), + + FLOAT16(116.f), FLOAT16(117.f), + FLOAT16(118.f), FLOAT16(119.f),//7 + + FLOAT16(120.f), FLOAT16(121.f), + FLOAT16(122.f), FLOAT16(123.f), + + FLOAT16(124.f), FLOAT16(125.f), + FLOAT16(126.f), FLOAT16(127.f),//8 + }); + + set_values(input2, { + FLOAT16(0.0f), + FLOAT16(1.0f) + }); + + set_values(input3, { + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(777.0f), FLOAT16(777.0f), + FLOAT16(777.0f), FLOAT16(777.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f), + + FLOAT16(999.0f), FLOAT16(999.0f), + FLOAT16(999.0f), FLOAT16(999.0f) + }); + + topology topology; + topology.add(input_layout("InputData", input1.get_layout())); + topology.add(input_layout("InputIndices", input2.get_layout())); + topology.add(input_layout("InputUpdates", input3.get_layout())); + topology.add( + scatter_nd_update("scatter_nd_update", "InputData", "InputIndices", "InputUpdates", 2) + ); + + network network(engine, topology); + + + network.set_input_data("InputData", input1); + network.set_input_data("InputIndices", input2); + network.set_input_data("InputUpdates", input3); + + auto outputs = network.execute(); + + + auto output = outputs.at("scatter_nd_update").get_memory(); + auto output_ptr = output.pointer(); + + std::vector expected_results = { + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//1 + + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//2 + + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//3 + + 777.f, 777.f, + 777.f, 777.f, + + 777.f, 777.f, + 777.f, 777.f,//4 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//5 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//6 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//7 + + 999.f, 999.f, + 999.f, 999.f, + + 999.f, 999.f, + 999.f, 999.f,//8 + }; + + for (size_t i = 0; i < expected_results.size(); ++i) { + EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i])); + } +}