[IE CLDNN] Add new operation GatherND-5 to IE clDNN plugin (#4857)
This commit is contained in:
parent
9a46851a33
commit
9f93509b1c
@ -193,10 +193,10 @@ REGISTER_FACTORY(v5, LogSoftmax);
|
|||||||
REGISTER_FACTORY(v5, LSTMSequence);
|
REGISTER_FACTORY(v5, LSTMSequence);
|
||||||
//REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
|
//REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
|
||||||
REGISTER_FACTORY(v5, Round);
|
REGISTER_FACTORY(v5, Round);
|
||||||
|
REGISTER_FACTORY(v5, GatherND);
|
||||||
|
|
||||||
// ----------------------------- Unsupported v5 ops ----------------------------- //
|
// ----------------------------- Unsupported v5 ops ----------------------------- //
|
||||||
// REGISTER_FACTORY(v5, BatchNormInference);
|
// REGISTER_FACTORY(v5, BatchNormInference);
|
||||||
// REGISTER_FACTORY(v5, GatherND);
|
|
||||||
// REGISTER_FACTORY(v5, GRUSequence);
|
// REGISTER_FACTORY(v5, GRUSequence);
|
||||||
// REGISTER_FACTORY(v5, Loop);
|
// REGISTER_FACTORY(v5, Loop);
|
||||||
// REGISTER_FACTORY(v5, RNNSequence);
|
// REGISTER_FACTORY(v5, RNNSequence);
|
||||||
|
36
inference-engine/src/cldnn_engine/ops/gather_nd.cpp
Normal file
36
inference-engine/src/cldnn_engine/ops/gather_nd.cpp
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "cldnn_program.h"
|
||||||
|
#include "cldnn_common_utils.h"
|
||||||
|
|
||||||
|
#include "ngraph/op/gather_nd.hpp"
|
||||||
|
#include "ngraph/op/constant.hpp"
|
||||||
|
|
||||||
|
#include "api/gather_nd.hpp"
|
||||||
|
|
||||||
|
namespace CLDNNPlugin {
|
||||||
|
|
||||||
|
void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v5::GatherND>& op) {
|
||||||
|
p.ValidateInputs(op, {2});
|
||||||
|
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||||
|
std::string layerName = layer_type_name_ID(op);
|
||||||
|
|
||||||
|
int32_t indices_rank = static_cast<int32_t>(op->get_input_shape(1).size());
|
||||||
|
|
||||||
|
auto batch_dims = op->get_batch_dims();
|
||||||
|
|
||||||
|
auto primitive = cldnn::gather_nd(layerName,
|
||||||
|
inputPrimitives[0],
|
||||||
|
inputPrimitives[1],
|
||||||
|
indices_rank,
|
||||||
|
batch_dims);
|
||||||
|
|
||||||
|
p.AddPrimitive(primitive);
|
||||||
|
p.AddPrimitiveToProfiler(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_FACTORY_IMPL(v5, GatherND);
|
||||||
|
|
||||||
|
} // namespace CLDNNPlugin
|
@ -0,0 +1,81 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <ngraph/opsets/opset5.hpp>
|
||||||
|
|
||||||
|
#include "single_layer_tests/gather_nd.hpp"
|
||||||
|
#include "common_test_utils/test_constants.hpp"
|
||||||
|
|
||||||
|
using namespace LayerTestsDefinitions;
|
||||||
|
using namespace ngraph::opset5;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const std::vector<InferenceEngine::Precision> inputPrecisions = {
|
||||||
|
InferenceEngine::Precision::FP32,
|
||||||
|
InferenceEngine::Precision::FP16,
|
||||||
|
InferenceEngine::Precision::I32,
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<InferenceEngine::Precision> idxPrecisions = {
|
||||||
|
InferenceEngine::Precision::I32,
|
||||||
|
InferenceEngine::Precision::I64,
|
||||||
|
};
|
||||||
|
|
||||||
|
// set1
|
||||||
|
const auto gatherNDArgsSubset1 = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{ {2, 2}, {2, 3, 4} })), // Data shape
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{ {2, 1}, {2, 1, 1} })), // Indices shape
|
||||||
|
::testing::ValuesIn(std::vector<int>({ 0, 1 })) // Batch dims
|
||||||
|
);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_GatherND_set1, GatherNDLayerTest,
|
||||||
|
::testing::Combine(
|
||||||
|
gatherNDArgsSubset1,
|
||||||
|
::testing::ValuesIn(inputPrecisions),
|
||||||
|
::testing::ValuesIn(idxPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::Values<Config>({})),
|
||||||
|
GatherNDLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
// set2
|
||||||
|
const auto gatherNDArgsSubset2 = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{ {15, 12, 20, 15, 2}, {15, 12, 18, 7, 17} })), // Data shape
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{ {15, 12, 2}, {15, 12, 5, 9, 1, 3} })), // Indices shape
|
||||||
|
::testing::ValuesIn(std::vector<int>({ 1, 2 })) // Batch dims
|
||||||
|
);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_GatherND_set2, GatherNDLayerTest,
|
||||||
|
::testing::Combine(
|
||||||
|
gatherNDArgsSubset2,
|
||||||
|
::testing::ValuesIn(inputPrecisions),
|
||||||
|
::testing::ValuesIn(idxPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::Values<Config>({})),
|
||||||
|
GatherNDLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
// set3
|
||||||
|
const auto gatherNDArgsSubset3 = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{ {4, 3, 2, 5, 5, 2}, {4, 3, 2, 5, 7, 2} })), // Data shape
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{ {4, 3, 2, 5, 1}, {4, 3, 2, 5, 6, 2} })), // Indices shape
|
||||||
|
::testing::ValuesIn(std::vector<int>({ 3, 4 })) // Batch dims
|
||||||
|
);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_GatherND_set3, GatherNDLayerTest,
|
||||||
|
::testing::Combine(
|
||||||
|
gatherNDArgsSubset3,
|
||||||
|
::testing::ValuesIn(inputPrecisions),
|
||||||
|
::testing::ValuesIn(idxPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||||
|
::testing::Values<Config>({})),
|
||||||
|
GatherNDLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
} // namespace
|
57
inference-engine/thirdparty/clDNN/api/gather_nd.hpp
vendored
Normal file
57
inference-engine/thirdparty/clDNN/api/gather_nd.hpp
vendored
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#pragma once
|
||||||
|
#include "primitive.hpp"
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
/// @addtogroup cpp_api C++ API
|
||||||
|
/// @{
|
||||||
|
/// @addtogroup cpp_topology Network Topology
|
||||||
|
/// @{
|
||||||
|
/// @addtogroup cpp_primitives Primitives
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// @brief
|
||||||
|
/// @details
|
||||||
|
struct gather_nd : public primitive_base<gather_nd> {
|
||||||
|
CLDNN_DECLARE_PRIMITIVE(gather_nd)
|
||||||
|
|
||||||
|
/// @brief Constructs gather_nd primitive.
|
||||||
|
/// @param id This primitive id.
|
||||||
|
/// @param data Input data primitive id.
|
||||||
|
/// @param indices Input indexes primitive id.
|
||||||
|
/// @param indices_rank Rank of indices.
|
||||||
|
/// @param batch_dims batch_dims as an attribute of GatherND. Optional.
|
||||||
|
gather_nd(const primitive_id& id,
|
||||||
|
const primitive_id& data,
|
||||||
|
const primitive_id& indices,
|
||||||
|
const uint8_t indices_rank,
|
||||||
|
const uint8_t batch_dims = 0,
|
||||||
|
const padding& output_padding = padding())
|
||||||
|
: primitive_base(id, {data, indices}, output_padding), indices_rank(indices_rank), batch_dims(batch_dims) {}
|
||||||
|
|
||||||
|
/// @brief GatherND indices_rank
|
||||||
|
uint8_t indices_rank;
|
||||||
|
|
||||||
|
/// @brief GatherND batch_dims
|
||||||
|
uint8_t batch_dims;
|
||||||
|
};
|
||||||
|
/// @}
|
||||||
|
/// @}
|
||||||
|
/// @}
|
||||||
|
} // namespace cldnn
|
@ -47,6 +47,7 @@ enum class KernelType {
|
|||||||
CONTRACT,
|
CONTRACT,
|
||||||
ONE_HOT,
|
ONE_HOT,
|
||||||
GATHER,
|
GATHER,
|
||||||
|
GATHER_ND,
|
||||||
SCATTER_UPDATE,
|
SCATTER_UPDATE,
|
||||||
SCATTER_ND_UPDATE,
|
SCATTER_ND_UPDATE,
|
||||||
SCATTER_ELEMENTS_UPDATE,
|
SCATTER_ELEMENTS_UPDATE,
|
||||||
|
@ -0,0 +1,210 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "gather_nd_kernel_ref.h"
|
||||||
|
#include "kernel_selector_utils.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
|
||||||
|
ParamsKey GatherNDKernelRef::GetSupportedKey() const {
|
||||||
|
ParamsKey k;
|
||||||
|
k.EnableInputDataType(Datatype::F16);
|
||||||
|
k.EnableInputDataType(Datatype::F32);
|
||||||
|
k.EnableInputDataType(Datatype::INT32);
|
||||||
|
k.EnableOutputDataType(Datatype::F16);
|
||||||
|
k.EnableOutputDataType(Datatype::F32);
|
||||||
|
k.EnableOutputDataType(Datatype::INT32);
|
||||||
|
k.EnableOutputDataType(Datatype::INT8);
|
||||||
|
k.EnableOutputDataType(Datatype::UINT8);
|
||||||
|
k.EnableInputLayout(DataLayout::bfyx);
|
||||||
|
k.EnableOutputLayout(DataLayout::bfyx);
|
||||||
|
k.EnableInputLayout(DataLayout::bfzyx);
|
||||||
|
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||||
|
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||||
|
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||||
|
k.EnableTensorOffset();
|
||||||
|
k.EnableTensorPitches();
|
||||||
|
k.EnableBatching();
|
||||||
|
k.EnableDifferentTypes();
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline std::string GetOrderString(std::vector<std::string>& order) {
|
||||||
|
std::string order_str = order[0];
|
||||||
|
for (size_t i = 1; i < order.size(); i++)
|
||||||
|
order_str += ", " + order[i];
|
||||||
|
|
||||||
|
return order_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline std::vector<std::string> GetDefaultOrder(size_t size) {
|
||||||
|
std::vector<std::string> default_order;
|
||||||
|
if (size <= 4) {
|
||||||
|
default_order = { "b", "f", "y", "x" };
|
||||||
|
} else if (size == 5) {
|
||||||
|
default_order = { "b", "f", "z", "y", "x" };
|
||||||
|
} else if (size == 6) {
|
||||||
|
default_order = { "b", "f", "w", "z", "y", "x" };
|
||||||
|
}
|
||||||
|
|
||||||
|
return default_order;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommonDispatchData GatherNDKernelRef::SetDefault(const gather_nd_params& params, const optional_params&) const {
|
||||||
|
CommonDispatchData dispatchData;
|
||||||
|
|
||||||
|
auto indices_dims = params.inputs[1].LogicalDims();
|
||||||
|
|
||||||
|
if (indices_dims.size() > 1) {
|
||||||
|
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
indices_dims[params.indices_rank - 1] = 1; // set last dim of indices to 1
|
||||||
|
|
||||||
|
switch (params.inputs[1].GetLayout()) {
|
||||||
|
case DataLayout::bfyx:
|
||||||
|
dispatchData.gws = { indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] };
|
||||||
|
break;
|
||||||
|
|
||||||
|
case DataLayout::bfzyx:
|
||||||
|
dispatchData.gws = { indices_dims[4] * indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] };
|
||||||
|
break;
|
||||||
|
|
||||||
|
case DataLayout::bfwzyx:
|
||||||
|
dispatchData.gws = { indices_dims[5] * indices_dims[4], indices_dims[3] * indices_dims[2], indices_dims[1] * indices_dims[0] };
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Unsupported data layout for scatter elements update primitive");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||||
|
|
||||||
|
return dispatchData;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t GetIndicesLastDim(const gather_nd_params& params) {
|
||||||
|
// get indices dims
|
||||||
|
auto indices_dims = params.inputs[1].LogicalDims();
|
||||||
|
|
||||||
|
if (indices_dims.size() > 1) {
|
||||||
|
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto indices_last_dim = indices_dims[params.indices_rank - 1];
|
||||||
|
|
||||||
|
return indices_last_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t GetSliceSize(const gather_nd_params& params) {
|
||||||
|
// get input dims
|
||||||
|
auto input_dims = params.inputs[0].LogicalDims();
|
||||||
|
|
||||||
|
if (input_dims.size() > 1) {
|
||||||
|
std::reverse(input_dims.begin(), input_dims.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// get last dim of indices
|
||||||
|
auto indices_last_dim = GetIndicesLastDim(params);
|
||||||
|
|
||||||
|
// calculate slize size which is used in kernel to copy
|
||||||
|
size_t wi_slice_size = 1;
|
||||||
|
for (size_t i = params.batch_dims + indices_last_dim; i < input_dims.size(); i++) {
|
||||||
|
wi_slice_size *= input_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return wi_slice_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
JitConstants GatherNDKernelRef::GetJitConstants(const gather_nd_params& params) const {
|
||||||
|
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||||
|
|
||||||
|
jit.AddConstant(MakeJitConstant("INDICES_RANK", params.indices_rank));
|
||||||
|
jit.AddConstant(MakeJitConstant("BATCH_DIMS", params.batch_dims));
|
||||||
|
jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE", GetSliceSize(params)));
|
||||||
|
jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", GetIndicesLastDim(params)));
|
||||||
|
|
||||||
|
if (!params.fused_ops.empty()) {
|
||||||
|
FusedOpsConfiguration conf = { "", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
|
||||||
|
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
|
||||||
|
}
|
||||||
|
|
||||||
|
return jit;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GatherNDKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||||
|
if (p.GetType() != KernelType:: GATHER_ND || o.GetType() != KernelType::GATHER_ND) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const gather_nd_params& params = static_cast<const gather_nd_params&>(p);
|
||||||
|
auto input_dims = params.inputs[0].LogicalDims();
|
||||||
|
auto indices_dims = params.inputs[1].LogicalDims();
|
||||||
|
auto indices_rank = params.indices_rank;
|
||||||
|
auto batch_dims = params.batch_dims;
|
||||||
|
|
||||||
|
std::reverse(input_dims.begin(), input_dims.end());
|
||||||
|
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||||
|
|
||||||
|
if (indices_rank < 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch_dims + indices_dims[indices_rank - 1] > input_dims.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch_dims >= std::min(input_dims.size(), static_cast<size_t>(indices_rank))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint8_t i = 0; i < batch_dims; i++) {
|
||||||
|
if (input_dims[i] != indices_dims[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& fused_op : params.fused_ops) {
|
||||||
|
if (!IsFusedPrimitiveSupported(fused_op))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelsData GatherNDKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||||
|
if (!Validate(params, options)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelData kd = KernelData::Default<gather_nd_params>(params);
|
||||||
|
gather_nd_params& newParams = *static_cast<gather_nd_params*>(kd.params.get());
|
||||||
|
|
||||||
|
auto dispatchData = SetDefault(newParams, options);
|
||||||
|
auto cldnn_jit = GetJitConstants(newParams);
|
||||||
|
|
||||||
|
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
|
||||||
|
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||||
|
auto& kernel = kd.kernels[0];
|
||||||
|
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params));
|
||||||
|
|
||||||
|
return { kd };
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kernel_selector
|
@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "kernel_base_opencl.h"
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// gather_nd_params
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
struct gather_nd_params : public base_params {
|
||||||
|
gather_nd_params() : base_params(KernelType::GATHER_ND), indices_rank(0), batch_dims(0) {}
|
||||||
|
|
||||||
|
uint8_t indices_rank;
|
||||||
|
|
||||||
|
uint8_t batch_dims;
|
||||||
|
|
||||||
|
virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// gather_nd_optional_params
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
struct gather_nd_optional_params : optional_params {
|
||||||
|
gather_nd_optional_params() : optional_params(KernelType::GATHER_ND) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class GatherNDKernelRef : public KernelBaseOpenCL {
|
||||||
|
public:
|
||||||
|
GatherNDKernelRef() : KernelBaseOpenCL("gather_nd_ref") {}
|
||||||
|
virtual ~GatherNDKernelRef() {}
|
||||||
|
virtual JitConstants GetJitConstants(const gather_nd_params& params) const;
|
||||||
|
virtual CommonDispatchData SetDefault(const gather_nd_params& params, const optional_params&) const;
|
||||||
|
KernelsData GetKernelsData(const Params& params, const optional_params& options) const;
|
||||||
|
ParamsKey GetSupportedKey() const override;
|
||||||
|
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||||
|
return { FusedOpType::QUANTIZE,
|
||||||
|
FusedOpType::SCALE,
|
||||||
|
FusedOpType::ACTIVATION,
|
||||||
|
FusedOpType::ELTWISE };
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool Validate(const Params& p, const optional_params& o) const override;
|
||||||
|
};
|
||||||
|
} // namespace kernel_selector
|
@ -0,0 +1,27 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "gather_nd_kernel_selector.h"
|
||||||
|
#include "gather_nd_kernel_ref.h"
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
|
||||||
|
gather_nd_kernel_selector::gather_nd_kernel_selector() { Attach<GatherNDKernelRef>(); }
|
||||||
|
|
||||||
|
KernelsData gather_nd_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
|
||||||
|
return GetNaiveBestKernel(params, options, KernelType::GATHER_ND);
|
||||||
|
}
|
||||||
|
} // namespace kernel_selector
|
@ -0,0 +1,35 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "kernel_selector.h"
|
||||||
|
|
||||||
|
namespace kernel_selector {
|
||||||
|
class gather_nd_kernel_selector : public kernel_selector_base {
|
||||||
|
public:
|
||||||
|
static gather_nd_kernel_selector& Instance() {
|
||||||
|
static gather_nd_kernel_selector instance_;
|
||||||
|
return instance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
gather_nd_kernel_selector();
|
||||||
|
|
||||||
|
virtual ~gather_nd_kernel_selector() {}
|
||||||
|
|
||||||
|
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
|
||||||
|
};
|
||||||
|
} // namespace kernel_selector
|
231
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_nd_ref.cl
vendored
Normal file
231
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_nd_ref.cl
vendored
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "include/fetch.cl"
|
||||||
|
|
||||||
|
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
|
||||||
|
#define GET_OUTPUT_INDEX(out_order) OUTPUT_GET_INDEX(out_order)
|
||||||
|
|
||||||
|
#if INPUT0_DIMS == 4
|
||||||
|
#define IN_ORDER in_b,in_f,in_y,in_x
|
||||||
|
#elif INPUT0_DIMS == 5
|
||||||
|
#define IN_ORDER in_b,in_f,in_z,in_y,in_x
|
||||||
|
#else
|
||||||
|
#define IN_ORDER in_b,in_f,in_w,in_z,in_y,in_x
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if INPUT1_DIMS == 4
|
||||||
|
#define IDX_ORDER idx_b,idx_f,idx_y,idx_x
|
||||||
|
#elif INPUT1_DIMS == 5
|
||||||
|
#define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x
|
||||||
|
#else
|
||||||
|
#define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
#define OUT_ORDER out_b,out_f,out_y,out_x
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
#define OUT_ORDER out_b,out_f,out_z,out_y,out_x
|
||||||
|
#else
|
||||||
|
#define OUT_ORDER out_b,out_f,out_w,out_z,out_y,out_x
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define INDICES_MAX_DIM 6
|
||||||
|
|
||||||
|
KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data,
|
||||||
|
const __global INPUT1_TYPE* indices,
|
||||||
|
__global OUTPUT_TYPE* output
|
||||||
|
#if HAS_FUSED_OPS_DECLS
|
||||||
|
, FUSED_OPS_DECLS
|
||||||
|
#endif
|
||||||
|
)
|
||||||
|
{
|
||||||
|
|
||||||
|
const uint dim0 = get_global_id(0);
|
||||||
|
const uint dim1 = get_global_id(1);
|
||||||
|
const uint dim2 = get_global_id(2);
|
||||||
|
|
||||||
|
// Calculate indice index
|
||||||
|
const uint F_NUM = (INDICES_RANK == 2) ? 1 : INPUT1_FEATURE_NUM;
|
||||||
|
const uint idx_f = dim2 % F_NUM;
|
||||||
|
const uint idx_b = dim2 / F_NUM;
|
||||||
|
|
||||||
|
#if INPUT1_DIMS == 4
|
||||||
|
const uint idx_x = dim0;
|
||||||
|
const uint idx_y = dim1;
|
||||||
|
const uint idx_z = 0;
|
||||||
|
const uint idx_w = 0;
|
||||||
|
|
||||||
|
const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_y, idx_x, 0, 0, 0, 0};
|
||||||
|
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||||
|
#elif INPUT1_DIMS == 5
|
||||||
|
const uint X_NUM = (INDICES_RANK == 5) ? 1 : INPUT1_SIZE_X;
|
||||||
|
|
||||||
|
const uint idx_x = dim0 % X_NUM;
|
||||||
|
const uint idx_y = dim0 / X_NUM;
|
||||||
|
const uint idx_z = dim1;
|
||||||
|
const uint idx_w = 0;
|
||||||
|
|
||||||
|
const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0};
|
||||||
|
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||||
|
#else
|
||||||
|
const uint X_NUM = (INDICES_RANK == 6) ? 1 : INPUT1_SIZE_X;
|
||||||
|
const uint Z_NUM = (INDICES_RANK == 4) ? 1 : INPUT1_SIZE_Z;
|
||||||
|
|
||||||
|
const uint idx_x = dim0 % X_NUM;
|
||||||
|
const uint idx_y = dim0 / X_NUM;
|
||||||
|
const uint idx_z = dim1 % Z_NUM;
|
||||||
|
const uint idx_w = dim1 / Z_NUM;
|
||||||
|
|
||||||
|
const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_w, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0, 0};
|
||||||
|
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const int idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
|
||||||
|
|
||||||
|
// Calculate data index
|
||||||
|
uint indices_val[INDICES_MAX_DIM + BATCH_DIMS];
|
||||||
|
for (int i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) {
|
||||||
|
indices_val[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < BATCH_DIMS; i++) {
|
||||||
|
indices_val[i] = idx_arr[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < INDICES_LAST_DIM; i++) {
|
||||||
|
indices_val[i + BATCH_DIMS] = indices[idx+i];
|
||||||
|
}
|
||||||
|
|
||||||
|
#if INPUT0_DIMS == 4
|
||||||
|
const uint in_x = indices_val[3];
|
||||||
|
const uint in_y = indices_val[2];
|
||||||
|
#elif INPUT0_DIMS == 5
|
||||||
|
const uint in_x = indices_val[4];
|
||||||
|
const uint in_y = indices_val[3];
|
||||||
|
const uint in_z = indices_val[2];
|
||||||
|
#else
|
||||||
|
const uint in_x = indices_val[5];
|
||||||
|
const uint in_y = indices_val[4];
|
||||||
|
const uint in_z = indices_val[3];
|
||||||
|
const uint in_w = indices_val[2];
|
||||||
|
#endif
|
||||||
|
const uint in_f = indices_val[1];
|
||||||
|
const uint in_b = indices_val[0];
|
||||||
|
|
||||||
|
const uint data_idx = GET_UPDATES_INDEX(INPUT0, IN_ORDER);
|
||||||
|
|
||||||
|
// Calculate output index
|
||||||
|
#if BATCH_DIMS <= 1
|
||||||
|
const uint out_x = idx_x;
|
||||||
|
const uint out_y = idx_y;
|
||||||
|
const uint out_z = idx_z;
|
||||||
|
const uint out_w = idx_w;
|
||||||
|
const uint out_f = idx_f;
|
||||||
|
const uint out_b = idx_b;
|
||||||
|
#else
|
||||||
|
uint pitch_acc = 1;
|
||||||
|
uint output_batch_size = 0;
|
||||||
|
for (int i = BATCH_DIMS - 1; i >= 0; i--) {
|
||||||
|
output_batch_size += (idx_arr[i] * pitch_acc);
|
||||||
|
pitch_acc *= idx_dim[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
const uint out_x = idx_arr[BATCH_DIMS+2];
|
||||||
|
const uint out_y = idx_arr[BATCH_DIMS+1];
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
const uint out_x = idx_arr[BATCH_DIMS+3];
|
||||||
|
const uint out_y = idx_arr[BATCH_DIMS+2];
|
||||||
|
const uint out_z = idx_arr[BATCH_DIMS+1];
|
||||||
|
#else
|
||||||
|
const uint out_x = idx_arr[BATCH_DIMS+4];
|
||||||
|
const uint out_y = idx_arr[BATCH_DIMS+3];
|
||||||
|
const uint out_z = idx_arr[BATCH_DIMS+2];
|
||||||
|
const uint out_w = idx_arr[BATCH_DIMS+1];
|
||||||
|
#endif
|
||||||
|
const uint out_f = idx_arr[BATCH_DIMS+0];
|
||||||
|
const uint out_b = output_batch_size;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const uint output_idx = GET_OUTPUT_INDEX(OUT_ORDER);
|
||||||
|
|
||||||
|
// Copy data to output as slice size
|
||||||
|
#if HAS_FUSED_OPS
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
const uint y_pitch = OUTPUT_SIZE_X;
|
||||||
|
const uint f_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
const uint y_pitch = OUTPUT_SIZE_X;
|
||||||
|
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||||
|
const uint f_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||||
|
#else
|
||||||
|
const uint y_pitch = OUTPUT_SIZE_X;
|
||||||
|
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||||
|
const uint w_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||||
|
const uint f_pitch = w_pitch * OUTPUT_SIZE_W;
|
||||||
|
#endif
|
||||||
|
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (int i = 0; i < WI_SLICE_SIZE; i++) {
|
||||||
|
uint dst_idx = output_idx + i;
|
||||||
|
INPUT0_TYPE val = data[data_idx + i];
|
||||||
|
|
||||||
|
#if HAS_FUSED_OPS
|
||||||
|
const uint b_remain = dst_idx % b_pitch;
|
||||||
|
const uint f_remain = b_remain % f_pitch;
|
||||||
|
#if OUTPUT_DIMS == 4
|
||||||
|
const uint y_remain = f_remain % y_pitch;
|
||||||
|
|
||||||
|
const uint y = f_remain / y_pitch;
|
||||||
|
#elif OUTPUT_DIMS == 5
|
||||||
|
const uint z_remain = f_remain % z_pitch;
|
||||||
|
const uint y_remain = z_remain % y_pitch;
|
||||||
|
|
||||||
|
const uint z = f_remain / z_pitch;
|
||||||
|
const uint y = z_remain / y_pitch;
|
||||||
|
#else
|
||||||
|
const uint w_remain = f_remain % w_pitch;
|
||||||
|
const uint z_remain = w_remain % z_pitch;
|
||||||
|
const uint y_remain = z_remain % y_pitch;
|
||||||
|
|
||||||
|
const uint w = f_remain / w_pitch;
|
||||||
|
const uint z = w_remain / z_pitch;
|
||||||
|
const uint y = z_remain / y_pitch;
|
||||||
|
#endif
|
||||||
|
const uint b = dst_idx / b_pitch;
|
||||||
|
const uint f = b_remain / f_pitch;
|
||||||
|
const uint x = y_remain;
|
||||||
|
|
||||||
|
#if FUSED_OPS_CAN_USE_PRELOAD
|
||||||
|
FUSED_OPS_PRELOAD;
|
||||||
|
FUSED_OPS_CALC;
|
||||||
|
#else
|
||||||
|
FUSED_OPS;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
output[dst_idx] = FUSED_OPS_RESULT;
|
||||||
|
#else
|
||||||
|
output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef INDICES_MAX_DIM
|
||||||
|
#undef GET_UPDATES_INDEX
|
||||||
|
#undef GET_OUTPUT_INDEX
|
||||||
|
#undef OUT_ORDER
|
||||||
|
#undef IDX_ORDER
|
||||||
|
#undef IN_ORDER
|
114
inference-engine/thirdparty/clDNN/src/gather_nd.cpp
vendored
Normal file
114
inference-engine/thirdparty/clDNN/src/gather_nd.cpp
vendored
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "gather_nd_inst.h"
|
||||||
|
|
||||||
|
#include "primitive_type_base.h"
|
||||||
|
#include "error_handler.h"
|
||||||
|
#include "json_object.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
primitive_type_id gather_nd::type_id() {
|
||||||
|
static primitive_type_base<gather_nd> instance;
|
||||||
|
return &instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
layout gather_nd_inst::calc_output_layout(gather_nd_node const& node) {
|
||||||
|
auto op = node.get_primitive();
|
||||||
|
|
||||||
|
auto input_layout_origin = node.input(0).get_output_layout();
|
||||||
|
auto indices_layout_origin = node.input(1).get_output_layout();
|
||||||
|
|
||||||
|
auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format);
|
||||||
|
auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format);
|
||||||
|
|
||||||
|
const size_t input_dims = input_layout.size();
|
||||||
|
|
||||||
|
const auto indices_rank = op->indices_rank;
|
||||||
|
const auto batch_dims = op->batch_dims;
|
||||||
|
|
||||||
|
// calculate initial output shape
|
||||||
|
std::vector<tensor::value_type> output_sizes;
|
||||||
|
|
||||||
|
for (uint8_t x = 0; x < indices_rank - 1; x++) {
|
||||||
|
output_sizes.push_back(indices_layout[x]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t indices_last_dim = indices_layout[indices_rank - 1];
|
||||||
|
for (size_t x = static_cast<size_t>(batch_dims + indices_last_dim); x < input_dims; x++) {
|
||||||
|
output_sizes.push_back(input_layout[x]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate batch_size by batch_dims
|
||||||
|
int batch_size = 1;
|
||||||
|
for (uint8_t x = 0; x < batch_dims; x++) {
|
||||||
|
batch_size *= output_sizes[x];
|
||||||
|
}
|
||||||
|
|
||||||
|
// create final output shape by batch_dims
|
||||||
|
std::vector<tensor::value_type> final_output_sizes;
|
||||||
|
|
||||||
|
if (batch_dims > 0) {
|
||||||
|
final_output_sizes.push_back(batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t x = static_cast<size_t>(batch_dims); x < output_sizes.size(); x++) {
|
||||||
|
final_output_sizes.push_back(output_sizes[x]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_format = cldnn::format::bfyx;
|
||||||
|
if (final_output_sizes.size() >= 6) {
|
||||||
|
output_format = cldnn::format::bfwzyx;
|
||||||
|
} else if (final_output_sizes.size() == 5) {
|
||||||
|
output_format = cldnn::format::bfzyx;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_sizes_tensor = tensor(tensor(final_output_sizes).sizes(output_format));
|
||||||
|
auto padding = op->output_padding;
|
||||||
|
|
||||||
|
|
||||||
|
if (node.has_fused_primitives()) {
|
||||||
|
input_layout_origin.data_type = node.get_fused_output_layout().data_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
return layout(input_layout_origin.data_type, output_format, output_sizes_tensor, padding);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string gather_nd_inst::to_string(gather_nd_node const& node) {
|
||||||
|
auto desc = node.get_primitive();
|
||||||
|
auto node_info = node.desc_to_json();
|
||||||
|
auto& input = node.input();
|
||||||
|
|
||||||
|
std::stringstream primitive_description;
|
||||||
|
|
||||||
|
json_composite gather_nd_info;
|
||||||
|
gather_nd_info.add("input id", input.id());
|
||||||
|
gather_nd_info.add("input shape", node.input(0).get_output_layout().size.to_string());
|
||||||
|
gather_nd_info.add("indices shape", node.input(1).get_output_layout().size.to_string());
|
||||||
|
gather_nd_info.add("indices rank", desc->indices_rank);
|
||||||
|
gather_nd_info.add("batch dims", desc->batch_dims);
|
||||||
|
gather_nd_info.add("output shape", calc_output_layout(node).size.to_string());
|
||||||
|
|
||||||
|
node_info->add("gather_nd info", gather_nd_info);
|
||||||
|
node_info->dump(primitive_description);
|
||||||
|
|
||||||
|
return primitive_description.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
gather_nd_inst::typed_primitive_inst(network_impl& network, gather_nd_node const& node) : parent(network, node) {}
|
||||||
|
|
||||||
|
} // namespace cldnn
|
78
inference-engine/thirdparty/clDNN/src/gpu/gather_nd_gpu.cpp
vendored
Normal file
78
inference-engine/thirdparty/clDNN/src/gpu/gather_nd_gpu.cpp
vendored
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "gather_nd_inst.h"
|
||||||
|
#include "primitive_gpu_base.h"
|
||||||
|
#include "implementation_map.h"
|
||||||
|
#include "kernel_selector_helper.h"
|
||||||
|
#include "gather/gather_nd_kernel_selector.h"
|
||||||
|
#include "gather/gather_nd_kernel_ref.h"
|
||||||
|
#include "error_handler.h"
|
||||||
|
|
||||||
|
using namespace cldnn;
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
struct gather_nd_gpu : typed_primitive_gpu_impl<gather_nd> {
|
||||||
|
using parent = typed_primitive_gpu_impl<gather_nd>;
|
||||||
|
using parent::parent;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static primitive_impl* create(const gather_nd_node& arg) {
|
||||||
|
auto gather_nd_params = get_default_params<kernel_selector::gather_nd_params>(arg);
|
||||||
|
auto gather_nd_optional_params =
|
||||||
|
get_default_optional_params<kernel_selector::gather_nd_optional_params>(arg.get_program());
|
||||||
|
|
||||||
|
gather_nd_params.indices_rank = arg.get_primitive()->indices_rank;
|
||||||
|
gather_nd_params.batch_dims = arg.get_primitive()->batch_dims;
|
||||||
|
|
||||||
|
gather_nd_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
|
||||||
|
|
||||||
|
auto& kernel_selector = kernel_selector::gather_nd_kernel_selector::Instance();
|
||||||
|
auto best_kernels = kernel_selector.GetBestKernels(gather_nd_params, gather_nd_optional_params);
|
||||||
|
|
||||||
|
CLDNN_ERROR_BOOL(arg.id(),
|
||||||
|
"Best_kernel.empty()",
|
||||||
|
best_kernels.empty(),
|
||||||
|
"Cannot find a proper kernel with this arguments");
|
||||||
|
|
||||||
|
auto gather_nd = new gather_nd_gpu(arg, best_kernels[0]);
|
||||||
|
|
||||||
|
return gather_nd;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
attach_gather_nd_gpu::attach_gather_nd_gpu() {
|
||||||
|
auto val_fw = gather_nd_gpu::create;
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
|
||||||
|
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
|
||||||
|
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
|
||||||
|
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace cldnn
|
@ -32,6 +32,7 @@ void register_implementations_gpu() {
|
|||||||
REGISTER_GPU(eltwise);
|
REGISTER_GPU(eltwise);
|
||||||
REGISTER_GPU(fully_connected);
|
REGISTER_GPU(fully_connected);
|
||||||
REGISTER_GPU(gather);
|
REGISTER_GPU(gather);
|
||||||
|
REGISTER_GPU(gather_nd);
|
||||||
REGISTER_GPU(gemm);
|
REGISTER_GPU(gemm);
|
||||||
REGISTER_GPU(input_layout);
|
REGISTER_GPU(input_layout);
|
||||||
REGISTER_GPU(lrn);
|
REGISTER_GPU(lrn);
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
#include "api/eltwise.hpp"
|
#include "api/eltwise.hpp"
|
||||||
#include "api/fully_connected.hpp"
|
#include "api/fully_connected.hpp"
|
||||||
#include "api/gather.hpp"
|
#include "api/gather.hpp"
|
||||||
|
#include "api/gather_nd.hpp"
|
||||||
#include "api/gemm.hpp"
|
#include "api/gemm.hpp"
|
||||||
#include "api/input_layout.hpp"
|
#include "api/input_layout.hpp"
|
||||||
#include "api/lrn.hpp"
|
#include "api/lrn.hpp"
|
||||||
@ -100,6 +101,7 @@ REGISTER_GPU(eltwise);
|
|||||||
REGISTER_GPU(embed);
|
REGISTER_GPU(embed);
|
||||||
REGISTER_GPU(fully_connected);
|
REGISTER_GPU(fully_connected);
|
||||||
REGISTER_GPU(gather);
|
REGISTER_GPU(gather);
|
||||||
|
REGISTER_GPU(gather_nd);
|
||||||
REGISTER_GPU(gemm);
|
REGISTER_GPU(gemm);
|
||||||
REGISTER_GPU(input_layout);
|
REGISTER_GPU(input_layout);
|
||||||
REGISTER_GPU(lookup_table);
|
REGISTER_GPU(lookup_table);
|
||||||
|
@ -32,6 +32,7 @@
|
|||||||
#include "depth_to_space_inst.h"
|
#include "depth_to_space_inst.h"
|
||||||
#include "space_to_depth_inst.h"
|
#include "space_to_depth_inst.h"
|
||||||
#include "gather_inst.h"
|
#include "gather_inst.h"
|
||||||
|
#include "gather_nd_inst.h"
|
||||||
#include "scatter_update_inst.h"
|
#include "scatter_update_inst.h"
|
||||||
#include "scatter_nd_update_inst.h"
|
#include "scatter_nd_update_inst.h"
|
||||||
#include "scatter_elements_update_inst.h"
|
#include "scatter_elements_update_inst.h"
|
||||||
@ -196,6 +197,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
|
|||||||
!input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
|
!input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
|
||||||
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
|
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
|
||||||
!input.is_type<scatter_nd_update>() &&
|
!input.is_type<scatter_nd_update>() &&
|
||||||
|
!input.is_type<gather_nd>() &&
|
||||||
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
|
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
|
||||||
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
|
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
|
||||||
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
|
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
|
||||||
@ -528,6 +530,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
|
|
||||||
should_fuse |= input_data.is_type<gather>();
|
should_fuse |= input_data.is_type<gather>();
|
||||||
|
|
||||||
|
should_fuse |= input_data.is_type<gather_nd>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_update>();
|
should_fuse |= input_data.is_type<scatter_update>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_nd_update>();
|
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||||
@ -594,6 +598,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
|
|
||||||
should_fuse |= input_data.is_type<gather>();
|
should_fuse |= input_data.is_type<gather>();
|
||||||
|
|
||||||
|
should_fuse |= input_data.is_type<gather_nd>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_update>();
|
should_fuse |= input_data.is_type<scatter_update>();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_nd_update>();
|
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||||
@ -682,6 +688,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
|
|
||||||
should_fuse |= input_data.is_type<gather>() && quantize_node.get_scale_shift_opt();
|
should_fuse |= input_data.is_type<gather>() && quantize_node.get_scale_shift_opt();
|
||||||
|
|
||||||
|
should_fuse |= input_data.is_type<gather_nd>() && quantize_node.get_scale_shift_opt();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
|
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
|
||||||
|
|
||||||
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
|
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
|
||||||
@ -741,6 +749,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
|||||||
(parents[i]->is_type<space_to_batch>()) ||
|
(parents[i]->is_type<space_to_batch>()) ||
|
||||||
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
|
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
|
||||||
(parents[i]->is_type<scale>()) ||
|
(parents[i]->is_type<scale>()) ||
|
||||||
|
(parents[i]->is_type<gather_nd>()) ||
|
||||||
(parents[i]->is_type<scatter_nd_update>()) ||
|
(parents[i]->is_type<scatter_nd_update>()) ||
|
||||||
(parents[i]->is_type<scatter_elements_update>()) ||
|
(parents[i]->is_type<scatter_elements_update>()) ||
|
||||||
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
|
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
|
||||||
|
49
inference-engine/thirdparty/clDNN/src/include/gather_nd_inst.h
vendored
Normal file
49
inference-engine/thirdparty/clDNN/src/include/gather_nd_inst.h
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#pragma once
|
||||||
|
#include "api/gather_nd.hpp"
|
||||||
|
#include "primitive_inst.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace cldnn {
|
||||||
|
template <>
|
||||||
|
struct typed_program_node<gather_nd> : public typed_program_node_base<gather_nd> {
|
||||||
|
using parent = typed_program_node_base<gather_nd>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using parent::parent;
|
||||||
|
|
||||||
|
program_node& input(size_t index = 0) const { return get_dependency(index); }
|
||||||
|
};
|
||||||
|
|
||||||
|
using gather_nd_node = typed_program_node<gather_nd>;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class typed_primitive_inst<gather_nd> : public typed_primitive_inst_base<gather_nd> {
|
||||||
|
using parent = typed_primitive_inst_base<gather_nd>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static layout calc_output_layout(gather_nd_node const& node);
|
||||||
|
static std::string to_string(gather_nd_node const& node);
|
||||||
|
|
||||||
|
public:
|
||||||
|
typed_primitive_inst(network_impl& network, gather_nd_node const& desc);
|
||||||
|
};
|
||||||
|
|
||||||
|
using gather_nd_inst = typed_primitive_inst<gather_nd>;
|
||||||
|
} // namespace cldnn
|
@ -22,6 +22,7 @@
|
|||||||
#include "api/deconvolution.hpp"
|
#include "api/deconvolution.hpp"
|
||||||
#include "api/permute.hpp"
|
#include "api/permute.hpp"
|
||||||
#include "api/gather.hpp"
|
#include "api/gather.hpp"
|
||||||
|
#include "api/gather_nd.hpp"
|
||||||
#include "api/scatter_update.hpp"
|
#include "api/scatter_update.hpp"
|
||||||
#include "api/scatter_nd_update.hpp"
|
#include "api/scatter_nd_update.hpp"
|
||||||
#include "api/scatter_elements_update.hpp"
|
#include "api/scatter_elements_update.hpp"
|
||||||
@ -5615,6 +5616,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_scale_activation,
|
|||||||
gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
|
gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
|
||||||
}), );
|
}), );
|
||||||
|
|
||||||
|
|
||||||
/* ----------------------------------------------------------------------------------------------------- */
|
/* ----------------------------------------------------------------------------------------------------- */
|
||||||
/* ------------------------------------------ ScatterUpdate cases --------------------------------------------- */
|
/* ------------------------------------------ ScatterUpdate cases --------------------------------------------- */
|
||||||
/* ----------------------------------------------------------------------------------------------------- */
|
/* ----------------------------------------------------------------------------------------------------- */
|
||||||
@ -7829,3 +7831,199 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise,
|
|||||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
|
||||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
|
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
|
||||||
}), );
|
}), );
|
||||||
|
|
||||||
|
|
||||||
|
/* ----------------------------------------------------------------------------------------------------- */
|
||||||
|
/* ------------------------------------------ GatherND cases ------------------------------------------- */
|
||||||
|
/* ----------------------------------------------------------------------------------------------------- */
|
||||||
|
struct gather_nd_test_params {
|
||||||
|
data_types data_type;
|
||||||
|
|
||||||
|
format input_format;
|
||||||
|
tensor input_shape;
|
||||||
|
|
||||||
|
format indices_format;
|
||||||
|
tensor indices_shape;
|
||||||
|
|
||||||
|
format output_format;
|
||||||
|
tensor output_shape;
|
||||||
|
|
||||||
|
int max_number_in_indices;
|
||||||
|
int indices_rank;
|
||||||
|
int batch_dims;
|
||||||
|
|
||||||
|
data_types default_type;
|
||||||
|
format default_format;
|
||||||
|
|
||||||
|
size_t expected_fused_primitives;
|
||||||
|
size_t expected_not_fused_primitives;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define CASE_GATHER_ND_FP16_4D_1 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 7, 9, 8}, 6, 2, 0, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_4D_2 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 8, 1, 9}, 6, 2, 1, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_4D_3 data_types::f16, format::bfyx, {5, 4, 7, 2}, format::bfyx, {5, 4, 1, 2}, format::bfyx, {40, 1, 1, 1}, 6, 4, 3, data_types::f16, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_GATHER_ND_FP16_5D_1 data_types::f16, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfzyx, {5, 6, 7, 8, 5}, 5, 2, 0, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_5D_2 data_types::f16, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfyx, {5, 5, 7, 8}, 5, 2, 1, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_5D_3 data_types::f16, format::bfzyx, {5, 4, 7, 8, 5}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {20, 1, 1, 1}, 4, 3, 2, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_5D_4 data_types::f16, format::bfzyx, {5, 4, 7, 8, 3}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {60, 7, 1, 1}, 4, 4, 3, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_5D_5 data_types::f16, format::bfzyx, {5, 4, 7, 2, 3}, format::bfzyx, {5, 4, 1, 2, 3}, format::bfyx, {120, 1, 1, 1}, 4, 5, 4, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_5D_6 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {20, 3, 7, 4, 1}, 4, 5, 2, data_types::f16, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_GATHER_ND_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 5}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {20, 2, 6, 7}, 5, 4, 2, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_6D_2 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {40, 6, 1, 1}, 5, 4, 3, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_6D_3 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 2, 2}, format::bfzyx, {5, 4, 1, 2, 2}, format::bfyx, {80, 6, 1, 1}, 5, 5, 4, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP16_6D_4 data_types::f16, format::bfwzyx, {5, 4, 6, 3, 2, 2}, format::bfwzyx, {5, 4, 1, 3, 2, 2}, format::bfyx, {240, 1, 1, 1}, 5, 6, 5, data_types::f16, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_GATHER_ND_FP32_4D_1 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 7, 9, 8}, 6, 2, 0, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_4D_2 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 8, 1, 9}, 6, 2, 1, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_4D_3 data_types::f32, format::bfyx, {5, 4, 7, 2}, format::bfyx, {5, 4, 1, 2}, format::bfyx, {40, 1, 1, 1}, 6, 4, 3, data_types::f32, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_GATHER_ND_FP32_5D_1 data_types::f32, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfzyx, {5, 6, 7, 8, 5}, 5, 2, 0, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_5D_2 data_types::f32, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfyx, {5, 5, 7, 8}, 5, 2, 1, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_5D_3 data_types::f32, format::bfzyx, {5, 4, 7, 8, 5}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {20, 1, 1, 1}, 4, 3, 2, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_5D_4 data_types::f32, format::bfzyx, {5, 4, 7, 8, 3}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {60, 7, 1, 1}, 4, 4, 3, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_5D_5 data_types::f32, format::bfzyx, {5, 4, 7, 2, 3}, format::bfzyx, {5, 4, 1, 2, 3}, format::bfyx, {120, 1, 1, 1}, 4, 5, 4, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_5D_6 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {20, 3, 7, 4, 1}, 4, 5, 2, data_types::f32, format::bfyx
|
||||||
|
|
||||||
|
#define CASE_GATHER_ND_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 5}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {20, 2, 6, 7}, 5, 4, 2, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_6D_2 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {40, 6, 1, 1}, 5, 4, 3, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_6D_3 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 2, 2}, format::bfzyx, {5, 4, 1, 2, 2}, format::bfyx, {80, 6, 1, 1}, 5, 5, 4, data_types::f32, format::bfyx
|
||||||
|
#define CASE_GATHER_ND_FP32_6D_4 data_types::f32, format::bfwzyx, {5, 4, 6, 3, 2, 2}, format::bfwzyx, {5, 4, 1, 3, 2, 2}, format::bfyx, {240, 1, 1, 1}, 5, 6, 5, data_types::f32, format::bfyx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GatherNDPrimitiveFusingTest : public ::BaseFusingTest<gather_nd_test_params> {
|
||||||
|
public:
|
||||||
|
void execute(gather_nd_test_params& p) {
|
||||||
|
auto input_prim = get_mem(get_input_layout(p));
|
||||||
|
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||||
|
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||||
|
network_fused.set_input_data("input", input_prim);
|
||||||
|
network_not_fused.set_input_data("input", input_prim);
|
||||||
|
compare(network_not_fused, network_fused, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_input_layout(gather_nd_test_params& p) {
|
||||||
|
return layout{ p.data_type, p.input_format, p.input_shape };
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_indices_layout(gather_nd_test_params& p) {
|
||||||
|
return layout{ p.data_type, p.indices_format, p.indices_shape };
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_output_layout(gather_nd_test_params& p) {
|
||||||
|
return layout{ p.data_type, p.output_format, p.output_shape };
|
||||||
|
}
|
||||||
|
|
||||||
|
layout get_per_channel_layout(gather_nd_test_params& p) {
|
||||||
|
return layout{ p.default_type, p.default_format, tensor{1, p.output_shape.feature[0], 1, 1} };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class gather_nd_quantize : public GatherNDPrimitiveFusingTest {};
|
||||||
|
TEST_P(gather_nd_quantize, basic) {
|
||||||
|
auto p = GetParam();
|
||||||
|
create_topologies(input_layout("input", get_input_layout(p)),
|
||||||
|
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
|
||||||
|
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||||
|
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||||
|
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||||
|
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||||
|
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
|
||||||
|
quantize("quantize", "gather_nd_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
||||||
|
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||||
|
);
|
||||||
|
tolerance = 1.f;
|
||||||
|
execute(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_nd_quantize,
|
||||||
|
::testing::ValuesIn(std::vector<gather_nd_test_params>{
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_1, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_2, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_3, 2, 3 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_1, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_2, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_3, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_4, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_5, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_6, 2, 3 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_1, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_2, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_3, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_4, 2, 3 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_1, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_2, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_3, 2, 3 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_1, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_2, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_3, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_4, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_5, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_6, 2, 3 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_1, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_2, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 3 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 3 },
|
||||||
|
}), );
|
||||||
|
|
||||||
|
class gather_nd_activation_scale_eltwise : public GatherNDPrimitiveFusingTest {};
|
||||||
|
TEST_P(gather_nd_activation_scale_eltwise, basic) {
|
||||||
|
auto p = GetParam();
|
||||||
|
|
||||||
|
create_topologies(input_layout("input", get_input_layout(p)),
|
||||||
|
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
|
||||||
|
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)),
|
||||||
|
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||||
|
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
|
||||||
|
activation("activation", "gather_nd_prim", activation_func::abs),
|
||||||
|
scale("scale", "activation", "scale_data"),
|
||||||
|
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
|
||||||
|
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||||
|
);
|
||||||
|
|
||||||
|
tolerance = 1e-5f;
|
||||||
|
execute(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_nd_activation_scale_eltwise,
|
||||||
|
::testing::ValuesIn(std::vector<gather_nd_test_params>{
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_1, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_2, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_3, 2, 5 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_1, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_2, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_3, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_4, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_5, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_6, 2, 5 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_1, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_2, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_3, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_4, 2, 5 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_1, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_2, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_3, 2, 5 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_1, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_2, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_3, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_4, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_5, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_6, 2, 5 },
|
||||||
|
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_1, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_2, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 },
|
||||||
|
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 },
|
||||||
|
}), );
|
||||||
|
|
||||||
|
730
inference-engine/thirdparty/clDNN/tests/test_cases/gather_nd_gpu_test.cpp
vendored
Normal file
730
inference-engine/thirdparty/clDNN/tests/test_cases/gather_nd_gpu_test.cpp
vendored
Normal file
@ -0,0 +1,730 @@
|
|||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <api/input_layout.hpp>
|
||||||
|
#include <api/memory.hpp>
|
||||||
|
#include <api/gather_nd.hpp>
|
||||||
|
#include <api/topology.hpp>
|
||||||
|
#include <api/network.hpp>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <tests/test_utils/test_utils.h>
|
||||||
|
|
||||||
|
using namespace cldnn;
|
||||||
|
using namespace ::tests;
|
||||||
|
|
||||||
|
inline void DoTest(const engine& engine,
|
||||||
|
const cldnn::memory& input0,
|
||||||
|
const cldnn::memory& input1,
|
||||||
|
const std::vector<float>& expected_results,
|
||||||
|
const int indices_rank,
|
||||||
|
const int batch_dims) {
|
||||||
|
topology topology;
|
||||||
|
topology.add(input_layout("InputData", input0.get_layout()));
|
||||||
|
topology.add(input_layout("InputIndices", input1.get_layout()));
|
||||||
|
topology.add(
|
||||||
|
gather_nd("gather_nd", "InputData", "InputIndices", indices_rank, batch_dims)
|
||||||
|
);
|
||||||
|
|
||||||
|
network network(engine, topology);
|
||||||
|
|
||||||
|
network.set_input_data("InputData", input0);
|
||||||
|
network.set_input_data("InputIndices", input1);
|
||||||
|
auto outputs = network.execute();
|
||||||
|
auto output = outputs.at("gather_nd").get_memory();
|
||||||
|
auto output_ptr = output.pointer<uint16_t>();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||||
|
EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d23322_i231312_ir6_batch2) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 6;
|
||||||
|
const int batch_dims = 2;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 2, 2, 3 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 1, 3, 1 } }); // indices
|
||||||
|
// expected output dim: {6,1,3,1,2}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0),
|
||||||
|
FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(15), FLOAT16(16), FLOAT16(11), FLOAT16(12), FLOAT16(11), FLOAT16(12),
|
||||||
|
FLOAT16(25), FLOAT16(26), FLOAT16(23), FLOAT16(24), FLOAT16(23), FLOAT16(24),
|
||||||
|
FLOAT16(33), FLOAT16(34), FLOAT16(33), FLOAT16(34), FLOAT16(33), FLOAT16(34),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(25), FLOAT16(26), FLOAT16(25), FLOAT16(26),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(35), FLOAT16(36), FLOAT16(33), FLOAT16(34),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d231322_i231321_ir6_batch5) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 6;
|
||||||
|
const int batch_dims = 5;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 2, 3, 1 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 1, 2, 3, 1 } }); // indices
|
||||||
|
// expected output dim: {36}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(21), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(31), FLOAT16(30),
|
||||||
|
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(29), FLOAT16(30),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||||
|
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(31),
|
||||||
|
|
||||||
|
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||||
|
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(29),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d23322_i23321_ir5_batch4) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 5;
|
||||||
|
const int batch_dims = 4;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 2, 2, 3 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 1, 2, 3 } }); // indices
|
||||||
|
// expected output dim: {36}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(21), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(31), FLOAT16(30),
|
||||||
|
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(29), FLOAT16(30),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||||
|
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(31),
|
||||||
|
|
||||||
|
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||||
|
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(29),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d23223_i2321_ir4_batch3) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 4;
|
||||||
|
const int batch_dims = 3;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 3, 2, 2 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 2 } }); // indices
|
||||||
|
// expected output dim: {2*3*2,3}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(1), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(25),
|
||||||
|
FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(17), FLOAT16(18), FLOAT16(15),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(35), FLOAT16(36), FLOAT16(33),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d2342_i2312_ir4_batch2) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 4;
|
||||||
|
const int batch_dims = 2;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 4 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 1 } }); // indices
|
||||||
|
// expected output dim: {6,1}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||||
|
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(2), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(14),
|
||||||
|
FLOAT16(21),
|
||||||
|
FLOAT16(34),
|
||||||
|
|
||||||
|
FLOAT16(11),
|
||||||
|
FLOAT16(26),
|
||||||
|
FLOAT16(33),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d234_i2311_ir4_batch2) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 4;
|
||||||
|
const int batch_dims = 2;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 4 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 1 } }); // indices
|
||||||
|
// expected output dim: {6,1,1}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4),
|
||||||
|
FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(20),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1),
|
||||||
|
FLOAT16(0),
|
||||||
|
FLOAT16(2),
|
||||||
|
|
||||||
|
FLOAT16(0),
|
||||||
|
FLOAT16(2),
|
||||||
|
FLOAT16(2),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(2),
|
||||||
|
FLOAT16(5),
|
||||||
|
FLOAT16(11),
|
||||||
|
|
||||||
|
FLOAT16(13),
|
||||||
|
FLOAT16(19),
|
||||||
|
FLOAT16(23),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d234_i21_ir2_batch1) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 2;
|
||||||
|
const int batch_dims = 1;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 4 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices
|
||||||
|
// expected output dim: {2,4}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4),
|
||||||
|
FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(20),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1),
|
||||||
|
FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d22_i21_ir2_batch1) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 2;
|
||||||
|
const int batch_dims = 1;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices
|
||||||
|
// expected output dim: 2
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(3), FLOAT16(4),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1),
|
||||||
|
FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(2),
|
||||||
|
FLOAT16(3),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d3223_i321113_ir6_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 6;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 3, 2 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 3, 2, 3, 1, 1, 1 } }); // indices
|
||||||
|
// expected output dim: 321113
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||||
|
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(0), FLOAT16(1), FLOAT16(0),
|
||||||
|
FLOAT16(2), FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33),
|
||||||
|
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23),
|
||||||
|
FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43),
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d3221_i32312_ir3_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 3;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 1, 3 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 2 } }); // indices
|
||||||
|
// expected output dim: 32312
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||||
|
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(1), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
FLOAT16(2), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d3231_i32312_ir3_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 3;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 1, 3 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 3 } }); // indices
|
||||||
|
// expected output dim: {3,2,1,2}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||||
|
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||||
|
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(2),
|
||||||
|
|
||||||
|
FLOAT16(0), FLOAT16(1), FLOAT16(0),
|
||||||
|
FLOAT16(2), FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(63), FLOAT16(64),
|
||||||
|
FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(21), FLOAT16(22),
|
||||||
|
FLOAT16(53), FLOAT16(54),
|
||||||
|
|
||||||
|
FLOAT16(45), FLOAT16(46),
|
||||||
|
FLOAT16(11), FLOAT16(12),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d3112_i3221_ir4_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 4;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 1, 2, 1 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 2 } }); // indices
|
||||||
|
// expected output dim: {3,2,2,1,1,2}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(13), FLOAT16(14),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d311211_i322111_ir4_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 4;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 3, 1, 1, 1, 2, 1 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 3, 2, 1, 1, 1, 2 } }); // indices
|
||||||
|
// expected output dim: {3,2,2,1,1,2,1,1}
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(13), FLOAT16(14),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||||
|
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d3332_i3223_ir4_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 4;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 3, 2 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 3, 2 } }); // indices
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(5), FLOAT16(6),
|
||||||
|
FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
|
||||||
|
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||||
|
FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(30),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1),
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(61), FLOAT16(62), FLOAT16(63),
|
||||||
|
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(25), FLOAT16(26), FLOAT16(27),
|
||||||
|
|
||||||
|
FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(28), FLOAT16(29), FLOAT16(30),
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(51), FLOAT16(52), FLOAT16(53),
|
||||||
|
|
||||||
|
FLOAT16(28), FLOAT16(29), FLOAT16(30), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||||
|
FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d3323_i322_ir3_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 3;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 3, 2 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 2 } }); // indices
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(5), FLOAT16(6),
|
||||||
|
FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
|
||||||
|
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||||
|
FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(30),
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(2), FLOAT16(0),
|
||||||
|
FLOAT16(2), FLOAT16(1),
|
||||||
|
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(1), FLOAT16(0),
|
||||||
|
|
||||||
|
FLOAT16(0), FLOAT16(1),
|
||||||
|
FLOAT16(0), FLOAT16(2),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||||
|
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||||
|
|
||||||
|
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||||
|
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||||
|
|
||||||
|
FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||||
|
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d22_i21_ir2_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 2;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(3), FLOAT16(4)
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(1), FLOAT16(0),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(3), FLOAT16(4),
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(gather_nd_gpu_fp16, d22_i32_ir2_batch0) {
|
||||||
|
const auto& engine = get_test_engine();
|
||||||
|
|
||||||
|
const int indices_rank = 2;
|
||||||
|
const int batch_dims = 0;
|
||||||
|
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // data
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices
|
||||||
|
|
||||||
|
set_values(input0, {
|
||||||
|
FLOAT16(1), FLOAT16(2),
|
||||||
|
FLOAT16(3), FLOAT16(4)
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
FLOAT16(0), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(0),
|
||||||
|
FLOAT16(1), FLOAT16(1),
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<float> expected_results = {
|
||||||
|
FLOAT16(1),
|
||||||
|
FLOAT16(3),
|
||||||
|
FLOAT16(4),
|
||||||
|
};
|
||||||
|
|
||||||
|
DoTest(engine,input0, input1, expected_results, indices_rank, batch_dims);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user