[IE clDNN] Implement ScatterNDUpdate op (#4458)
This commit is contained in:
parent
d86eab4d84
commit
9c60f4f697
@ -158,6 +158,7 @@ REGISTER_FACTORY(v3, EmbeddingSegmentsSum);
|
|||||||
REGISTER_FACTORY(v3, ExtractImagePatches);
|
REGISTER_FACTORY(v3, ExtractImagePatches);
|
||||||
REGISTER_FACTORY(v3, ScatterUpdate);
|
REGISTER_FACTORY(v3, ScatterUpdate);
|
||||||
REGISTER_FACTORY(v3, ScatterElementsUpdate);
|
REGISTER_FACTORY(v3, ScatterElementsUpdate);
|
||||||
|
REGISTER_FACTORY(v3, ScatterNDUpdate);
|
||||||
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion
|
// REGISTER_FACTORY(v3, NonMaxSuppression); Supported via v3 -> v5 internal conversion
|
||||||
|
|
||||||
// ----------------------------- Unsupported v3 ops ----------------------------- //
|
// ----------------------------- Unsupported v3 ops ----------------------------- //
|
||||||
@ -167,7 +168,6 @@ REGISTER_FACTORY(v3, ScatterElementsUpdate);
|
|||||||
// REGISTER_FACTORY(v3, NonZero);
|
// REGISTER_FACTORY(v3, NonZero);
|
||||||
// REGISTER_FACTORY(v3, ROIAlign);
|
// REGISTER_FACTORY(v3, ROIAlign);
|
||||||
// REGISTER_FACTORY(v3, ReadValue);
|
// REGISTER_FACTORY(v3, ReadValue);
|
||||||
// REGISTER_FACTORY(v3, ScatterNDUpdate);
|
|
||||||
// REGISTER_FACTORY(v3, ShapeOf);
|
// REGISTER_FACTORY(v3, ShapeOf);
|
||||||
// REGISTER_FACTORY(v3, TopK);
|
// REGISTER_FACTORY(v3, TopK);
|
||||||
|
|
||||||
|
33
inference-engine/src/cldnn_engine/ops/scatter_nd_update.cpp
Normal file
33
inference-engine/src/cldnn_engine/ops/scatter_nd_update.cpp
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "cldnn_program.h"
|
||||||
|
#include "cldnn_common_utils.h"
|
||||||
|
|
||||||
|
#include "ngraph/op/scatter_nd_update.hpp"
|
||||||
|
#include "ngraph/op/constant.hpp"
|
||||||
|
|
||||||
|
#include "api/scatter_nd_update.hpp"
|
||||||
|
|
||||||
|
namespace CLDNNPlugin {
|
||||||
|
|
||||||
|
void CreateScatterNDUpdateOp(Program& p, const std::shared_ptr<ngraph::op::v3::ScatterNDUpdate>& op) {
|
||||||
|
p.ValidateInputs(op, {3});
|
||||||
|
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||||
|
std::string layerName = layer_type_name_ID(op);
|
||||||
|
auto indices_rank = op->get_input_shape(1).size();
|
||||||
|
|
||||||
|
auto primitive = cldnn::scatter_nd_update(layerName,
|
||||||
|
inputPrimitives[0],
|
||||||
|
inputPrimitives[1],
|
||||||
|
inputPrimitives[2],
|
||||||
|
indices_rank);
|
||||||
|
|
||||||
|
p.AddPrimitive(primitive);
|
||||||
|
p.AddPrimitiveToProfiler(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_FACTORY_IMPL(v3, ScatterNDUpdate);
|
||||||
|
|
||||||
|
} // namespace CLDNNPlugin
|
@ -0,0 +1,53 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
|
||||||
|
#include "single_layer_tests/scatter_ND_update.hpp"
|
||||||
|
#include "common_test_utils/test_constants.hpp"
|
||||||
|
|
||||||
|
using namespace LayerTestsDefinitions;
|
||||||
|
using namespace ngraph::opset3;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// map<inputShape map<indicesShape, indicesValue>>
|
||||||
|
// updateShape is gotten from inputShape and indicesShape
|
||||||
|
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>> sliceSelectInShape{
|
||||||
|
{{4, 3, 2, 3, 2}, {{{2, 2, 1}, {3, 2, 0, 1}}}},
|
||||||
|
{{10, 9, 9, 11}, {{{4, 1}, {1, 3, 5, 7}}, {{1, 2}, {4, 6}}, {{2, 3}, {0, 1, 1, 2, 2, 2}}, {{1, 4}, {5, 5, 4, 9}}}},
|
||||||
|
{{10, 9, 12, 10, 11}, {{{2, 2, 1}, {5, 6, 2, 8}}, {{2, 3}, {0, 4, 6, 5, 7, 1}}}},
|
||||||
|
{{15}, {{{2, 1}, {1, 3}}}},
|
||||||
|
{{15, 14}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}}},
|
||||||
|
{{15, 14, 13}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}}},
|
||||||
|
{{15, 14, 13, 12}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}},
|
||||||
|
{{2, 2, 2}, {2, 3, 1, 8, 7, 5, 6, 5}}}},
|
||||||
|
{{15, 14, 13, 12, 16}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}},
|
||||||
|
{{2, 5}, {2, 3, 1, 8, 6, 9, 7, 5, 6, 5}}}},
|
||||||
|
{{15, 14, 13, 12, 16, 10}, {{{2, 1}, {1, 3}}, {{2, 2}, {2, 3, 10, 11}}, {{2, 3}, {2, 3, 1, 8, 10, 11}}, {{2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}},
|
||||||
|
{{1, 2, 4}, {2, 3, 1, 8, 7, 5, 6, 5}}, {{2, 5}, {2, 3, 1, 8, 6, 9, 7, 5, 6, 5}}, {{2, 6}, {2, 3, 1, 8, 6, 5, 9, 7, 5, 6, 5, 7}}}}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
const std::vector<InferenceEngine::Precision> inputPrecisions = {
|
||||||
|
InferenceEngine::Precision::FP32,
|
||||||
|
InferenceEngine::Precision::FP16,
|
||||||
|
InferenceEngine::Precision::I32,
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<InferenceEngine::Precision> idxPrecisions = {
|
||||||
|
InferenceEngine::Precision::I32,
|
||||||
|
InferenceEngine::Precision::I64,
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto ScatterNDUpdateCases = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(ScatterNDUpdateLayerTest::combineShapes(sliceSelectInShape)),
|
||||||
|
::testing::ValuesIn(inputPrecisions),
|
||||||
|
::testing::ValuesIn(idxPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
|
||||||
|
} // namespace
|
54
inference-engine/thirdparty/clDNN/api/scatter_nd_update.hpp
vendored
Normal file
54
inference-engine/thirdparty/clDNN/api/scatter_nd_update.hpp
vendored
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#pragma once
|
||||||
|
#include "primitive.hpp"
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
/// @addtogroup cpp_api C++ API
|
||||||
|
/// @{
|
||||||
|
/// @addtogroup cpp_topology Network Topology
|
||||||
|
/// @{
|
||||||
|
/// @addtogroup cpp_primitives Primitives
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// @brief
|
||||||
|
/// @details
|
||||||
|
struct scatter_nd_update : public primitive_base<scatter_nd_update> {
|
||||||
|
CLDNN_DECLARE_PRIMITIVE(scatter_nd_update)
|
||||||
|
|
||||||
|
/// @brief Constructs scatter_nd_update primitive.
|
||||||
|
/// @param id This primitive id.
|
||||||
|
/// @param dict Input data primitive id.
|
||||||
|
/// @param idx Input indexes primitive id.
|
||||||
|
/// @param idupd Input updates primitive id.
|
||||||
|
/// @param indices_rank Rank of indices.
|
||||||
|
scatter_nd_update(const primitive_id& id,
|
||||||
|
const primitive_id& data,
|
||||||
|
const primitive_id& idx,
|
||||||
|
const primitive_id& idupd,
|
||||||
|
const size_t indices_rank,
|
||||||
|
const padding& output_padding = padding())
|
||||||
|
: primitive_base(id, {data, idx, idupd}, output_padding), indices_rank(indices_rank) {}
|
||||||
|
|
||||||
|
/// @brief ScatterNDUpdate indices_rank
|
||||||
|
size_t indices_rank;
|
||||||
|
};
|
||||||
|
/// @}
|
||||||
|
/// @}
|
||||||
|
/// @}
|
||||||
|
} // namespace cldnn
|
@ -58,6 +58,7 @@ enum class KernelType {
|
|||||||
ONE_HOT,
|
ONE_HOT,
|
||||||
GATHER,
|
GATHER,
|
||||||
SCATTER_UPDATE,
|
SCATTER_UPDATE,
|
||||||
|
SCATTER_ND_UPDATE,
|
||||||
SCATTER_ELEMENTS_UPDATE,
|
SCATTER_ELEMENTS_UPDATE,
|
||||||
DEPTH_TO_SPACE,
|
DEPTH_TO_SPACE,
|
||||||
BATCH_TO_SPACE,
|
BATCH_TO_SPACE,
|
||||||
|
@ -0,0 +1,192 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "scatter_nd_update_kernel_ref.h"
|
||||||
|
#include "kernel_selector_utils.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
|
||||||
|
ParamsKey ScatterNDUpdateKernelRef::GetSupportedKey() const {
|
||||||
|
ParamsKey k;
|
||||||
|
k.EnableInputDataType(Datatype::F16);
|
||||||
|
k.EnableInputDataType(Datatype::F32);
|
||||||
|
k.EnableInputDataType(Datatype::INT32);
|
||||||
|
k.EnableOutputDataType(Datatype::F16);
|
||||||
|
k.EnableOutputDataType(Datatype::F32);
|
||||||
|
k.EnableOutputDataType(Datatype::INT32);
|
||||||
|
k.EnableOutputDataType(Datatype::INT8);
|
||||||
|
k.EnableOutputDataType(Datatype::UINT8);
|
||||||
|
k.EnableInputLayout(DataLayout::bfyx);
|
||||||
|
k.EnableOutputLayout(DataLayout::bfyx);
|
||||||
|
k.EnableInputLayout(DataLayout::bfzyx);
|
||||||
|
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||||
|
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||||
|
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||||
|
k.EnableTensorOffset();
|
||||||
|
k.EnableTensorPitches();
|
||||||
|
k.EnableBatching();
|
||||||
|
k.EnableDifferentTypes();
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline std::string GetOrderString(std::vector<std::string>& order) {
|
||||||
|
std::string order_str = order[0];
|
||||||
|
for (size_t i = 1; i < order.size(); i++)
|
||||||
|
order_str += ", " + order[i];
|
||||||
|
|
||||||
|
return order_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline std::vector<std::string> GetDefaultOrder(size_t size) {
|
||||||
|
std::vector<std::string> default_order;
|
||||||
|
if (size <= 4) {
|
||||||
|
default_order = {"b", "f", "y", "x"};
|
||||||
|
} else if (size == 5) {
|
||||||
|
default_order = {"b", "f", "z", "y", "x"};
|
||||||
|
} else if (size == 6) {
|
||||||
|
default_order = {"b", "f", "w", "z", "y", "x"};
|
||||||
|
}
|
||||||
|
|
||||||
|
return default_order;
|
||||||
|
}
|
||||||
|
|
||||||
|
ScatterNDUpdateKernelRef::DispatchData
|
||||||
|
ScatterNDUpdateKernelRef::SetDefault(const scatter_nd_update_params& params, const optional_params&, bool is_second) const {
|
||||||
|
DispatchData dispatchData;
|
||||||
|
|
||||||
|
if (!is_second) {
|
||||||
|
const auto& scope = params.output;
|
||||||
|
dispatchData.gws = { scope.X().v * scope.Y().v, scope.Z().v * scope.W().v, scope.Feature().v * scope.Batch().v };
|
||||||
|
} else {
|
||||||
|
auto indices_rank = params.indices_rank;
|
||||||
|
const auto& indices = params.inputs[1];
|
||||||
|
auto indices_dims = indices.LogicalDims();
|
||||||
|
|
||||||
|
if (indices_dims.size() > 1) {
|
||||||
|
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatchData.indicesLastDim = indices_dims[indices_rank - 1];
|
||||||
|
size_t indices_set_size = 1;
|
||||||
|
for (size_t i = 0; i < (indices_rank - 1); i++) {
|
||||||
|
indices_set_size *= indices_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatchData.gws = {1, 1, indices_set_size};
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||||
|
|
||||||
|
return dispatchData;
|
||||||
|
}
|
||||||
|
|
||||||
|
JitConstants ScatterNDUpdateKernelRef::GetJitConstants(const scatter_nd_update_params& params) const {
|
||||||
|
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||||
|
|
||||||
|
if (!params.fused_ops.empty()) {
|
||||||
|
FusedOpsConfiguration conf1 = { "_FIRST_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
|
||||||
|
FusedOpsConfiguration conf2 = { "_SECOND_KERNEL", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
|
||||||
|
jit.Merge(MakeFusedOpsJitConstants(params, {conf1, conf2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
return jit;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ScatterNDUpdateKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||||
|
if (p.GetType() != KernelType:: SCATTER_ND_UPDATE || o.GetType() != KernelType::SCATTER_ND_UPDATE) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scatter_nd_update_params& params = static_cast<const scatter_nd_update_params&>(p);
|
||||||
|
auto input_dims = params.inputs[0].LogicalDims();
|
||||||
|
auto indices_dims = params.inputs[1].LogicalDims();
|
||||||
|
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||||
|
|
||||||
|
auto indices_rank = params.indices_rank;
|
||||||
|
if (indices_rank < 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices_dims[indices_rank - 1] > input_dims.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& fused_op : params.fused_ops) {
|
||||||
|
if (!IsFusedPrimitiveSupported(fused_op))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string GetInputBlockND(const scatter_nd_update_params& params) {
|
||||||
|
const auto& input = params.inputs[0];
|
||||||
|
auto input_dims = input.LogicalDims();
|
||||||
|
std::reverse(input_dims.begin(), input_dims.end());
|
||||||
|
while (!input_dims.empty() && input_dims.back() == 1) {
|
||||||
|
input_dims.pop_back();
|
||||||
|
}
|
||||||
|
const int rank = static_cast<int>(input_dims.size());
|
||||||
|
std::vector<size_t> block_nd(rank + 1);
|
||||||
|
block_nd[rank] = 1;
|
||||||
|
for (int idx = (rank - 1); idx >= 0; idx--) {
|
||||||
|
block_nd[idx] = input_dims[idx] * block_nd[idx + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream s;
|
||||||
|
for (int i = 0; i < (rank + 1); i++) {
|
||||||
|
if (i < rank) {
|
||||||
|
s << block_nd[i] << ",";
|
||||||
|
} else {
|
||||||
|
s << block_nd[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto str_result = s.str();
|
||||||
|
return str_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelsData ScatterNDUpdateKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||||
|
if (!Validate(params, options)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelData kd = KernelData::Default<scatter_nd_update_params>(params, 2);
|
||||||
|
scatter_nd_update_params& newParams = *static_cast<scatter_nd_update_params*>(kd.params.get());
|
||||||
|
auto cldnn_jit = GetJitConstants(newParams);
|
||||||
|
|
||||||
|
// First iter - copy input data to output data
|
||||||
|
// Second iter - update values specified by updates at specific index position specified by indices
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
auto dispatchData = SetDefault(newParams, options, (i == 1));
|
||||||
|
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
|
||||||
|
|
||||||
|
if (i == 1) {
|
||||||
|
cldnn_jit.AddConstant(MakeJitConstant("IS_SECOND_ITER", "true"));
|
||||||
|
cldnn_jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", dispatchData.indicesLastDim));
|
||||||
|
cldnn_jit.AddConstant(MakeJitConstant("INPUT_BLOCK_ND", GetInputBlockND(newParams)));
|
||||||
|
}
|
||||||
|
std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||||
|
|
||||||
|
clKernelData& kernel = kd.kernels[i];
|
||||||
|
|
||||||
|
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {kd};
|
||||||
|
}
|
||||||
|
} // namespace kernel_selector
|
@ -0,0 +1,62 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "kernel_base_opencl.h"
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// scatter_nd_update_params
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
struct scatter_nd_update_params : public base_params {
|
||||||
|
scatter_nd_update_params() : base_params(KernelType::SCATTER_ND_UPDATE), indices_rank(0) {}
|
||||||
|
|
||||||
|
size_t indices_rank;
|
||||||
|
|
||||||
|
virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// scatter_nd_update_optional_params
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
struct scatter_nd_update_optional_params : optional_params {
|
||||||
|
scatter_nd_update_optional_params() : optional_params(KernelType::SCATTER_ND_UPDATE) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ScatterNDUpdateKernelRef : public KernelBaseOpenCL {
|
||||||
|
public:
|
||||||
|
struct DispatchData : public CommonDispatchData {
|
||||||
|
size_t indicesLastDim;
|
||||||
|
};
|
||||||
|
|
||||||
|
ScatterNDUpdateKernelRef() : KernelBaseOpenCL("scatter_nd_update_ref") {}
|
||||||
|
virtual ~ScatterNDUpdateKernelRef() {}
|
||||||
|
virtual JitConstants GetJitConstants(const scatter_nd_update_params& params) const;
|
||||||
|
virtual DispatchData SetDefault(const scatter_nd_update_params& params, const optional_params&, bool is_second) const;
|
||||||
|
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||||
|
ParamsKey GetSupportedKey() const override;
|
||||||
|
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||||
|
return { FusedOpType::QUANTIZE,
|
||||||
|
FusedOpType::SCALE,
|
||||||
|
FusedOpType::ACTIVATION,
|
||||||
|
FusedOpType::ELTWISE };
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool Validate(const Params& p, const optional_params& o) const override;
|
||||||
|
};
|
||||||
|
} // namespace kernel_selector
|
@ -0,0 +1,27 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "scatter_nd_update_kernel_selector.h"
|
||||||
|
#include "scatter_nd_update_kernel_ref.h"
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
|
||||||
|
scatter_nd_update_kernel_selector::scatter_nd_update_kernel_selector() { Attach<ScatterNDUpdateKernelRef>(); }
|
||||||
|
|
||||||
|
KernelsData scatter_nd_update_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
|
||||||
|
return GetNaiveBestKernel(params, options, KernelType::SCATTER_ND_UPDATE);
|
||||||
|
}
|
||||||
|
} // namespace kernel_selector
|
@ -0,0 +1,35 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "kernel_selector.h"
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
class scatter_nd_update_kernel_selector : public kernel_selector_base {
|
||||||
|
public:
|
||||||
|
static scatter_nd_update_kernel_selector& Instance() {
|
||||||
|
static scatter_nd_update_kernel_selector instance_;
|
||||||
|
return instance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
scatter_nd_update_kernel_selector();
|
||||||
|
|
||||||
|
virtual ~scatter_nd_update_kernel_selector() {}
|
||||||
|
|
||||||
|
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
|
||||||
|
};
|
||||||
|
} // namespace kernel_selector
|
152
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl
vendored
Normal file
152
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/scatter_nd_update_ref.cl
vendored
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
#include "include/include_all.cl"
|
||||||
|
|
||||||
|
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
|
||||||
|
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
|
||||||
|
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
#define ORDER b,f,y,x
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
#define ORDER b,f,z,y,x
|
||||||
|
#elif OUTPUT_DIMS == 6
|
||||||
|
#define ORDER b,f,w,z,y,x
|
||||||
|
#endif
|
||||||
|
|
||||||
|
KERNEL(scatter_nd_update_ref)(const __global INPUT0_TYPE* data,
|
||||||
|
const __global INPUT1_TYPE* indices,
|
||||||
|
const __global INPUT2_TYPE* updates,
|
||||||
|
__global OUTPUT_TYPE* output
|
||||||
|
#if HAS_FUSED_OPS_DECLS
|
||||||
|
, FUSED_OPS_DECLS
|
||||||
|
#endif
|
||||||
|
)
|
||||||
|
{
|
||||||
|
|
||||||
|
const uint dim0 = get_global_id(0);
|
||||||
|
const uint dim1 = get_global_id(1);
|
||||||
|
const uint dim2 = get_global_id(2);
|
||||||
|
|
||||||
|
#ifndef IS_SECOND_ITER // First kernel
|
||||||
|
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||||
|
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||||
|
const uint z = dim1 % OUTPUT_SIZE_Z;
|
||||||
|
const uint w = dim1 / OUTPUT_SIZE_Z;
|
||||||
|
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||||
|
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||||
|
|
||||||
|
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
|
||||||
|
INPUT0_TYPE val = data[output_idx];
|
||||||
|
#if HAS_FUSED_OPS
|
||||||
|
FUSED_OPS_FIRST_KERNEL;
|
||||||
|
output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_FIRST_KERNEL);
|
||||||
|
#else
|
||||||
|
output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#else // Second kernel
|
||||||
|
|
||||||
|
const uint blockND[] = {INPUT_BLOCK_ND};
|
||||||
|
const uint k = INDICES_LAST_DIM;
|
||||||
|
const uint size_to_update = blockND[INDICES_LAST_DIM];
|
||||||
|
const uint indices_idx = dim2;
|
||||||
|
const uint indices_offset = indices_idx * k;
|
||||||
|
uint dst_offset = 0;
|
||||||
|
|
||||||
|
for (uint i = 0; i < k; i++) {
|
||||||
|
INPUT1_TYPE idxValue = indices[indices_offset + i];
|
||||||
|
dst_offset += idxValue * blockND[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
uint update_offset = indices_idx * size_to_update;
|
||||||
|
|
||||||
|
for (int i = 0; i < size_to_update; i++) {
|
||||||
|
uint dst_idx = dst_offset + i;
|
||||||
|
uint up_idx = update_offset + i;
|
||||||
|
INPUT2_TYPE val = updates[up_idx];
|
||||||
|
|
||||||
|
#if HAS_FUSED_OPS
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
const uint y_pitch = OUTPUT_SIZE_X;
|
||||||
|
const uint f_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||||
|
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||||
|
|
||||||
|
const uint b_remain = dst_idx % b_pitch;
|
||||||
|
const uint f_remain = b_remain % f_pitch;
|
||||||
|
const uint y_remain = f_remain % y_pitch;
|
||||||
|
|
||||||
|
const uint b = dst_idx / b_pitch;
|
||||||
|
const uint f = b_remain / f_pitch;
|
||||||
|
const uint y = f_remain / y_pitch;
|
||||||
|
const uint x = y_remain;
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
const uint y_pitch = OUTPUT_SIZE_X;
|
||||||
|
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||||
|
const uint f_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||||
|
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||||
|
|
||||||
|
const uint b_remain = dst_idx % b_pitch;
|
||||||
|
const uint f_remain = b_remain % f_pitch;
|
||||||
|
const uint z_remain = f_remain % z_pitch;
|
||||||
|
const uint y_remain = z_remain % y_pitch;
|
||||||
|
|
||||||
|
const uint b = dst_idx / b_pitch;
|
||||||
|
const uint f = b_remain / f_pitch;
|
||||||
|
const uint z = f_remain / z_pitch;
|
||||||
|
const uint y = z_remain / y_pitch;
|
||||||
|
const uint x = y_remain;
|
||||||
|
#elif OUTPUT_DIMS == 6
|
||||||
|
const uint y_pitch = OUTPUT_SIZE_X;
|
||||||
|
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||||
|
const uint w_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||||
|
const uint f_pitch = w_pitch * OUTPUT_SIZE_W;
|
||||||
|
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||||
|
|
||||||
|
const uint b_remain = dst_idx % b_pitch;
|
||||||
|
const uint f_remain = b_remain % f_pitch;
|
||||||
|
const uint w_remain = f_remain % w_pitch;
|
||||||
|
const uint z_remain = w_remain % z_pitch;
|
||||||
|
const uint y_remain = z_remain % y_pitch;
|
||||||
|
|
||||||
|
const uint b = dst_idx / b_pitch;
|
||||||
|
const uint f = b_remain / f_pitch;
|
||||||
|
const uint w = f_remain / w_pitch;
|
||||||
|
const uint z = w_remain / z_pitch;
|
||||||
|
const uint y = z_remain / y_pitch;
|
||||||
|
const uint x = y_remain;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
FUSED_OPS_SECOND_KERNEL;
|
||||||
|
output[dst_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
|
||||||
|
#else
|
||||||
|
output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GET_UPDATES_INDEX
|
||||||
|
#undef GET_UPDATES_INDEX
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GET_OUTPUT_INDEX
|
||||||
|
#undef GET_OUTPUT_INDEX
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ORDER
|
||||||
|
#undef ORDER
|
||||||
|
#endif
|
@ -69,6 +69,7 @@ void register_implementations_gpu() {
|
|||||||
REGISTER_GPU(roi_pooling);
|
REGISTER_GPU(roi_pooling);
|
||||||
REGISTER_GPU(scale);
|
REGISTER_GPU(scale);
|
||||||
REGISTER_GPU(scatter_update);
|
REGISTER_GPU(scatter_update);
|
||||||
|
REGISTER_GPU(scatter_nd_update);
|
||||||
REGISTER_GPU(scatter_elements_update);
|
REGISTER_GPU(scatter_elements_update);
|
||||||
REGISTER_GPU(select);
|
REGISTER_GPU(select);
|
||||||
REGISTER_GPU(shuffle_channels);
|
REGISTER_GPU(shuffle_channels);
|
||||||
|
@ -62,6 +62,7 @@
|
|||||||
#include "api/scale.hpp"
|
#include "api/scale.hpp"
|
||||||
#include "api/scatter_update.hpp"
|
#include "api/scatter_update.hpp"
|
||||||
#include "api/scatter_elements_update.hpp"
|
#include "api/scatter_elements_update.hpp"
|
||||||
|
#include "api/scatter_nd_update.hpp"
|
||||||
#include "api/select.hpp"
|
#include "api/select.hpp"
|
||||||
#include "api/shuffle_channels.hpp"
|
#include "api/shuffle_channels.hpp"
|
||||||
#include "api/softmax.hpp"
|
#include "api/softmax.hpp"
|
||||||
@ -138,6 +139,7 @@ REGISTER_GPU(roi_pooling);
|
|||||||
REGISTER_GPU(scale);
|
REGISTER_GPU(scale);
|
||||||
REGISTER_GPU(scatter_update);
|
REGISTER_GPU(scatter_update);
|
||||||
REGISTER_GPU(scatter_elements_update);
|
REGISTER_GPU(scatter_elements_update);
|
||||||
|
REGISTER_GPU(scatter_nd_update);
|
||||||
REGISTER_GPU(select);
|
REGISTER_GPU(select);
|
||||||
REGISTER_GPU(shuffle_channels);
|
REGISTER_GPU(shuffle_channels);
|
||||||
REGISTER_GPU(softmax);
|
REGISTER_GPU(softmax);
|
||||||
|
78
inference-engine/thirdparty/clDNN/src/gpu/scatter_nd_update_gpu.cpp
vendored
Normal file
78
inference-engine/thirdparty/clDNN/src/gpu/scatter_nd_update_gpu.cpp
vendored
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "scatter_nd_update_inst.h"
|
||||||
|
#include "primitive_gpu_base.h"
|
||||||
|
#include "implementation_map.h"
|
||||||
|
#include "kernel_selector_helper.h"
|
||||||
|
#include "scatter_update/scatter_nd_update_kernel_selector.h"
|
||||||
|
#include "scatter_update/scatter_nd_update_kernel_ref.h"
|
||||||
|
#include "error_handler.h"
|
||||||
|
|
||||||
|
using namespace cldnn;
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
struct scatter_nd_update_gpu : typed_primitive_gpu_impl<scatter_nd_update> {
|
||||||
|
using parent = typed_primitive_gpu_impl<scatter_nd_update>;
|
||||||
|
using parent::parent;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static primitive_impl* create(const scatter_nd_update_node& arg) {
|
||||||
|
auto scatter_nd_update_params = get_default_params<kernel_selector::scatter_nd_update_params>(arg);
|
||||||
|
auto scatter_nd_update_optional_params =
|
||||||
|
get_default_optional_params<kernel_selector::scatter_nd_update_optional_params>(arg.get_program());
|
||||||
|
|
||||||
|
scatter_nd_update_params.indices_rank = arg.get_primitive()->indices_rank;
|
||||||
|
|
||||||
|
scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
|
||||||
|
scatter_nd_update_params.inputs.push_back(convert_data_tensor(arg.input(2).get_output_layout()));
|
||||||
|
|
||||||
|
auto& kernel_selector = kernel_selector::scatter_nd_update_kernel_selector::Instance();
|
||||||
|
auto best_kernels = kernel_selector.GetBestKernels(scatter_nd_update_params, scatter_nd_update_optional_params);
|
||||||
|
|
||||||
|
CLDNN_ERROR_BOOL(arg.id(),
|
||||||
|
"Best_kernel.empty()",
|
||||||
|
best_kernels.empty(),
|
||||||
|
"Cannot find a proper kernel with this arguments");
|
||||||
|
|
||||||
|
auto scatter_nd_update = new scatter_nd_update_gpu(arg, best_kernels[0]);
|
||||||
|
|
||||||
|
return scatter_nd_update;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
attach_scatter_nd_update_gpu::attach_scatter_nd_update_gpu() {
|
||||||
|
auto val_fw = scatter_nd_update_gpu::create;
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
|
||||||
|
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
|
||||||
|
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
|
||||||
|
implementation_map<scatter_nd_update>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace cldnn
|
@ -45,6 +45,7 @@
|
|||||||
#include "space_to_depth_inst.h"
|
#include "space_to_depth_inst.h"
|
||||||
#include "gather_inst.h"
|
#include "gather_inst.h"
|
||||||
#include "scatter_update_inst.h"
|
#include "scatter_update_inst.h"
|
||||||
|
#include "scatter_nd_update_inst.h"
|
||||||
#include "scatter_elements_update_inst.h"
|
#include "scatter_elements_update_inst.h"
|
||||||
#include "reverse_sequence_inst.h"
|
#include "reverse_sequence_inst.h"
|
||||||
#include "shuffle_channels_inst.h"
|
#include "shuffle_channels_inst.h"
|
||||||
@ -206,6 +207,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
|
|||||||
!input.is_type<softmax>() && !input.is_type<resample>() && !input.is_type<mvn>() &&
|
!input.is_type<softmax>() && !input.is_type<resample>() && !input.is_type<mvn>() &&
|
||||||
!input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
|
!input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
|
||||||
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
|
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
|
||||||
|
!input.is_type<scatter_nd_update>() &&
|
||||||
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
|
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
|
||||||
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
|
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
|
||||||
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
|
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
|
||||||
@ -540,6 +542,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_update>();
|
should_fuse |= input_data.is_type<scatter_update>();
|
||||||
|
|
||||||
|
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_elements_update>();
|
should_fuse |= input_data.is_type<scatter_elements_update>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<depth_to_space>();
|
should_fuse |= input_data.is_type<depth_to_space>();
|
||||||
@ -604,6 +608,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_update>();
|
should_fuse |= input_data.is_type<scatter_update>();
|
||||||
|
|
||||||
|
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_elements_update>();
|
should_fuse |= input_data.is_type<scatter_elements_update>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<depth_to_space>();
|
should_fuse |= input_data.is_type<depth_to_space>();
|
||||||
@ -690,6 +696,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
|
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
|
||||||
|
|
||||||
|
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_elements_update>() && quantize_node.get_scale_shift_opt();
|
should_fuse |= input_data.is_type<scatter_elements_update>() && quantize_node.get_scale_shift_opt();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<permute>() && quantize_node.get_scale_shift_opt();
|
should_fuse |= input_data.is_type<permute>() && quantize_node.get_scale_shift_opt();
|
||||||
@ -745,6 +753,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
(parents[i]->is_type<space_to_batch>()) ||
|
(parents[i]->is_type<space_to_batch>()) ||
|
||||||
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
|
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
|
||||||
(parents[i]->is_type<scale>()) ||
|
(parents[i]->is_type<scale>()) ||
|
||||||
|
(parents[i]->is_type<scatter_nd_update>()) ||
|
||||||
(parents[i]->is_type<scatter_elements_update>()) ||
|
(parents[i]->is_type<scatter_elements_update>()) ||
|
||||||
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
|
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
|
||||||
(parents[i]->is_type<depth_to_space>() && dts_supports_fusings(parents[i]->as<depth_to_space>())) ||
|
(parents[i]->is_type<depth_to_space>() && dts_supports_fusings(parents[i]->as<depth_to_space>())) ||
|
||||||
|
49
inference-engine/thirdparty/clDNN/src/include/scatter_nd_update_inst.h
vendored
Normal file
49
inference-engine/thirdparty/clDNN/src/include/scatter_nd_update_inst.h
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#pragma once
|
||||||
|
#include "api/scatter_nd_update.hpp"
|
||||||
|
#include "primitive_inst.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
template <>
|
||||||
|
struct typed_program_node<scatter_nd_update> : public typed_program_node_base<scatter_nd_update> {
|
||||||
|
using parent = typed_program_node_base<scatter_nd_update>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using parent::parent;
|
||||||
|
|
||||||
|
program_node& input(size_t index = 0) const { return get_dependency(index); }
|
||||||
|
};
|
||||||
|
|
||||||
|
using scatter_nd_update_node = typed_program_node<scatter_nd_update>;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class typed_primitive_inst<scatter_nd_update> : public typed_primitive_inst_base<scatter_nd_update> {
|
||||||
|
using parent = typed_primitive_inst_base<scatter_nd_update>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static layout calc_output_layout(scatter_nd_update_node const& node);
|
||||||
|
static std::string to_string(scatter_nd_update_node const& node);
|
||||||
|
|
||||||
|
public:
|
||||||
|
typed_primitive_inst(network_impl& network, scatter_nd_update_node const& desc);
|
||||||
|
};
|
||||||
|
|
||||||
|
using scatter_nd_update_inst = typed_primitive_inst<scatter_nd_update>;
|
||||||
|
} // namespace cldnn
|
66
inference-engine/thirdparty/clDNN/src/scatter_nd_update.cpp
vendored
Normal file
66
inference-engine/thirdparty/clDNN/src/scatter_nd_update.cpp
vendored
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "scatter_nd_update_inst.h"
|
||||||
|
|
||||||
|
#include "primitive_type_base.h"
|
||||||
|
#include "error_handler.h"
|
||||||
|
#include "json_object.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
primitive_type_id scatter_nd_update::type_id() {
|
||||||
|
static primitive_type_base<scatter_nd_update> instance;
|
||||||
|
return &instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
layout scatter_nd_update_inst::calc_output_layout(scatter_nd_update_node const& node) {
|
||||||
|
auto input_layout = node.input(0).get_output_layout();
|
||||||
|
|
||||||
|
auto output_shape = input_layout.size;
|
||||||
|
auto input_format = input_layout.format;
|
||||||
|
auto output_type = input_layout.data_type;
|
||||||
|
|
||||||
|
if (node.has_fused_primitives()) {
|
||||||
|
output_type = node.get_fused_output_layout().data_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
return layout{output_type, input_format, output_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node) {
|
||||||
|
auto desc = node.get_primitive();
|
||||||
|
auto node_info = node.desc_to_json();
|
||||||
|
auto& input = node.input();
|
||||||
|
|
||||||
|
std::stringstream primitive_description;
|
||||||
|
|
||||||
|
json_composite scatter_nd_update_info;
|
||||||
|
scatter_nd_update_info.add("input id", input.id());
|
||||||
|
scatter_nd_update_info.add("input shape", node.input(0).get_output_layout().size.to_string());
|
||||||
|
scatter_nd_update_info.add("indices shape", node.input(1).get_output_layout().size.to_string());
|
||||||
|
scatter_nd_update_info.add("updates shape", node.input(2).get_output_layout().size.to_string());
|
||||||
|
|
||||||
|
node_info->add("scatter_nd_update info", scatter_nd_update_info);
|
||||||
|
node_info->dump(primitive_description);
|
||||||
|
|
||||||
|
return primitive_description.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
scatter_nd_update_inst::typed_primitive_inst(network_impl& network, scatter_nd_update_node const& node) : parent(network, node) {}
|
||||||
|
|
||||||
|
} // namespace cldnn
|
@ -35,6 +35,7 @@
|
|||||||
#include "api/permute.hpp"
|
#include "api/permute.hpp"
|
||||||
#include "api/gather.hpp"
|
#include "api/gather.hpp"
|
||||||
#include "api/scatter_update.hpp"
|
#include "api/scatter_update.hpp"
|
||||||
|
#include "api/scatter_nd_update.hpp"
|
||||||
#include "api/scatter_elements_update.hpp"
|
#include "api/scatter_elements_update.hpp"
|
||||||
#include "api/depth_to_space.hpp"
|
#include "api/depth_to_space.hpp"
|
||||||
#include "api/space_to_depth.hpp"
|
#include "api/space_to_depth.hpp"
|
||||||
@ -7569,3 +7570,224 @@ INSTANTIATE_TEST_CASE_P(DISABLED_fusings_gpu,
|
|||||||
reduce_test_params{CASE_REDUCE_I8_3, 2, 4, reduce_mode::mean, {reduce::along_x}, true, "reduce_ref"},
|
reduce_test_params{CASE_REDUCE_I8_3, 2, 4, reduce_mode::mean, {reduce::along_x}, true, "reduce_ref"},
|
||||||
reduce_test_params{CASE_REDUCE_U8_3, 2, 4, reduce_mode::l2, {reduce::along_x}, true, "reduce_ref"}
|
reduce_test_params{CASE_REDUCE_U8_3, 2, 4, reduce_mode::l2, {reduce::along_x}, true, "reduce_ref"}
|
||||||
}), );
|
}), );
|
||||||
|
|
||||||
|
|
||||||
|
/* ----------------------------------------------------------------------------------------------------- */
|
||||||
|
/* ------------------------------------------ ScatterNDUpdate cases ------------------------------ */
|
||||||
|
/* ----------------------------------------------------------------------------------------------------- */
|
||||||
|
struct scatter_nd_update_test_params {
|
||||||
|
tensor input_shape;
|
||||||
|
tensor indices_shape;
|
||||||
|
tensor updates_shape;
|
||||||
|
int max_number_in_indices;
|
||||||
|
int indices_rank;
|
||||||
|
data_types data_type;
|
||||||
|
format input_format;
|
||||||
|
data_types default_type;
|
||||||
|
format default_format;
|
||||||
|
size_t expected_fused_primitives;
|
||||||
|
size_t expected_not_fused_primitives;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_4D_1 {6, 1, 1, 1}, {3, 1, 1, 1}, {3, 1, 1, 1}, 6, 1, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_4D_2 {6, 6, 1, 1}, {3, 2, 1, 1}, {3, 1, 1, 1}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_4D_3 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_4D_4 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_4D_5 {6, 7, 8, 9}, {6, 2, 1, 1}, {6, 9, 1, 8}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_4D_6 {6, 7, 8, 9}, {6, 3, 1, 1}, {6, 8, 1, 1}, 6, 2, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_1 {6, 7, 8, 9, 10}, {5, 1, 1, 1, 1}, {5, 7, 8, 9, 10}, 6, 1, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_2 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 1}, {5, 10, 1, 8, 9}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_3 {6, 7, 8, 9, 10}, {5, 3, 1, 1, 1}, {5, 9, 1, 1, 8}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_4 {6, 7, 8, 9, 10}, {5, 4, 1, 1, 1}, {5, 8, 1, 1, 1}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_5 {6, 7, 8, 9, 10}, {5, 5, 1, 1, 1}, {5, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_6 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 2}, {5, 2, 8, 9, 10}, 6, 3, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_7 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 3}, {5, 2, 1, 8, 9}, 6, 3, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_8 {6, 7, 8, 9, 10}, {5, 2, 1, 4, 3}, {5, 2, 1, 8, 3}, 6, 4, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_5D_9 {6, 7, 8, 9, 10}, {5, 2, 1, 3, 3}, {5, 2, 8, 9, 3}, 6, 4, data_types::f16, format::bfzyx, data_types::f16, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_6D_1 {6, 7, 8, 9, 10, 11}, {5, 1, 1, 1}, {5, 7, 8, 9, 10, 11}, 6, 1, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_6D_2 {6, 7, 8, 9, 10, 11}, {5, 2, 1, 1}, {5, 11, 1, 8, 9, 10}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_6D_3 {6, 7, 8, 9, 10, 11}, {5, 3, 1, 1}, {5, 10, 1, 1, 8, 9}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_6D_4 {6, 7, 8, 9, 10, 11}, {5, 4, 1, 1}, {5, 9, 1, 1, 1, 8}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_6D_5 {6, 7, 8, 9, 2, 2}, {5, 5, 1, 1}, {5, 8, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP16_6D_6 {6, 7, 8, 9, 2, 2}, {5, 6, 1, 1}, {5, 1, 1, 1, 1, 1}, 6, 2, data_types::f16, format::bfwzyx, data_types::f16, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_4D_1 {6, 1, 1, 1}, {3, 1, 1, 1}, {3, 1, 1, 1}, 6, 1, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_4D_2 {6, 6, 1, 1}, {3, 2, 1, 1}, {3, 1, 1, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_4D_3 {6, 7, 8, 1}, {5, 1, 1, 1}, {5, 7, 8, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_4D_4 {6, 7, 8, 9}, {5, 1, 1, 1}, {5, 7, 8, 9}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_4D_5 {6, 7, 8, 9}, {6, 2, 1, 1}, {6, 9, 1, 8}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_4D_6 {6, 7, 8, 9}, {6, 3, 1, 1}, {6, 8, 1, 1}, 6, 2, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_5D_1 {6, 7, 8, 9, 10}, {5, 1, 1, 1, 1}, {5, 7, 8, 9, 10}, 6, 1, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_5D_2 {6, 7, 8, 9, 10}, {5, 2, 1, 1, 1}, {5, 10, 1, 8, 9}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_5D_3 {6, 7, 8, 9, 10}, {5, 3, 1, 1, 1}, {5, 9, 1, 1, 8}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_5D_4 {6, 7, 8, 9, 10}, {5, 4, 1, 1, 1}, {5, 8, 1, 1, 1}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_5D_5 {6, 7, 8, 9, 10}, {5, 5, 1, 1, 1}, {5, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfzyx, data_types::f32, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_6D_1 {6, 7, 8, 9, 10, 11}, {5, 1, 1, 1}, {5, 7, 8, 9, 10, 11}, 6, 1, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_6D_2 {6, 7, 8, 9, 10, 11}, {5, 2, 1, 1}, {5, 11, 1, 8, 9, 10}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_6D_3 {6, 7, 8, 9, 10, 11}, {5, 3, 1, 1}, {5, 10, 1, 1, 8, 9}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_6D_4 {6, 7, 8, 9, 10, 11}, {5, 4, 1, 1}, {5, 9, 1, 1, 1, 8}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_6D_5 {6, 7, 8, 9, 2, 2}, {5, 5, 1, 1}, {5, 8, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||||
|
#define CASE_SCATTER_ND_UPDATE_FP32_6D_6 {6, 7, 8, 9, 2, 2}, {5, 6, 1, 1}, {5, 1, 1, 1, 1, 1}, 6, 2, data_types::f32, format::bfwzyx, data_types::f32, format::bfyx
|
||||||
|
|
||||||
|
class ScatterNDUpdatePrimitiveFusingTest : public ::BaseFusingTest<scatter_nd_update_test_params> {
|
||||||
|
public:
|
||||||
|
void execute(scatter_nd_update_test_params& p) {
|
||||||
|
auto input_prim = get_mem(get_input_layout(p));
|
||||||
|
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||||
|
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||||
|
network_fused.set_input_data("input", input_prim);
|
||||||
|
network_not_fused.set_input_data("input", input_prim);
|
||||||
|
compare(network_not_fused, network_fused, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_input_layout(scatter_nd_update_test_params& p) {
|
||||||
|
return layout{ p.data_type, p.input_format, p.input_shape };
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_indices_layout(scatter_nd_update_test_params& p) {
|
||||||
|
return layout{ p.data_type, p.input_format, p.indices_shape };
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_updates_layout(scatter_nd_update_test_params& p) {
|
||||||
|
return layout{ p.data_type, p.input_format, p.updates_shape };
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_per_channel_layout(scatter_nd_update_test_params& p) {
|
||||||
|
return layout{ p.default_type, p.default_format, tensor{1, p.input_shape.feature[0], 1, 1} };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class scatter_nd_update_quantize : public ScatterNDUpdatePrimitiveFusingTest {};
|
||||||
|
TEST_P(scatter_nd_update_quantize, basic) {
|
||||||
|
auto p = GetParam();
|
||||||
|
create_topologies(input_layout("input", get_input_layout(p)),
|
||||||
|
data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)),
|
||||||
|
data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)),
|
||||||
|
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||||
|
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||||
|
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||||
|
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||||
|
scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank),
|
||||||
|
quantize("quantize", "scatter_nd_update_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
||||||
|
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||||
|
);
|
||||||
|
tolerance = 1.f;
|
||||||
|
execute(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_quantize,
|
||||||
|
::testing::ValuesIn(std::vector<scatter_nd_update_test_params>{
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_1, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_2, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_3, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_4, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_5, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_6, 2, 3 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_1, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_2, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_3, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 3 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_3, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_4, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_5, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_6, 2, 3 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_1, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_2, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_3, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_4, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_5, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_6, 2, 3 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_1, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_2, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_3, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_4, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_5, 2, 3 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_1, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_2, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_3, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 3 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 3 },
|
||||||
|
}), );
|
||||||
|
|
||||||
|
class scatter_nd_update_scale_activation_eltwise : public ScatterNDUpdatePrimitiveFusingTest {};
|
||||||
|
TEST_P(scatter_nd_update_scale_activation_eltwise, basic) {
|
||||||
|
auto p = GetParam();
|
||||||
|
create_topologies(input_layout("input", get_input_layout(p)),
|
||||||
|
data("scatter_nd_update_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices)),
|
||||||
|
data("scatter_nd_update_updates", get_mem(get_updates_layout(p), 0, 100)),
|
||||||
|
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
||||||
|
data("eltwise_data", get_mem(layout{ p.data_type, p.input_format, p.input_shape })),
|
||||||
|
scatter_nd_update("scatter_nd_update_prim", "input", "scatter_nd_update_indices", "scatter_nd_update_updates", p.indices_rank),
|
||||||
|
activation("activation", "scatter_nd_update_prim", activation_func::abs),
|
||||||
|
scale("scale", "activation", "scale_data"),
|
||||||
|
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
|
||||||
|
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||||
|
);
|
||||||
|
|
||||||
|
tolerance = 1.f;
|
||||||
|
execute(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise,
|
||||||
|
::testing::ValuesIn(std::vector<scatter_nd_update_test_params>{
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_1, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_2, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_3, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_4, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_5, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_4D_6, 2, 5 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_1, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_2, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_3, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 5 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_3, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_4, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_5, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_6, 2, 5 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_1, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_2, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_3, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_4, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_5, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_4D_6, 2, 5 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_1, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_2, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_3, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_4, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_5D_5, 2, 5 },
|
||||||
|
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_1, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_2, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_3, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_4, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
|
||||||
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
|
||||||
|
}), );
|
||||||
|
3927
inference-engine/thirdparty/clDNN/tests/test_cases/scatter_nd_update_gpu_test.cpp
vendored
Normal file
3927
inference-engine/thirdparty/clDNN/tests/test_cases/scatter_nd_update_gpu_test.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user