[IE clDNN] Implement ScatterNDUpdate op (#4458)

This commit is contained in:
Kelvin Choi 2021-03-10 14:08:20 +09:00 committed by GitHub
parent d86eab4d84
commit 9c60f4f697
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 4964 additions and 1 deletions

View File

@ -158,6 +158,7 @@ REGISTER_FACTORY(v3, EmbeddingSegmentsSum);
REGISTER_FACTORY(v3, ExtractImagePatches); REGISTER_FACTORY(v3, ExtractImagePatches);
REGISTER_FACTORY(v3, ScatterUpdate); REGISTER_FACTORY(v3, ScatterUpdate);
REGISTER_FACTORY(v3, ScatterElementsUpdate); REGISTER_FACTORY(v3, ScatterElementsUpdate);
REGISTER_FACTORY(v3, ScatterNDUpdate);
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion // REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion
// ----------------------------- Unsupported v3 ops ----------------------------- // // ----------------------------- Unsupported v3 ops ----------------------------- //
@ -167,7 +168,6 @@ REGISTER_FACTORY(v3, ScatterElementsUpdate);
// REGISTER_FACTORY(v3, NonZero); // REGISTER_FACTORY(v3, NonZero);
// REGISTER_FACTORY(v3, ROIAlign); // REGISTER_FACTORY(v3, ROIAlign);
// REGISTER_FACTORY(v3, ReadValue); // REGISTER_FACTORY(v3, ReadValue);
// REGISTER_FACTORY(v3, ScatterNDUpdate);
// REGISTER_FACTORY(v3, ShapeOf); // REGISTER_FACTORY(v3, ShapeOf);
// REGISTER_FACTORY(v3, TopK); // REGISTER_FACTORY(v3, TopK);

View File

@ -0,0 +1,33 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "cldnn_program.h"
#include "cldnn_common_utils.h"
#include "ngraph/op/scatter_nd_update.hpp"
#include "ngraph/op/constant.hpp"
#include "api/scatter_nd_update.hpp"
namespace CLDNNPlugin {
void CreateScatterNDUpdateOp(Program& p, const std::shared_ptr<ngraph::op::v3::ScatterNDUpdate>& op) {
p.ValidateInputs(op, {3});
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);
auto indices_rank = op->get_input_shape(1).size();
auto primitive = cldnn::scatter_nd_update(layerName,
inputPrimitives[0],
inputPrimitives[1],
inputPrimitives[2],
indices_rank);
p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);
}
REGISTER_FACTORY_IMPL(v3, ScatterNDUpdate);
} // namespace CLDNNPlugin

View File

@ -0,0 +1,53 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include "single_layer_tests/scatter_ND_update.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;
namespace {
// map<inputShape map<indicesShape, indicesValue>>
// updateShape is gotten from inputShape and indicesShape
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>> sliceSelectInShape{
{{4, 3, 2, 3, 2}, {{{2, 2, 1}, {3, 2, 0, 1}}}},
{{10, 9, 9, 11}, {{{4, 1}, {1, 3, 5, 7}}, {{1, 2}, {4, 6}}, {{2, 3}, {0, 1, 1, 2, 2, 2}}, {{1, 4}, {5, 5, 4, 9}}}},
{{10, 9, 12, 10, 11}, {{{2, 2, 1}, {5, 6, 2, 8}}, {{2, 3}, {0, 4, 6, 5, 7, 1}}}},
{{15}, {{{2, 1}, {1, 3}}}},
{{15, 14}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}}},
{{15, 14, 13}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}}},
{{15, 14, 13, 12}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}},
{{2, 2, 2}, {2, 3, 1, 8, 7, 5, 6, 5}}}},
{{15, 14, 13, 12, 16}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}},
{{2, 5}, {2, 3, 1, 8, 6, 9, 7, 5, 6, 5}}}},
{{15, 14, 13, 12, 16, 10}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}},
{{1, 2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}}, {{2, 5}, {2, 3, 1, 8, 6, 9, 7, 5, 6, 5}}, {{2, 6}, {2, 3, 1, 8, 6, 5, 9, 7, 5, 6, 5, 7}}}}
};
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
const auto ScatterNDUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterNDUpdateLayerTest::combineShapes(sliceSelectInShape)),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(smoke_ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,54 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "primitive.hpp"
namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{
/// @brief
/// @details
struct scatter_nd_update : public primitive_base<scatter_nd_update> {
CLDNN_DECLARE_PRIMITIVE(scatter_nd_update)
/// @brief Constructs scatter_nd_update primitive.
/// @param id This primitive id.
/// @param dict Input data primitive id.
/// @param idx Input indexes primitive id.
/// @param idupd Input updates primitive id.
/// @param indices_rank Rank of indices.
scatter_nd_update(const primitive_id& id,
const primitive_id& data,
const primitive_id& idx,
const primitive_id& idupd,
const size_t indices_rank,
const padding& output_padding = padding())
: primitive_base(id, {data, idx, idupd}, output_padding), indices_rank(indices_rank) {}
/// @brief ScatterNDUpdate indices_rank
size_t indices_rank;
};
/// @}
/// @}
/// @}
} // namespace cldnn

View File

@ -58,6 +58,7 @@ enum class KernelType {
ONE_HOT, ONE_HOT,
GATHER, GATHER,
SCATTER_UPDATE, SCATTER_UPDATE,
SCATTER_ND_UPDATE,
SCATTER_ELEMENTS_UPDATE, SCATTER_ELEMENTS_UPDATE,
DEPTH_TO_SPACE, DEPTH_TO_SPACE,
BATCH_TO_SPACE, BATCH_TO_SPACE,

View File

@ -0,0 +1,192 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#include "scatter_nd_update_kernel_ref.h"
#include "kernel_selector_utils.h"
#include <string>
#include <vector>
namespace kernel_selector {
ParamsKey ScatterNDUpdateKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
return k;
}
static inline std::string GetOrderString(std::vector<std::string>& order) {
std::string order_str = order[0];
for (size_t i = 1; i < order.size(); i++)
order_str += ", " + order[i];
return order_str;
}
static inline std::vector<std::string> GetDefaultOrder(size_t size) {
std::vector<std::string> default_order;
if (size <= 4) {
default_order = {"b", "f", "y", "x"};
} else if (size == 5) {
default_order = {"b", "f", "z", "y", "x"};
} else if (size == 6) {
default_order = {"b", "f", "w", "z", "y", "x"};
}
return default_order;
}
ScatterNDUpdateKernelRef::DispatchData
ScatterNDUpdateKernelRef::SetDefault(const scatter_nd_update_params& params, const optional_params&, bool is_second) const {
DispatchData dispatchData;
if (!is_second) {
const auto& scope = params.output;
dispatchData.gws = { scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v };
} else {
auto indices_rank = params.indices_rank;
const auto& indices = params.inputs[1];
auto indices_dims = indices.LogicalDims();
if (indices_dims.size() > 1) {
std::reverse(indices_dims.begin(), indices_dims.end());
}
dispatchData.indicesLastDim = indices_dims[indices_rank - 1];
size_t indices_set_size = 1;
for (size_t i = 0; i < (indices_rank - 1); i++) {
indices_set_size *= indices_dims[i];
}
dispatchData.gws = {1, 1, indices_set_size};
}
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
return dispatchData;
}
JitConstants ScatterNDUpdateKernelRef::GetJitConstants(const scatter_nd_update_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);
if (!params.fused_ops.empty()) {
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2}));
}
return jit;
}
bool ScatterNDUpdateKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType:: SCATTER_ND_UPDATE || o.GetType() != KernelType::SCATTER_ND_UPDATE) {
return false;
}
const scatter_nd_update_params& params = static_cast<const scatter_nd_update_params&>(p);
auto input_dims = params.inputs[0].LogicalDims();
auto indices_dims = params.inputs[1].LogicalDims();
std::reverse(indices_dims.begin(), indices_dims.end());
auto indices_rank = params.indices_rank;
if (indices_rank < 1) {
return false;
}
if (indices_dims[indices_rank - 1] > input_dims.size()) {
return false;
}
for (auto& fused_op : params.fused_ops) {
if (!IsFusedPrimitiveSupported(fused_op))
return false;
}
return true;
}
static std::string GetInputBlockND(const scatter_nd_update_params& params) {
const auto& input = params.inputs[0];
auto input_dims = input.LogicalDims();
std::reverse(input_dims.begin(), input_dims.end());
while (!input_dims.empty() && input_dims.back() == 1) {
input_dims.pop_back();
}
const int rank = static_cast<int>(input_dims.size());
std::vector<size_t> block_nd(rank + 1);
block_nd[rank] = 1;
for (int idx = (rank - 1); idx >= 0; idx--) {
block_nd[idx] = input_dims[idx] * block_nd[idx + 1];
}
std::stringstream s;
for (int i = 0; i < (rank + 1); i++) {
if (i < rank) {
s << block_nd[i] << ",";
} else {
s << block_nd[i];
}
}
auto str_result = s.str();
return str_result;
}
KernelsData ScatterNDUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}
KernelData kd = KernelData::Default<scatter_nd_update_params>(params, 2);
scatter_nd_update_params& newParams = *static_cast<scatter_nd_update_params*>(kd.params.get());
auto cldnn_jit = GetJitConstants(newParams);
// First iter - copy input data to output data
// Second iter - update values specified by updates at specific index position specified by indices
for (int i = 0; i < 2; i++) {
auto dispatchData = SetDefault(newParams, options, (i == 1));
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
if (i == 1) {
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim));
cldnn_jit.AddConstant(MakeJitConstant("INPUT_BLOCK_ND", GetInputBlockND(newParams)));
}
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
clKernelData& kernel = kd.kernels[i];
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params));
}
return {kd};
}
} // namespace kernel_selector

View File

@ -0,0 +1,62 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#pragma once
#include "kernel_base_opencl.h"
namespace kernel_selector {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// scatter_nd_update_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct scatter_nd_update_params : public base_params {
scatter_nd_update_params() : base_params(KernelType::SCATTER_ND_UPDATE), indices_rank(0) {}
size_t indices_rank;
virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// scatter_nd_update_optional_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct scatter_nd_update_optional_params : optional_params {
scatter_nd_update_optional_params() : optional_params(KernelType::SCATTER_ND_UPDATE) {}
};
class ScatterNDUpdateKernelRef : public KernelBaseOpenCL {
public:
struct DispatchData : public CommonDispatchData {
size_t indicesLastDim;
};
ScatterNDUpdateKernelRef() : KernelBaseOpenCL("scatter_nd_update_ref") {}
virtual ~ScatterNDUpdateKernelRef() {}
virtual JitConstants GetJitConstants(const scatter_nd_update_params& params) const;
virtual DispatchData SetDefault(const scatter_nd_update_params& params, const optional_params&, bool is_second) const;
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
FusedOpType::SCALE,
FusedOpType::ACTIVATION,
FusedOpType::ELTWISE };
}
protected:
bool Validate(const Params& p, const optional_params& o) const override;
};
} // namespace kernel_selector

View File

@ -0,0 +1,27 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#include "scatter_nd_update_kernel_selector.h"
#include "scatter_nd_update_kernel_ref.h"
namespace kernel_selector {
scatter_nd_update_kernel_selector::scatter_nd_update_kernel_selector() { Attach<ScatterNDUpdateKernelRef>(); }
KernelsData scatter_nd_update_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::SCATTER_ND_UPDATE);
}
} // namespace kernel_selector

View File

@ -0,0 +1,35 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#pragma once
#include "kernel_selector.h"
namespace kernel_selector {
class scatter_nd_update_kernel_selector : public kernel_selector_base {
public:
static scatter_nd_update_kernel_selector& Instance() {
static scatter_nd_update_kernel_selector instance_;
return instance_;
}
scatter_nd_update_kernel_selector();
virtual ~scatter_nd_update_kernel_selector() {}
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
};
} // namespace kernel_selector

View File

@ -0,0 +1,152 @@
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "include/include_all.cl"
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
#if OUTPUT_DIMS == 4
#define ORDER b,f,y,x
#elif OUTPUT_DIMS == 5
#define ORDER b,f,z,y,x
#elif OUTPUT_DIMS == 6
#define ORDER b,f,w,z,y,x
#endif
KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
const __global INPUT1_TYPE* indices,
const __global INPUT2_TYPE* updates,
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint dim0 = get_global_id(0);
const uint dim1 = get_global_id(1);
const uint dim2 = get_global_id(2);
#ifndef IS_SECOND_ITER // First kernel
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint z = dim1 % OUTPUT_SIZE_Z;
const uint w = dim1 / OUTPUT_SIZE_Z;
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
INPUT0_TYPE val = data[output_idx];
#if HAS_FUSED_OPS
FUSED_OPS_FIRST_KERNEL;
output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_FIRST_KERNEL);
#else
output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif
#else // Second kernel
const uint blockND[] = {INPUT_BLOCK_ND};
const uint k = INDICES_LAST_DIM;
const uint size_to_update = blockND[INDICES_LAST_DIM];
const uint indices_idx = dim2;
const uint indices_offset = indices_idx * k;
uint dst_offset = 0;
for (uint i = 0; i < k; i++) {
INPUT1_TYPE idxValue = indices[indices_offset + i];
dst_offset += idxValue * blockND[i + 1];
}
uint update_offset = indices_idx * size_to_update;
for (int i = 0; i < size_to_update; i++) {
uint dst_idx = dst_offset + i;
uint up_idx = update_offset + i;
INPUT2_TYPE val = updates[up_idx];
#if HAS_FUSED_OPS
#if OUTPUT_DIMS == 4
const uint y_pitch = OUTPUT_SIZE_X;
const uint f_pitch = y_pitch * OUTPUT_SIZE_Y;
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
const uint b_remain = dst_idx % b_pitch;
const uint f_remain = b_remain % f_pitch;
const uint y_remain = f_remain % y_pitch;
const uint b = dst_idx / b_pitch;
const uint f = b_remain / f_pitch;
const uint y = f_remain / y_pitch;
const uint x = y_remain;
#elif OUTPUT_DIMS == 5
const uint y_pitch = OUTPUT_SIZE_X;
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
const uint f_pitch = z_pitch * OUTPUT_SIZE_Z;
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
const uint b_remain = dst_idx % b_pitch;
const uint f_remain = b_remain % f_pitch;
const uint z_remain = f_remain % z_pitch;
const uint y_remain = z_remain % y_pitch;
const uint b = dst_idx / b_pitch;
const uint f = b_remain / f_pitch;
const uint z = f_remain / z_pitch;
const uint y = z_remain / y_pitch;
const uint x = y_remain;
#elif OUTPUT_DIMS == 6
const uint y_pitch = OUTPUT_SIZE_X;
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
const uint w_pitch = z_pitch * OUTPUT_SIZE_Z;
const uint f_pitch = w_pitch * OUTPUT_SIZE_W;
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
const uint b_remain = dst_idx % b_pitch;
const uint f_remain = b_remain % f_pitch;
const uint w_remain = f_remain % w_pitch;
const uint z_remain = w_remain % z_pitch;
const uint y_remain = z_remain % y_pitch;
const uint b = dst_idx / b_pitch;
const uint f = b_remain / f_pitch;
const uint w = f_remain / w_pitch;
const uint z = w_remain / z_pitch;
const uint y = z_remain / y_pitch;
const uint x = y_remain;
#endif
FUSED_OPS_SECOND_KERNEL;
output[dst_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
#else
output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif
}
#endif
}
#ifdef GET_UPDATES_INDEX
#undef GET_UPDATES_INDEX
#endif
#ifdef GET_OUTPUT_INDEX
#undef GET_OUTPUT_INDEX
#endif
#ifdef ORDER
#undef ORDER
#endif

View File

@ -69,6 +69,7 @@ void register_implementations_gpu() {
REGISTER_GPU(roi_pooling); REGISTER_GPU(roi_pooling);
REGISTER_GPU(scale); REGISTER_GPU(scale);
REGISTER_GPU(scatter_update); REGISTER_GPU(scatter_update);
REGISTER_GPU(scatter_nd_update);
REGISTER_GPU(scatter_elements_update); REGISTER_GPU(scatter_elements_update);
REGISTER_GPU(select); REGISTER_GPU(select);
REGISTER_GPU(shuffle_channels); REGISTER_GPU(shuffle_channels);

View File

@ -62,6 +62,7 @@
#include "api/scale.hpp" #include "api/scale.hpp"
#include "api/scatter_update.hpp" #include "api/scatter_update.hpp"
#include "api/scatter_elements_update.hpp" #include "api/scatter_elements_update.hpp"
#include "api/scatter_nd_update.hpp"
#include "api/select.hpp" #include "api/select.hpp"
#include "api/shuffle_channels.hpp" #include "api/shuffle_channels.hpp"
#include "api/softmax.hpp" #include "api/softmax.hpp"
@ -138,6 +139,7 @@ REGISTER_GPU(roi_pooling);
REGISTER_GPU(scale); REGISTER_GPU(scale);
REGISTER_GPU(scatter_update); REGISTER_GPU(scatter_update);
REGISTER_GPU(scatter_elements_update); REGISTER_GPU(scatter_elements_update);
REGISTER_GPU(scatter_nd_update);
REGISTER_GPU(select); REGISTER_GPU(select);
REGISTER_GPU(shuffle_channels); REGISTER_GPU(shuffle_channels);
REGISTER_GPU(softmax); REGISTER_GPU(softmax);

View File

@ -0,0 +1,78 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#include "scatter_nd_update_inst.h"
#include "primitive_gpu_base.h"
#include "implementation_map.h"
#include "kernel_selector_helper.h"
#include "scatter_update/scatter_nd_update_kernel_selector.h"
#include "scatter_update/scatter_nd_update_kernel_ref.h"
#include "error_handler.h"
using namespace cldnn;
namespace cldnn {
namespace gpu {
struct scatter_nd_update_gpu : typed_primitive_gpu_impl<scatter_nd_update> {
using parent = typed_primitive_gpu_impl<scatter_nd_update>;
using parent::parent;
public:
static primitive_impl* create(const scatter_nd_update_node& arg) {
auto scatter_nd_update_params = get_default_params<kernel_selector::scatter_nd_update_params>(arg);
auto scatter_nd_update_optional_params =
get_default_optional_params<kernel_selector::scatter_nd_update_optional_params>(arg.get_program());
scatter_nd_update_params.indices_rank = arg.get_primitive()->indices_rank;
scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
auto& kernel_selector = kernel_selector::scatter_nd_update_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(scatter_nd_update_params, scatter_nd_update_optional_params);
CLDNN_ERROR_BOOL(arg.id(),
"Best_kernel.empty()",
best_kernels.empty(),
"Cannot find a proper kernel with this arguments");
auto scatter_nd_update = new scatter_nd_update_gpu(arg, best_kernels[0]);
return scatter_nd_update;
}
};
namespace detail {
attach_scatter_nd_update_gpu::attach_scatter_nd_update_gpu() {
auto val_fw = scatter_nd_update_gpu::create;
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
}
} // namespace detail
} // namespace gpu
} // namespace cldnn

View File

@ -45,6 +45,7 @@
#include "space_to_depth_inst.h" #include "space_to_depth_inst.h"
#include "gather_inst.h" #include "gather_inst.h"
#include "scatter_update_inst.h" #include "scatter_update_inst.h"
#include "scatter_nd_update_inst.h"
#include "scatter_elements_update_inst.h" #include "scatter_elements_update_inst.h"
#include "reverse_sequence_inst.h" #include "reverse_sequence_inst.h"
#include "shuffle_channels_inst.h" #include "shuffle_channels_inst.h"
@ -206,6 +207,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
!input.is_type<softmax>() && !input.is_type<resample>() && !input.is_type<mvn>() && !input.is_type<softmax>() && !input.is_type<resample>() && !input.is_type<mvn>() &&
!input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() && !input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() && !input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
!input.is_type<scatter_nd_update>() &&
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() && !input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() && !input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>())) !input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
@ -540,6 +542,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
should_fuse |= input_data.is_type<scatter_update>(); should_fuse |= input_data.is_type<scatter_update>();
should_fuse |= input_data.is_type<scatter_nd_update>();
should_fuse |= input_data.is_type<scatter_elements_update>(); should_fuse |= input_data.is_type<scatter_elements_update>();
should_fuse |= input_data.is_type<depth_to_space>(); should_fuse |= input_data.is_type<depth_to_space>();
@ -604,6 +608,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
should_fuse |= input_data.is_type<scatter_update>(); should_fuse |= input_data.is_type<scatter_update>();
should_fuse |= input_data.is_type<scatter_nd_update>();
should_fuse |= input_data.is_type<scatter_elements_update>(); should_fuse |= input_data.is_type<scatter_elements_update>();
should_fuse |= input_data.is_type<depth_to_space>(); should_fuse |= input_data.is_type<depth_to_space>();
@ -690,6 +696,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
should_fuse |= input_data.is_type<scatter_elements_update>() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type<scatter_elements_update>() && quantize_node.get_scale_shift_opt();
should_fuse |= input_data.is_type<permute>() && quantize_node.get_scale_shift_opt(); should_fuse |= input_data.is_type<permute>() && quantize_node.get_scale_shift_opt();
@ -745,6 +753,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
(parents[i]->is_type<space_to_batch>()) || (parents[i]->is_type<space_to_batch>()) ||
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) || (parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
(parents[i]->is_type<scale>()) || (parents[i]->is_type<scale>()) ||
(parents[i]->is_type<scatter_nd_update>()) ||
(parents[i]->is_type<scatter_elements_update>()) || (parents[i]->is_type<scatter_elements_update>()) ||
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) || (parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
(parents[i]->is_type<depth_to_space>() && dts_supports_fusings(parents[i]->as<depth_to_space>())) || (parents[i]->is_type<depth_to_space>() && dts_supports_fusings(parents[i]->as<depth_to_space>())) ||

View File

@ -0,0 +1,49 @@
/*
// Copyright (c) 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "api/scatter_nd_update.hpp"
#include "primitive_inst.h"
#include <string>
namespace cldnn {
template <>
struct typed_program_node<scatter_nd_update> : public typed_program_node_base<scatter_nd_update> {
using parent = typed_program_node_base<scatter_nd_update>;
public:
using parent::parent;
program_node& input(size_t index = 0) const { return get_dependency(index); }
};
using scatter_nd_update_node = typed_program_node<scatter_nd_update>;
template <>
class typed_primitive_inst<scatter_nd_update> : public typed_primitive_inst_base<scatter_nd_update> {
using parent = typed_primitive_inst_base<scatter_nd_update>;
public:
static layout calc_output_layout(scatter_nd_update_node const& node);
static std::string to_string(scatter_nd_update_node const& node);
public:
typed_primitive_inst(network_impl& network, scatter_nd_update_node const& desc);
};
using scatter_nd_update_inst = typed_primitive_inst<scatter_nd_update>;
} // namespace cldnn

View File

@ -0,0 +1,66 @@
/*
// Copyright (c) 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#include "scatter_nd_update_inst.h"
#include "primitive_type_base.h"
#include "error_handler.h"
#include "json_object.h"
#include <string>
namespace cldnn {
primitive_type_id scatter_nd_update::type_id() {
static primitive_type_base<scatter_nd_update> instance;
return &instance;
}
layout scatter_nd_update_inst::calc_output_layout(scatter_nd_update_node const& node) {
auto input_layout = node.input(0).get_output_layout();
auto output_shape = input_layout.size;
auto input_format = input_layout.format;
auto output_type = input_layout.data_type;
if (node.has_fused_primitives()) {
output_type = node.get_fused_output_layout().data_type;
}
return layout{output_type, input_format, output_shape};
}
std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();
auto& input = node.input();
std::stringstream primitive_description;
json_composite scatter_nd_update_info;
scatter_nd_update_info.add("input id", input.id());
scatter_nd_update_info.add("input shape", node.input(0).get_output_layout().size.to_string());
scatter_nd_update_info.add("indices shape", node.input(1).get_output_layout().size.to_string());
scatter_nd_update_info.add("updates shape", node.input(2).get_output_layout().size.to_string());
node_info->add("scatter_nd_update info", scatter_nd_update_info);
node_info->dump(primitive_description);
return primitive_description.str();
}
scatter_nd_update_inst::typed_primitive_inst(network_impl& network, scatter_nd_update_node const& node) : parent(network, node) {}
} // namespace cldnn

View File

@ -35,6 +35,7 @@
#include "api/permute.hpp" #include "api/permute.hpp"
#include "api/gather.hpp" #include "api/gather.hpp"
#include "api/scatter_update.hpp" #include "api/scatter_update.hpp"
#include "api/scatter_nd_update.hpp"
#include "api/scatter_elements_update.hpp" #include "api/scatter_elements_update.hpp"
#include "api/depth_to_space.hpp" #include "api/depth_to_space.hpp"
#include "api/space_to_depth.hpp" #include "api/space_to_depth.hpp"
@ -7569,3 +7570,224 @@ INSTANTIATE_TEST_CASE_P(DISABLED_fusings_gpu,
reduce_test_params{CASE_REDUCE_I8_3, 2, 4, reduce_mode::mean, {reduce::along_x}, true, "reduce_ref"}, reduce_test_params{CASE_REDUCE_I8_3, 2, 4, reduce_mode::mean, {reduce::along_x}, true, "reduce_ref"},
reduce_test_params{CASE_REDUCE_U8_3, 2, 4, reduce_mode::l2, {reduce::along_x}, true, "reduce_ref"} reduce_test_params{CASE_REDUCE_U8_3, 2, 4, reduce_mode::l2, {reduce::along_x}, true, "reduce_ref"}
}), ); }), );
/* ----------------------------------------------------------------------------------------------------- */
/* ------------------------------------------ ScatterNDUpdate cases ------------------------------ */
/* ----------------------------------------------------------------------------------------------------- */
struct scatter_nd_update_test_params {
tensor input_shape;
tensor indices_shape;
tensor updates_shape;
int max_number_in_indices;
int indices_rank;
data_types data_type;
format input_format;
data_types default_type;
format default_format;
size_t expected_fused_primitives;
size_t expected_not_fused_primitives;
};
#define CASE_SCATTER_ND_UPDATE_FP16_4D_1 {6, 1, 1, 1}, {3, 1, 1, 1}, {3, 1, 1, 1}, 6, 1, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_4D_2 {6, 6, 1, 1}, {3, 2, 1, 1}, {3, 1, 1, 1}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_4D_3 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_4D_4 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_4D_5 {6, 7, 8, 9}, {6, 2, 1, 1}, {6, 9, 1, 8}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_4D_6 {6, 7, 8, 9}, {6, 3, 1, 1}, {6, 8, 1, 1}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_1 {6, 7, 8, 9, 10}, {5, 1, 1, 1, 1}, {5, 7, 8, 9, 10}, 6, 1, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_2 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 1}, {5, 10, 1, 8, 9}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_3 {6, 7, 8, 9, 10}, {5, 3, 1, 1, 1}, {5, 9, 1, 1, 8}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_4 {6, 7, 8, 9, 10}, {5, 4, 1, 1, 1}, {5, 8, 1, 1, 1}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_5 {6, 7, 8, 9, 10}, {5, 5, 1, 1, 1}, {5, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_6 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 2}, {5, 2, 8, 9, 10}, 6, 3, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_7 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 3}, {5, 2, 1, 8, 9}, 6, 3, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_8 {6, 7, 8, 9, 10}, {5, 2, 1, 4, 3}, {5, 2, 1, 8, 3}, 6, 4, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_5D_9 {6, 7, 8, 9, 10}, {5, 2, 1, 3, 3}, {5, 2, 8, 9, 3}, 6, 4, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_6D_1 {6, 7, 8, 9, 10, 11}, {5, 1, 1, 1}, {5, 7, 8, 9, 10, 11}, 6, 1, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_6D_2 {6, 7, 8, 9, 10, 11}, {5, 2, 1, 1}, {5, 11, 1, 8, 9, 10}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_6D_3 {6, 7, 8, 9, 10, 11}, {5, 3, 1, 1}, {5, 10, 1, 1, 8, 9}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_6D_4 {6, 7, 8, 9, 10, 11}, {5, 4, 1, 1}, {5, 9, 1, 1, 1, 8}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_6D_5 {6, 7, 8, 9, 2, 2}, {5, 5, 1, 1}, {5, 8, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP16_6D_6 {6, 7, 8, 9, 2, 2}, {5, 6, 1, 1}, {5, 1, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_4D_1 {6, 1, 1, 1}, {3, 1, 1, 1}, {3, 1, 1, 1}, 6, 1, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_4D_2 {6, 6, 1, 1}, {3, 2, 1, 1}, {3, 1, 1, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_4D_3 {6, 7, 8, 1}, {5, 1, 1, 1}, {5, 7, 8, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_4D_4 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_4D_5 {6, 7, 8, 9}, {6, 2, 1, 1}, {6, 9, 1, 8}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_4D_6 {6, 7, 8, 9}, {6, 3, 1, 1}, {6, 8, 1, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_5D_1 {6, 7, 8, 9, 10}, {5, 1, 1, 1, 1}, {5, 7, 8, 9, 10}, 6, 1, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_5D_2 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 1}, {5, 10, 1, 8, 9}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_5D_3 {6, 7, 8, 9, 10}, {5, 3, 1, 1, 1}, {5, 9, 1, 1, 8}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_5D_4 {6, 7, 8, 9, 10}, {5, 4, 1, 1, 1}, {5, 8, 1, 1, 1}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_5D_5 {6, 7, 8, 9, 10}, {5, 5, 1, 1, 1}, {5, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_1 {6, 7, 8, 9, 10, 11}, {5, 1, 1, 1}, {5, 7, 8, 9, 10, 11}, 6, 1, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_2 {6, 7, 8, 9, 10, 11}, {5, 2, 1, 1}, {5, 11, 1, 8, 9, 10}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_3 {6, 7, 8, 9, 10, 11}, {5, 3, 1, 1}, {5, 10, 1, 1, 8, 9}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_4 {6, 7, 8, 9, 10, 11}, {5, 4, 1, 1}, {5, 9, 1, 1, 1, 8}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_5 {6, 7, 8, 9, 2, 2}, {5, 5, 1, 1}, {5, 8, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
#define CASE_SCATTER_ND_UPDATE_FP32_6D_6 {6, 7, 8, 9, 2, 2}, {5, 6, 1, 1}, {5, 1, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
class ScatterNDUpdatePrimitiveFusingTest : public ::BaseFusingTest<scatter_nd_update_test_params> {
public:
void execute(scatter_nd_update_test_params& p) {
auto input_prim = get_mem(get_input_layout(p));
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
network network_fused(this->engine, this->topology_fused, bo_fused);
network_fused.set_input_data("input", input_prim);
network_not_fused.set_input_data("input", input_prim);
compare(network_not_fused, network_fused, p);
}
layout get_input_layout(scatter_nd_update_test_params& p) {
return layout{ p.data_type, p.input_format, p.input_shape };
}
layout get_indices_layout(scatter_nd_update_test_params& p) {
return layout{ p.data_type, p.input_format, p.indices_shape };
}
layout get_updates_layout(scatter_nd_update_test_params& p) {
return layout{ p.data_type, p.input_format, p.updates_shape };
}
layout get_per_channel_layout(scatter_nd_update_test_params& p) {
return layout{ p.default_type, p.default_format, tensor{1, p.input_shape.feature[0], 1, 1} };
}
};
class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {};
TEST_P(scatter_nd_update_quantize, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)),
data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)),
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)),
scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank),
quantize("quantize", "scatter_nd_update_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
tolerance = 1.f;
execute(p);
}
INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_quantize,
::testing::ValuesIn(std::vector<scatter_nd_update_test_params>{
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_6, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_1, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_2, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_3, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 },
}), );
class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {};
TEST_P(scatter_nd_update_scale_activation_eltwise, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)),
data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)),
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.input_shape })),
scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank),
activation("activation", "scatter_nd_update_prim", activation_func::abs),
scale("scale", "activation", "scale_data"),
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
);
tolerance = 1.f;
execute(p);
}
INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise,
::testing::ValuesIn(std::vector<scatter_nd_update_test_params>{
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_6, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_1, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_2, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_3, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
}), );

File diff suppressed because it is too large Load Diff