[IE CLDNN] Add new operation GatherND-5 to IE clDNN plugin (#4857)
This commit is contained in:
parent
9a46851a33
commit
9f93509b1c
@ -193,10 +193,10 @@ REGISTER_FACTORY(v5, LogSoftmax);
|
||||
REGISTER_FACTORY(v5, LSTMSequence);
|
||||
//REGISTER_FACTORY(v5, NonMaxSuppression); Supported via v5 -> v5 internal conversion
|
||||
REGISTER_FACTORY(v5, Round);
|
||||
REGISTER_FACTORY(v5, GatherND);
|
||||
|
||||
// ----------------------------- Unsupported v5 ops ----------------------------- //
|
||||
// REGISTER_FACTORY(v5, BatchNormInference);
|
||||
// REGISTER_FACTORY(v5, GatherND);
|
||||
// REGISTER_FACTORY(v5, GRUSequence);
|
||||
// REGISTER_FACTORY(v5, Loop);
|
||||
// REGISTER_FACTORY(v5, RNNSequence);
|
||||
|
36
inference-engine/src/cldnn_engine/ops/gather_nd.cpp
Normal file
36
inference-engine/src/cldnn_engine/ops/gather_nd.cpp
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "cldnn_program.h"
|
||||
#include "cldnn_common_utils.h"
|
||||
|
||||
#include "ngraph/op/gather_nd.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
|
||||
#include "api/gather_nd.hpp"
|
||||
|
||||
namespace CLDNNPlugin {
|
||||
|
||||
void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v5::GatherND>& op) {
|
||||
p.ValidateInputs(op, {2});
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
int32_t indices_rank = static_cast<int32_t>(op->get_input_shape(1).size());
|
||||
|
||||
auto batch_dims = op->get_batch_dims();
|
||||
|
||||
auto primitive = cldnn::gather_nd(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
indices_rank,
|
||||
batch_dims);
|
||||
|
||||
p.AddPrimitive(primitive);
|
||||
p.AddPrimitiveToProfiler(op);
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v5, GatherND);
|
||||
|
||||
} // namespace CLDNNPlugin
|
@ -0,0 +1,81 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
|
||||
#include "single_layer_tests/gather_nd.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using namespace ngraph::opset5;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> inputPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16,
|
||||
InferenceEngine::Precision::I32,
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> idxPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64,
|
||||
};
|
||||
|
||||
// set1
|
||||
const auto gatherNDArgsSubset1 = ::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{ {2, 2}, {2, 3, 4} })), // Data shape
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{ {2, 1}, {2, 1, 1} })), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({ 0, 1 })) // Batch dims
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherND_set1, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset1,
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
// set2
|
||||
const auto gatherNDArgsSubset2 = ::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{ {15, 12, 20, 15, 2}, {15, 12, 18, 7, 17} })), // Data shape
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{ {15, 12, 2}, {15, 12, 5, 9, 1, 3} })), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({ 1, 2 })) // Batch dims
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherND_set2, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset2,
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
// set3
|
||||
const auto gatherNDArgsSubset3 = ::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{ {4, 3, 2, 5, 5, 2}, {4, 3, 2, 5, 7, 2} })), // Data shape
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{ {4, 3, 2, 5, 1}, {4, 3, 2, 5, 6, 2} })), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({ 3, 4 })) // Batch dims
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherND_set3, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset3,
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
57
inference-engine/thirdparty/clDNN/api/gather_nd.hpp
vendored
Normal file
57
inference-engine/thirdparty/clDNN/api/gather_nd.hpp
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#pragma once
|
||||
#include "primitive.hpp"
|
||||
|
||||
namespace cldnn {
|
||||
/// @addtogroup cpp_api C++ API
|
||||
/// @{
|
||||
/// @addtogroup cpp_topology Network Topology
|
||||
/// @{
|
||||
/// @addtogroup cpp_primitives Primitives
|
||||
/// @{
|
||||
|
||||
/// @brief
|
||||
/// @details
|
||||
struct gather_nd : public primitive_base<gather_nd> {
|
||||
CLDNN_DECLARE_PRIMITIVE(gather_nd)
|
||||
|
||||
/// @brief Constructs gather_nd primitive.
|
||||
/// @param id This primitive id.
|
||||
/// @param data Input data primitive id.
|
||||
/// @param indices Input indexes primitive id.
|
||||
/// @param indices_rank Rank of indices.
|
||||
/// @param batch_dims batch_dims as an attribute of GatherND. Optional.
|
||||
gather_nd(const primitive_id& id,
|
||||
const primitive_id& data,
|
||||
const primitive_id& indices,
|
||||
const uint8_t indices_rank,
|
||||
const uint8_t batch_dims = 0,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data, indices}, output_padding), indices_rank(indices_rank), batch_dims(batch_dims) {}
|
||||
|
||||
/// @brief GatherND indices_rank
|
||||
uint8_t indices_rank;
|
||||
|
||||
/// @brief GatherND batch_dims
|
||||
uint8_t batch_dims;
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
/// @}
|
||||
} // namespace cldnn
|
@ -47,6 +47,7 @@ enum class KernelType {
|
||||
CONTRACT,
|
||||
ONE_HOT,
|
||||
GATHER,
|
||||
GATHER_ND,
|
||||
SCATTER_UPDATE,
|
||||
SCATTER_ND_UPDATE,
|
||||
SCATTER_ELEMENTS_UPDATE,
|
||||
|
@ -0,0 +1,210 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_nd_kernel_ref.h"
|
||||
#include "kernel_selector_utils.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
ParamsKey GatherNDKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::INT8);
|
||||
k.EnableOutputDataType(Datatype::UINT8);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
return k;
|
||||
}
|
||||
|
||||
static inline std::string GetOrderString(std::vector<std::string>& order) {
|
||||
std::string order_str = order[0];
|
||||
for (size_t i = 1; i < order.size(); i++)
|
||||
order_str += ", " + order[i];
|
||||
|
||||
return order_str;
|
||||
}
|
||||
|
||||
static inline std::vector<std::string> GetDefaultOrder(size_t size) {
|
||||
std::vector<std::string> default_order;
|
||||
if (size <= 4) {
|
||||
default_order = { "b", "f", "y", "x" };
|
||||
} else if (size == 5) {
|
||||
default_order = { "b", "f", "z", "y", "x" };
|
||||
} else if (size == 6) {
|
||||
default_order = { "b", "f", "w", "z", "y", "x" };
|
||||
}
|
||||
|
||||
return default_order;
|
||||
}
|
||||
|
||||
CommonDispatchData GatherNDKernelRef::SetDefault(const gather_nd_params& params, const optional_params&) const {
|
||||
CommonDispatchData dispatchData;
|
||||
|
||||
auto indices_dims = params.inputs[1].LogicalDims();
|
||||
|
||||
if (indices_dims.size() > 1) {
|
||||
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||
}
|
||||
|
||||
indices_dims[params.indices_rank - 1] = 1; // set last dim of indices to 1
|
||||
|
||||
switch (params.inputs[1].GetLayout()) {
|
||||
case DataLayout::bfyx:
|
||||
dispatchData.gws = { indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] };
|
||||
break;
|
||||
|
||||
case DataLayout::bfzyx:
|
||||
dispatchData.gws = { indices_dims[4] * indices_dims[3], indices_dims[2], indices_dims[1] * indices_dims[0] };
|
||||
break;
|
||||
|
||||
case DataLayout::bfwzyx:
|
||||
dispatchData.gws = { indices_dims[5] * indices_dims[4], indices_dims[3] * indices_dims[2], indices_dims[1] * indices_dims[0] };
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::invalid_argument("Unsupported data layout for scatter elements update primitive");
|
||||
break;
|
||||
}
|
||||
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||
|
||||
return dispatchData;
|
||||
}
|
||||
|
||||
static size_t GetIndicesLastDim(const gather_nd_params& params) {
|
||||
// get indices dims
|
||||
auto indices_dims = params.inputs[1].LogicalDims();
|
||||
|
||||
if (indices_dims.size() > 1) {
|
||||
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||
}
|
||||
|
||||
auto indices_last_dim = indices_dims[params.indices_rank - 1];
|
||||
|
||||
return indices_last_dim;
|
||||
}
|
||||
|
||||
static size_t GetSliceSize(const gather_nd_params& params) {
|
||||
// get input dims
|
||||
auto input_dims = params.inputs[0].LogicalDims();
|
||||
|
||||
if (input_dims.size() > 1) {
|
||||
std::reverse(input_dims.begin(), input_dims.end());
|
||||
}
|
||||
|
||||
// get last dim of indices
|
||||
auto indices_last_dim = GetIndicesLastDim(params);
|
||||
|
||||
// calculate slize size which is used in kernel to copy
|
||||
size_t wi_slice_size = 1;
|
||||
for (size_t i = params.batch_dims + indices_last_dim; i < input_dims.size(); i++) {
|
||||
wi_slice_size *= input_dims[i];
|
||||
}
|
||||
|
||||
return wi_slice_size;
|
||||
}
|
||||
|
||||
JitConstants GatherNDKernelRef::GetJitConstants(const gather_nd_params& params) const {
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
|
||||
jit.AddConstant(MakeJitConstant("INDICES_RANK", params.indices_rank));
|
||||
jit.AddConstant(MakeJitConstant("BATCH_DIMS", params.batch_dims));
|
||||
jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE", GetSliceSize(params)));
|
||||
jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", GetIndicesLastDim(params)));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
FusedOpsConfiguration conf = { "", GetDefaultOrder(params.output.GetDims().size()), "val", params.inputs[0].GetDType() };
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
|
||||
}
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
bool GatherNDKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||
if (p.GetType() != KernelType:: GATHER_ND || o.GetType() != KernelType::GATHER_ND) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const gather_nd_params& params = static_cast<const gather_nd_params&>(p);
|
||||
auto input_dims = params.inputs[0].LogicalDims();
|
||||
auto indices_dims = params.inputs[1].LogicalDims();
|
||||
auto indices_rank = params.indices_rank;
|
||||
auto batch_dims = params.batch_dims;
|
||||
|
||||
std::reverse(input_dims.begin(), input_dims.end());
|
||||
std::reverse(indices_dims.begin(), indices_dims.end());
|
||||
|
||||
if (indices_rank < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch_dims + indices_dims[indices_rank - 1] > input_dims.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch_dims >= std::min(input_dims.size(), static_cast<size_t>(indices_rank))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint8_t i = 0; i < batch_dims; i++) {
|
||||
if (input_dims[i] != indices_dims[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& fused_op : params.fused_ops) {
|
||||
if (!IsFusedPrimitiveSupported(fused_op))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
KernelsData GatherNDKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
if (!Validate(params, options)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
KernelData kd = KernelData::Default<gather_nd_params>(params);
|
||||
gather_nd_params& newParams = *static_cast<gather_nd_params*>(kd.params.get());
|
||||
|
||||
auto dispatchData = SetDefault(newParams, options);
|
||||
auto cldnn_jit = GetJitConstants(newParams);
|
||||
|
||||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
|
||||
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
auto& kernel = kd.kernels[0];
|
||||
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params));
|
||||
|
||||
return { kd };
|
||||
}
|
||||
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,60 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernel_base_opencl.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// gather_nd_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct gather_nd_params : public base_params {
|
||||
gather_nd_params() : base_params(KernelType::GATHER_ND), indices_rank(0), batch_dims(0) {}
|
||||
|
||||
uint8_t indices_rank;
|
||||
|
||||
uint8_t batch_dims;
|
||||
|
||||
virtual ParamsKey GetParamsKey() const { return base_params::GetParamsKey(); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// gather_nd_optional_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct gather_nd_optional_params : optional_params {
|
||||
gather_nd_optional_params() : optional_params(KernelType::GATHER_ND) {}
|
||||
};
|
||||
|
||||
class GatherNDKernelRef : public KernelBaseOpenCL {
|
||||
public:
|
||||
GatherNDKernelRef() : KernelBaseOpenCL("gather_nd_ref") {}
|
||||
virtual ~GatherNDKernelRef() {}
|
||||
virtual JitConstants GetJitConstants(const gather_nd_params& params) const;
|
||||
virtual CommonDispatchData SetDefault(const gather_nd_params& params, const optional_params&) const;
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION,
|
||||
FusedOpType::ELTWISE };
|
||||
}
|
||||
|
||||
protected:
|
||||
bool Validate(const Params& p, const optional_params& o) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,27 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_nd_kernel_selector.h"
|
||||
#include "gather_nd_kernel_ref.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
gather_nd_kernel_selector::gather_nd_kernel_selector() { Attach<GatherNDKernelRef>(); }
|
||||
|
||||
KernelsData gather_nd_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
|
||||
return GetNaiveBestKernel(params, options, KernelType::GATHER_ND);
|
||||
}
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,35 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernel_selector.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
class gather_nd_kernel_selector : public kernel_selector_base {
|
||||
public:
|
||||
static gather_nd_kernel_selector& Instance() {
|
||||
static gather_nd_kernel_selector instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
gather_nd_kernel_selector();
|
||||
|
||||
virtual ~gather_nd_kernel_selector() {}
|
||||
|
||||
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
231
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_nd_ref.cl
vendored
Normal file
231
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_nd_ref.cl
vendored
Normal file
@ -0,0 +1,231 @@
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "include/fetch.cl"
|
||||
|
||||
#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
|
||||
#define GET_OUTPUT_INDEX(out_order) OUTPUT_GET_INDEX(out_order)
|
||||
|
||||
#if INPUT0_DIMS == 4
|
||||
#define IN_ORDER in_b,in_f,in_y,in_x
|
||||
#elif INPUT0_DIMS == 5
|
||||
#define IN_ORDER in_b,in_f,in_z,in_y,in_x
|
||||
#else
|
||||
#define IN_ORDER in_b,in_f,in_w,in_z,in_y,in_x
|
||||
#endif
|
||||
|
||||
#if INPUT1_DIMS == 4
|
||||
#define IDX_ORDER idx_b,idx_f,idx_y,idx_x
|
||||
#elif INPUT1_DIMS == 5
|
||||
#define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x
|
||||
#else
|
||||
#define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x
|
||||
#endif
|
||||
|
||||
#if OUTPUT_DIMS == 4
|
||||
#define OUT_ORDER out_b,out_f,out_y,out_x
|
||||
#elif OUTPUT_DIMS == 5
|
||||
#define OUT_ORDER out_b,out_f,out_z,out_y,out_x
|
||||
#else
|
||||
#define OUT_ORDER out_b,out_f,out_w,out_z,out_y,out_x
|
||||
#endif
|
||||
|
||||
#define INDICES_MAX_DIM 6
|
||||
|
||||
KERNEL(gather_nd_ref)(const __global INPUT0_TYPE* data,
|
||||
const __global INPUT1_TYPE* indices,
|
||||
__global OUTPUT_TYPE* output
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
, FUSED_OPS_DECLS
|
||||
#endif
|
||||
)
|
||||
{
|
||||
|
||||
const uint dim0 = get_global_id(0);
|
||||
const uint dim1 = get_global_id(1);
|
||||
const uint dim2 = get_global_id(2);
|
||||
|
||||
// Calculate indice index
|
||||
const uint F_NUM = (INDICES_RANK == 2) ? 1 : INPUT1_FEATURE_NUM;
|
||||
const uint idx_f = dim2 % F_NUM;
|
||||
const uint idx_b = dim2 / F_NUM;
|
||||
|
||||
#if INPUT1_DIMS == 4
|
||||
const uint idx_x = dim0;
|
||||
const uint idx_y = dim1;
|
||||
const uint idx_z = 0;
|
||||
const uint idx_w = 0;
|
||||
|
||||
const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_y, idx_x, 0, 0, 0, 0};
|
||||
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#elif INPUT1_DIMS == 5
|
||||
const uint X_NUM = (INDICES_RANK == 5) ? 1 : INPUT1_SIZE_X;
|
||||
|
||||
const uint idx_x = dim0 % X_NUM;
|
||||
const uint idx_y = dim0 / X_NUM;
|
||||
const uint idx_z = dim1;
|
||||
const uint idx_w = 0;
|
||||
|
||||
const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0};
|
||||
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#else
|
||||
const uint X_NUM = (INDICES_RANK == 6) ? 1 : INPUT1_SIZE_X;
|
||||
const uint Z_NUM = (INDICES_RANK == 4) ? 1 : INPUT1_SIZE_Z;
|
||||
|
||||
const uint idx_x = dim0 % X_NUM;
|
||||
const uint idx_y = dim0 / X_NUM;
|
||||
const uint idx_z = dim1 % Z_NUM;
|
||||
const uint idx_w = dim1 / Z_NUM;
|
||||
|
||||
const uint idx_arr[INPUT1_DIMS*2] = {idx_b, idx_f, idx_w, idx_z, idx_y, idx_x, 0, 0, 0, 0, 0, 0};
|
||||
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#endif
|
||||
|
||||
const int idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
|
||||
|
||||
// Calculate data index
|
||||
uint indices_val[INDICES_MAX_DIM + BATCH_DIMS];
|
||||
for (int i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) {
|
||||
indices_val[i] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < BATCH_DIMS; i++) {
|
||||
indices_val[i] = idx_arr[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < INDICES_LAST_DIM; i++) {
|
||||
indices_val[i + BATCH_DIMS] = indices[idx+i];
|
||||
}
|
||||
|
||||
#if INPUT0_DIMS == 4
|
||||
const uint in_x = indices_val[3];
|
||||
const uint in_y = indices_val[2];
|
||||
#elif INPUT0_DIMS == 5
|
||||
const uint in_x = indices_val[4];
|
||||
const uint in_y = indices_val[3];
|
||||
const uint in_z = indices_val[2];
|
||||
#else
|
||||
const uint in_x = indices_val[5];
|
||||
const uint in_y = indices_val[4];
|
||||
const uint in_z = indices_val[3];
|
||||
const uint in_w = indices_val[2];
|
||||
#endif
|
||||
const uint in_f = indices_val[1];
|
||||
const uint in_b = indices_val[0];
|
||||
|
||||
const uint data_idx = GET_UPDATES_INDEX(INPUT0, IN_ORDER);
|
||||
|
||||
// Calculate output index
|
||||
#if BATCH_DIMS <= 1
|
||||
const uint out_x = idx_x;
|
||||
const uint out_y = idx_y;
|
||||
const uint out_z = idx_z;
|
||||
const uint out_w = idx_w;
|
||||
const uint out_f = idx_f;
|
||||
const uint out_b = idx_b;
|
||||
#else
|
||||
uint pitch_acc = 1;
|
||||
uint output_batch_size = 0;
|
||||
for (int i = BATCH_DIMS - 1; i >= 0; i--) {
|
||||
output_batch_size += (idx_arr[i] * pitch_acc);
|
||||
pitch_acc *= idx_dim[i];
|
||||
}
|
||||
|
||||
#if OUTPUT_DIMS == 4
|
||||
const uint out_x = idx_arr[BATCH_DIMS+2];
|
||||
const uint out_y = idx_arr[BATCH_DIMS+1];
|
||||
#elif OUTPUT_DIMS == 5
|
||||
const uint out_x = idx_arr[BATCH_DIMS+3];
|
||||
const uint out_y = idx_arr[BATCH_DIMS+2];
|
||||
const uint out_z = idx_arr[BATCH_DIMS+1];
|
||||
#else
|
||||
const uint out_x = idx_arr[BATCH_DIMS+4];
|
||||
const uint out_y = idx_arr[BATCH_DIMS+3];
|
||||
const uint out_z = idx_arr[BATCH_DIMS+2];
|
||||
const uint out_w = idx_arr[BATCH_DIMS+1];
|
||||
#endif
|
||||
const uint out_f = idx_arr[BATCH_DIMS+0];
|
||||
const uint out_b = output_batch_size;
|
||||
#endif
|
||||
|
||||
const uint output_idx = GET_OUTPUT_INDEX(OUT_ORDER);
|
||||
|
||||
// Copy data to output as slice size
|
||||
#if HAS_FUSED_OPS
|
||||
#if OUTPUT_DIMS == 4
|
||||
const uint y_pitch = OUTPUT_SIZE_X;
|
||||
const uint f_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||
#elif OUTPUT_DIMS == 5
|
||||
const uint y_pitch = OUTPUT_SIZE_X;
|
||||
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||
const uint f_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||
#else
|
||||
const uint y_pitch = OUTPUT_SIZE_X;
|
||||
const uint z_pitch = y_pitch * OUTPUT_SIZE_Y;
|
||||
const uint w_pitch = z_pitch * OUTPUT_SIZE_Z;
|
||||
const uint f_pitch = w_pitch * OUTPUT_SIZE_W;
|
||||
#endif
|
||||
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < WI_SLICE_SIZE; i++) {
|
||||
uint dst_idx = output_idx + i;
|
||||
INPUT0_TYPE val = data[data_idx + i];
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
const uint b_remain = dst_idx % b_pitch;
|
||||
const uint f_remain = b_remain % f_pitch;
|
||||
#if OUTPUT_DIMS == 4
|
||||
const uint y_remain = f_remain % y_pitch;
|
||||
|
||||
const uint y = f_remain / y_pitch;
|
||||
#elif OUTPUT_DIMS == 5
|
||||
const uint z_remain = f_remain % z_pitch;
|
||||
const uint y_remain = z_remain % y_pitch;
|
||||
|
||||
const uint z = f_remain / z_pitch;
|
||||
const uint y = z_remain / y_pitch;
|
||||
#else
|
||||
const uint w_remain = f_remain % w_pitch;
|
||||
const uint z_remain = w_remain % z_pitch;
|
||||
const uint y_remain = z_remain % y_pitch;
|
||||
|
||||
const uint w = f_remain / w_pitch;
|
||||
const uint z = w_remain / z_pitch;
|
||||
const uint y = z_remain / y_pitch;
|
||||
#endif
|
||||
const uint b = dst_idx / b_pitch;
|
||||
const uint f = b_remain / f_pitch;
|
||||
const uint x = y_remain;
|
||||
|
||||
#if FUSED_OPS_CAN_USE_PRELOAD
|
||||
FUSED_OPS_PRELOAD;
|
||||
FUSED_OPS_CALC;
|
||||
#else
|
||||
FUSED_OPS;
|
||||
#endif
|
||||
|
||||
output[dst_idx] = FUSED_OPS_RESULT;
|
||||
#else
|
||||
output[dst_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#undef INDICES_MAX_DIM
|
||||
#undef GET_UPDATES_INDEX
|
||||
#undef GET_OUTPUT_INDEX
|
||||
#undef OUT_ORDER
|
||||
#undef IDX_ORDER
|
||||
#undef IN_ORDER
|
114
inference-engine/thirdparty/clDNN/src/gather_nd.cpp
vendored
Normal file
114
inference-engine/thirdparty/clDNN/src/gather_nd.cpp
vendored
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_nd_inst.h"
|
||||
|
||||
#include "primitive_type_base.h"
|
||||
#include "error_handler.h"
|
||||
#include "json_object.h"
|
||||
#include <string>
|
||||
|
||||
namespace cldnn {
|
||||
primitive_type_id gather_nd::type_id() {
|
||||
static primitive_type_base<gather_nd> instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
layout gather_nd_inst::calc_output_layout(gather_nd_node const& node) {
|
||||
auto op = node.get_primitive();
|
||||
|
||||
auto input_layout_origin = node.input(0).get_output_layout();
|
||||
auto indices_layout_origin = node.input(1).get_output_layout();
|
||||
|
||||
auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format);
|
||||
auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format);
|
||||
|
||||
const size_t input_dims = input_layout.size();
|
||||
|
||||
const auto indices_rank = op->indices_rank;
|
||||
const auto batch_dims = op->batch_dims;
|
||||
|
||||
// calculate initial output shape
|
||||
std::vector<tensor::value_type> output_sizes;
|
||||
|
||||
for (uint8_t x = 0; x < indices_rank - 1; x++) {
|
||||
output_sizes.push_back(indices_layout[x]);
|
||||
}
|
||||
|
||||
const size_t indices_last_dim = indices_layout[indices_rank - 1];
|
||||
for (size_t x = static_cast<size_t>(batch_dims + indices_last_dim); x < input_dims; x++) {
|
||||
output_sizes.push_back(input_layout[x]);
|
||||
}
|
||||
|
||||
// calculate batch_size by batch_dims
|
||||
int batch_size = 1;
|
||||
for (uint8_t x = 0; x < batch_dims; x++) {
|
||||
batch_size *= output_sizes[x];
|
||||
}
|
||||
|
||||
// create final output shape by batch_dims
|
||||
std::vector<tensor::value_type> final_output_sizes;
|
||||
|
||||
if (batch_dims > 0) {
|
||||
final_output_sizes.push_back(batch_size);
|
||||
}
|
||||
|
||||
for (size_t x = static_cast<size_t>(batch_dims); x < output_sizes.size(); x++) {
|
||||
final_output_sizes.push_back(output_sizes[x]);
|
||||
}
|
||||
|
||||
auto output_format = cldnn::format::bfyx;
|
||||
if (final_output_sizes.size() >= 6) {
|
||||
output_format = cldnn::format::bfwzyx;
|
||||
} else if (final_output_sizes.size() == 5) {
|
||||
output_format = cldnn::format::bfzyx;
|
||||
}
|
||||
|
||||
auto output_sizes_tensor = tensor(tensor(final_output_sizes).sizes(output_format));
|
||||
auto padding = op->output_padding;
|
||||
|
||||
|
||||
if (node.has_fused_primitives()) {
|
||||
input_layout_origin.data_type = node.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
return layout(input_layout_origin.data_type, output_format, output_sizes_tensor, padding);
|
||||
}
|
||||
|
||||
std::string gather_nd_inst::to_string(gather_nd_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto node_info = node.desc_to_json();
|
||||
auto& input = node.input();
|
||||
|
||||
std::stringstream primitive_description;
|
||||
|
||||
json_composite gather_nd_info;
|
||||
gather_nd_info.add("input id", input.id());
|
||||
gather_nd_info.add("input shape", node.input(0).get_output_layout().size.to_string());
|
||||
gather_nd_info.add("indices shape", node.input(1).get_output_layout().size.to_string());
|
||||
gather_nd_info.add("indices rank", desc->indices_rank);
|
||||
gather_nd_info.add("batch dims", desc->batch_dims);
|
||||
gather_nd_info.add("output shape", calc_output_layout(node).size.to_string());
|
||||
|
||||
node_info->add("gather_nd info", gather_nd_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
||||
return primitive_description.str();
|
||||
}
|
||||
|
||||
gather_nd_inst::typed_primitive_inst(network_impl& network, gather_nd_node const& node) : parent(network, node) {}
|
||||
|
||||
} // namespace cldnn
|
78
inference-engine/thirdparty/clDNN/src/gpu/gather_nd_gpu.cpp
vendored
Normal file
78
inference-engine/thirdparty/clDNN/src/gpu/gather_nd_gpu.cpp
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_nd_inst.h"
|
||||
#include "primitive_gpu_base.h"
|
||||
#include "implementation_map.h"
|
||||
#include "kernel_selector_helper.h"
|
||||
#include "gather/gather_nd_kernel_selector.h"
|
||||
#include "gather/gather_nd_kernel_ref.h"
|
||||
#include "error_handler.h"
|
||||
|
||||
using namespace cldnn;
|
||||
|
||||
namespace cldnn {
|
||||
namespace gpu {
|
||||
|
||||
struct gather_nd_gpu : typed_primitive_gpu_impl<gather_nd> {
|
||||
using parent = typed_primitive_gpu_impl<gather_nd>;
|
||||
using parent::parent;
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const gather_nd_node& arg) {
|
||||
auto gather_nd_params = get_default_params<kernel_selector::gather_nd_params>(arg);
|
||||
auto gather_nd_optional_params =
|
||||
get_default_optional_params<kernel_selector::gather_nd_optional_params>(arg.get_program());
|
||||
|
||||
gather_nd_params.indices_rank = arg.get_primitive()->indices_rank;
|
||||
gather_nd_params.batch_dims = arg.get_primitive()->batch_dims;
|
||||
|
||||
gather_nd_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
|
||||
|
||||
auto& kernel_selector = kernel_selector::gather_nd_kernel_selector::Instance();
|
||||
auto best_kernels = kernel_selector.GetBestKernels(gather_nd_params, gather_nd_optional_params);
|
||||
|
||||
CLDNN_ERROR_BOOL(arg.id(),
|
||||
"Best_kernel.empty()",
|
||||
best_kernels.empty(),
|
||||
"Cannot find a proper kernel with this arguments");
|
||||
|
||||
auto gather_nd = new gather_nd_gpu(arg, best_kernels[0]);
|
||||
|
||||
return gather_nd;
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
attach_gather_nd_gpu::attach_gather_nd_gpu() {
|
||||
auto val_fw = gather_nd_gpu::create;
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
|
||||
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
|
||||
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
|
||||
implementation_map<gather_nd>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace gpu
|
||||
} // namespace cldnn
|
@ -32,6 +32,7 @@ void register_implementations_gpu() {
|
||||
REGISTER_GPU(eltwise);
|
||||
REGISTER_GPU(fully_connected);
|
||||
REGISTER_GPU(gather);
|
||||
REGISTER_GPU(gather_nd);
|
||||
REGISTER_GPU(gemm);
|
||||
REGISTER_GPU(input_layout);
|
||||
REGISTER_GPU(lrn);
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "api/eltwise.hpp"
|
||||
#include "api/fully_connected.hpp"
|
||||
#include "api/gather.hpp"
|
||||
#include "api/gather_nd.hpp"
|
||||
#include "api/gemm.hpp"
|
||||
#include "api/input_layout.hpp"
|
||||
#include "api/lrn.hpp"
|
||||
@ -100,6 +101,7 @@ REGISTER_GPU(eltwise);
|
||||
REGISTER_GPU(embed);
|
||||
REGISTER_GPU(fully_connected);
|
||||
REGISTER_GPU(gather);
|
||||
REGISTER_GPU(gather_nd);
|
||||
REGISTER_GPU(gemm);
|
||||
REGISTER_GPU(input_layout);
|
||||
REGISTER_GPU(lookup_table);
|
||||
|
@ -32,6 +32,7 @@
|
||||
#include "depth_to_space_inst.h"
|
||||
#include "space_to_depth_inst.h"
|
||||
#include "gather_inst.h"
|
||||
#include "gather_nd_inst.h"
|
||||
#include "scatter_update_inst.h"
|
||||
#include "scatter_nd_update_inst.h"
|
||||
#include "scatter_elements_update_inst.h"
|
||||
@ -196,6 +197,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
|
||||
!input.is_type<depth_to_space>() && !input.is_type<batch_to_space>() &&
|
||||
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
|
||||
!input.is_type<scatter_nd_update>() &&
|
||||
!input.is_type<gather_nd>() &&
|
||||
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
|
||||
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
|
||||
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
|
||||
@ -528,6 +530,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather>();
|
||||
|
||||
should_fuse |= input_data.is_type<gather_nd>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_update>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||
@ -594,6 +598,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather>();
|
||||
|
||||
should_fuse |= input_data.is_type<gather_nd>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_update>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||
@ -682,6 +688,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<gather_nd>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
|
||||
@ -741,6 +749,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
(parents[i]->is_type<space_to_batch>()) ||
|
||||
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
|
||||
(parents[i]->is_type<scale>()) ||
|
||||
(parents[i]->is_type<gather_nd>()) ||
|
||||
(parents[i]->is_type<scatter_nd_update>()) ||
|
||||
(parents[i]->is_type<scatter_elements_update>()) ||
|
||||
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
|
||||
|
49
inference-engine/thirdparty/clDNN/src/include/gather_nd_inst.h
vendored
Normal file
49
inference-engine/thirdparty/clDNN/src/include/gather_nd_inst.h
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#pragma once
|
||||
#include "api/gather_nd.hpp"
|
||||
#include "primitive_inst.h"
|
||||
#include <string>
|
||||
|
||||
namespace cldnn {
|
||||
template <>
|
||||
struct typed_program_node<gather_nd> : public typed_program_node_base<gather_nd> {
|
||||
using parent = typed_program_node_base<gather_nd>;
|
||||
|
||||
public:
|
||||
using parent::parent;
|
||||
|
||||
program_node& input(size_t index = 0) const { return get_dependency(index); }
|
||||
};
|
||||
|
||||
using gather_nd_node = typed_program_node<gather_nd>;
|
||||
|
||||
template <>
|
||||
class typed_primitive_inst<gather_nd> : public typed_primitive_inst_base<gather_nd> {
|
||||
using parent = typed_primitive_inst_base<gather_nd>;
|
||||
|
||||
public:
|
||||
static layout calc_output_layout(gather_nd_node const& node);
|
||||
static std::string to_string(gather_nd_node const& node);
|
||||
|
||||
public:
|
||||
typed_primitive_inst(network_impl& network, gather_nd_node const& desc);
|
||||
};
|
||||
|
||||
using gather_nd_inst = typed_primitive_inst<gather_nd>;
|
||||
} // namespace cldnn
|
@ -22,6 +22,7 @@
|
||||
#include "api/deconvolution.hpp"
|
||||
#include "api/permute.hpp"
|
||||
#include "api/gather.hpp"
|
||||
#include "api/gather_nd.hpp"
|
||||
#include "api/scatter_update.hpp"
|
||||
#include "api/scatter_nd_update.hpp"
|
||||
#include "api/scatter_elements_update.hpp"
|
||||
@ -5615,6 +5616,7 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_scale_activation,
|
||||
gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
|
||||
}), );
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ------------------------------------------ ScatterUpdate cases --------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
@ -7829,3 +7831,199 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise,
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP32_6D_6, 2, 5 },
|
||||
}), );
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ------------------------------------------ GatherND cases ------------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
struct gather_nd_test_params {
|
||||
data_types data_type;
|
||||
|
||||
format input_format;
|
||||
tensor input_shape;
|
||||
|
||||
format indices_format;
|
||||
tensor indices_shape;
|
||||
|
||||
format output_format;
|
||||
tensor output_shape;
|
||||
|
||||
int max_number_in_indices;
|
||||
int indices_rank;
|
||||
int batch_dims;
|
||||
|
||||
data_types default_type;
|
||||
format default_format;
|
||||
|
||||
size_t expected_fused_primitives;
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
#define CASE_GATHER_ND_FP16_4D_1 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 7, 9, 8}, 6, 2, 0, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_4D_2 data_types::f16, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 8, 1, 9}, 6, 2, 1, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_4D_3 data_types::f16, format::bfyx, {5, 4, 7, 2}, format::bfyx, {5, 4, 1, 2}, format::bfyx, {40, 1, 1, 1}, 6, 4, 3, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ND_FP16_5D_1 data_types::f16, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfzyx, {5, 6, 7, 8, 5}, 5, 2, 0, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_5D_2 data_types::f16, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfyx, {5, 5, 7, 8}, 5, 2, 1, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_5D_3 data_types::f16, format::bfzyx, {5, 4, 7, 8, 5}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {20, 1, 1, 1}, 4, 3, 2, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_5D_4 data_types::f16, format::bfzyx, {5, 4, 7, 8, 3}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {60, 7, 1, 1}, 4, 4, 3, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_5D_5 data_types::f16, format::bfzyx, {5, 4, 7, 2, 3}, format::bfzyx, {5, 4, 1, 2, 3}, format::bfyx, {120, 1, 1, 1}, 4, 5, 4, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_5D_6 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {20, 3, 7, 4, 1}, 4, 5, 2, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ND_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 5}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {20, 2, 6, 7}, 5, 4, 2, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_6D_2 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {40, 6, 1, 1}, 5, 4, 3, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_6D_3 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 2, 2}, format::bfzyx, {5, 4, 1, 2, 2}, format::bfyx, {80, 6, 1, 1}, 5, 5, 4, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ND_FP16_6D_4 data_types::f16, format::bfwzyx, {5, 4, 6, 3, 2, 2}, format::bfwzyx, {5, 4, 1, 3, 2, 2}, format::bfyx, {240, 1, 1, 1}, 5, 6, 5, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ND_FP32_4D_1 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {3, 1, 1, 1}, format::bfyx, {3, 7, 9, 8}, 6, 2, 0, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_4D_2 data_types::f32, format::bfyx, {6, 7, 9, 8}, format::bfyx, {6, 1, 1, 1}, format::bfyx, {6, 8, 1, 9}, 6, 2, 1, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_4D_3 data_types::f32, format::bfyx, {5, 4, 7, 2}, format::bfyx, {5, 4, 1, 2}, format::bfyx, {40, 1, 1, 1}, 6, 4, 3, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ND_FP32_5D_1 data_types::f32, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfzyx, {5, 6, 7, 8, 5}, 5, 2, 0, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_5D_2 data_types::f32, format::bfzyx, {5, 6, 7, 8, 5}, format::bfyx, {5, 1, 1, 1}, format::bfyx, {5, 5, 7, 8}, 5, 2, 1, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_5D_3 data_types::f32, format::bfzyx, {5, 4, 7, 8, 5}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {20, 1, 1, 1}, 4, 3, 2, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_5D_4 data_types::f32, format::bfzyx, {5, 4, 7, 8, 3}, format::bfyx, {5, 4, 1, 3}, format::bfyx, {60, 7, 1, 1}, 4, 4, 3, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_5D_5 data_types::f32, format::bfzyx, {5, 4, 7, 2, 3}, format::bfzyx, {5, 4, 1, 2, 3}, format::bfyx, {120, 1, 1, 1}, 4, 5, 4, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_5D_6 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 1, 1, 3}, format::bfzyx, {20, 3, 7, 4, 1}, 4, 5, 2, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ND_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 5}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {20, 2, 6, 7}, 5, 4, 2, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_6D_2 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfyx, {5, 4, 2, 2}, format::bfyx, {40, 6, 1, 1}, 5, 4, 3, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_6D_3 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 2, 2}, format::bfzyx, {5, 4, 1, 2, 2}, format::bfyx, {80, 6, 1, 1}, 5, 5, 4, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ND_FP32_6D_4 data_types::f32, format::bfwzyx, {5, 4, 6, 3, 2, 2}, format::bfwzyx, {5, 4, 1, 3, 2, 2}, format::bfyx, {240, 1, 1, 1}, 5, 6, 5, data_types::f32, format::bfyx
|
||||
|
||||
|
||||
|
||||
class GatherNDPrimitiveFusingTest : public ::BaseFusingTest<gather_nd_test_params> {
|
||||
public:
|
||||
void execute(gather_nd_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
compare(network_not_fused, network_fused, p);
|
||||
}
|
||||
|
||||
layout get_input_layout(gather_nd_test_params& p) {
|
||||
return layout{ p.data_type, p.input_format, p.input_shape };
|
||||
}
|
||||
|
||||
layout get_indices_layout(gather_nd_test_params& p) {
|
||||
return layout{ p.data_type, p.indices_format, p.indices_shape };
|
||||
}
|
||||
|
||||
layout get_output_layout(gather_nd_test_params& p) {
|
||||
return layout{ p.data_type, p.output_format, p.output_shape };
|
||||
}
|
||||
|
||||
layout get_per_channel_layout(gather_nd_test_params& p) {
|
||||
return layout{ p.default_type, p.default_format, tensor{1, p.output_shape.feature[0], 1, 1} };
|
||||
}
|
||||
};
|
||||
|
||||
class gather_nd_quantize : public GatherNDPrimitiveFusingTest {};
|
||||
TEST_P(gather_nd_quantize, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
|
||||
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
|
||||
quantize("quantize", "gather_nd_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
||||
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||
);
|
||||
tolerance = 1.f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_nd_quantize,
|
||||
::testing::ValuesIn(std::vector<gather_nd_test_params>{
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_1, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_2, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_3, 2, 3 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_1, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_2, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_3, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_4, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_5, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_6, 2, 3 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_1, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_2, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_3, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_4, 2, 3 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_1, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_2, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_3, 2, 3 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_1, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_2, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_3, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_4, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_5, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_6, 2, 3 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_1, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_2, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 3 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 3 },
|
||||
}), );
|
||||
|
||||
class gather_nd_activation_scale_eltwise : public GatherNDPrimitiveFusingTest {};
|
||||
TEST_P(gather_nd_activation_scale_eltwise, basic) {
|
||||
auto p = GetParam();
|
||||
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)),
|
||||
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
|
||||
activation("activation", "gather_nd_prim", activation_func::abs),
|
||||
scale("scale", "activation", "scale_data"),
|
||||
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
|
||||
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_nd_activation_scale_eltwise,
|
||||
::testing::ValuesIn(std::vector<gather_nd_test_params>{
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_1, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_2, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_4D_3, 2, 5 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_1, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_2, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_3, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_4, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_5, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_5D_6, 2, 5 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_1, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_2, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_3, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP16_6D_4, 2, 5 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_1, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_2, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_4D_3, 2, 5 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_1, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_2, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_3, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_4, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_5, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_5D_6, 2, 5 },
|
||||
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_1, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_2, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 },
|
||||
}), );
|
||||
|
||||
|
730
inference-engine/thirdparty/clDNN/tests/test_cases/gather_nd_gpu_test.cpp
vendored
Normal file
730
inference-engine/thirdparty/clDNN/tests/test_cases/gather_nd_gpu_test.cpp
vendored
Normal file
@ -0,0 +1,730 @@
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <api/input_layout.hpp>
|
||||
#include <api/memory.hpp>
|
||||
#include <api/gather_nd.hpp>
|
||||
#include <api/topology.hpp>
|
||||
#include <api/network.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <tests/test_utils/test_utils.h>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
inline void DoTest(const engine& engine,
|
||||
const cldnn::memory& input0,
|
||||
const cldnn::memory& input1,
|
||||
const std::vector<float>& expected_results,
|
||||
const int indices_rank,
|
||||
const int batch_dims) {
|
||||
topology topology;
|
||||
topology.add(input_layout("InputData", input0.get_layout()));
|
||||
topology.add(input_layout("InputIndices", input1.get_layout()));
|
||||
topology.add(
|
||||
gather_nd("gather_nd", "InputData", "InputIndices", indices_rank, batch_dims)
|
||||
);
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("InputData", input0);
|
||||
network.set_input_data("InputIndices", input1);
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("gather_nd").get_memory();
|
||||
auto output_ptr = output.pointer<uint16_t>();
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
EXPECT_EQ(expected_results[i], float16_to_float32(output_ptr[i]));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d23322_i231312_ir6_batch2) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 6;
|
||||
const int batch_dims = 2;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 2, 2, 3 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 1, 3, 1 } }); // indices
|
||||
// expected output dim: {6,1,3,1,2}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0), FLOAT16(2), FLOAT16(0),
|
||||
FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(2), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(15), FLOAT16(16), FLOAT16(11), FLOAT16(12), FLOAT16(11), FLOAT16(12),
|
||||
FLOAT16(25), FLOAT16(26), FLOAT16(23), FLOAT16(24), FLOAT16(23), FLOAT16(24),
|
||||
FLOAT16(33), FLOAT16(34), FLOAT16(33), FLOAT16(34), FLOAT16(33), FLOAT16(34),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(25), FLOAT16(26), FLOAT16(25), FLOAT16(26),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(35), FLOAT16(36), FLOAT16(33), FLOAT16(34),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d231322_i231321_ir6_batch5) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 6;
|
||||
const int batch_dims = 5;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 2, 2, 3, 1 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 2, 3, 1, 2, 3, 1 } }); // indices
|
||||
// expected output dim: {36}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(21), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(31), FLOAT16(30),
|
||||
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(29), FLOAT16(30),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(31),
|
||||
|
||||
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(29),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d23322_i23321_ir5_batch4) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 5;
|
||||
const int batch_dims = 4;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 2, 2, 3 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 1, 2, 3 } }); // indices
|
||||
// expected output dim: {36}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(21), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(31), FLOAT16(30),
|
||||
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(10), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(20), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(37), FLOAT16(38), FLOAT16(39), FLOAT16(30), FLOAT16(29), FLOAT16(30),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(31),
|
||||
|
||||
FLOAT16(12), FLOAT16(14), FLOAT16(16), FLOAT16(18), FLOAT16(10), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(23), FLOAT16(25), FLOAT16(27), FLOAT16(29), FLOAT16(27),
|
||||
FLOAT16(32), FLOAT16(33), FLOAT16(35), FLOAT16(38), FLOAT16(30), FLOAT16(29),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d23223_i2321_ir4_batch3) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 4;
|
||||
const int batch_dims = 3;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 2, 3, 3, 2, 2 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 2 } }); // indices
|
||||
// expected output dim: {2*3*2,3}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(1), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(1),
|
||||
|
||||
FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(25),
|
||||
FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(17), FLOAT16(18), FLOAT16(15),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(35), FLOAT16(36), FLOAT16(33),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d2342_i2312_ir4_batch2) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 4;
|
||||
const int batch_dims = 2;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 4 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 2, 1 } }); // indices
|
||||
// expected output dim: {6,1}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28),
|
||||
FLOAT16(29), FLOAT16(30), FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
|
||||
FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(2), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(14),
|
||||
FLOAT16(21),
|
||||
FLOAT16(34),
|
||||
|
||||
FLOAT16(11),
|
||||
FLOAT16(26),
|
||||
FLOAT16(33),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d234_i2311_ir4_batch2) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 4;
|
||||
const int batch_dims = 2;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 4 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 1 } }); // indices
|
||||
// expected output dim: {6,1,1}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4),
|
||||
FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(20),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1),
|
||||
FLOAT16(0),
|
||||
FLOAT16(2),
|
||||
|
||||
FLOAT16(0),
|
||||
FLOAT16(2),
|
||||
FLOAT16(2),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(2),
|
||||
FLOAT16(5),
|
||||
FLOAT16(11),
|
||||
|
||||
FLOAT16(13),
|
||||
FLOAT16(19),
|
||||
FLOAT16(23),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d234_i21_ir2_batch1) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 2;
|
||||
const int batch_dims = 1;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 3, 1, 4 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices
|
||||
// expected output dim: {2,4}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4),
|
||||
FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(17), FLOAT16(18), FLOAT16(19), FLOAT16(20),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1),
|
||||
FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(5), FLOAT16(6), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d22_i21_ir2_batch1) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 2;
|
||||
const int batch_dims = 1;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices
|
||||
// expected output dim: 2
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(3), FLOAT16(4),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1),
|
||||
FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(2),
|
||||
FLOAT16(3),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d3223_i321113_ir6_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 6;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 3, 2 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 3, 2, 3, 1, 1, 1 } }); // indices
|
||||
// expected output dim: 321113
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(0),
|
||||
|
||||
FLOAT16(0), FLOAT16(1), FLOAT16(0),
|
||||
FLOAT16(2), FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33),
|
||||
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23),
|
||||
FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43),
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d3221_i32312_ir3_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 3;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 1, 3 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 2 } }); // indices
|
||||
// expected output dim: 32312
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(1), FLOAT16(0),
|
||||
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
FLOAT16(2), FLOAT16(0),
|
||||
|
||||
FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d3231_i32312_ir3_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 3;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfzyx, { 3, 2, 2, 1, 3 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 3 } }); // indices
|
||||
// expected output dim: {3,2,1,2}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(11), FLOAT16(12), FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16),
|
||||
FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(25), FLOAT16(26),
|
||||
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(2),
|
||||
|
||||
FLOAT16(0), FLOAT16(1), FLOAT16(0),
|
||||
FLOAT16(2), FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(0), FLOAT16(0), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(63), FLOAT16(64),
|
||||
FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(21), FLOAT16(22),
|
||||
FLOAT16(53), FLOAT16(54),
|
||||
|
||||
FLOAT16(45), FLOAT16(46),
|
||||
FLOAT16(11), FLOAT16(12),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d3112_i3221_ir4_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 4;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 1, 2, 1 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 2 } }); // indices
|
||||
// expected output dim: {3,2,2,1,1,2}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(13), FLOAT16(14),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d311211_i322111_ir4_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 4;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 3, 1, 1, 1, 2, 1 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfwzyx, { 3, 2, 1, 1, 1, 2 } }); // indices
|
||||
// expected output dim: {3,2,2,1,1,2,1,1}
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(13), FLOAT16(14),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(7), FLOAT16(8),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(7), FLOAT16(8),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d3332_i3223_ir4_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 4;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 3, 2 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 3, 2 } }); // indices
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(5), FLOAT16(6),
|
||||
FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
|
||||
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||
FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(30),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(2), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(0),
|
||||
|
||||
FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(2), FLOAT16(0), FLOAT16(0), FLOAT16(2), FLOAT16(1), FLOAT16(0),
|
||||
|
||||
FLOAT16(1), FLOAT16(1), FLOAT16(1), FLOAT16(0), FLOAT16(1), FLOAT16(1),
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(1), FLOAT16(0), FLOAT16(2), FLOAT16(1),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(61), FLOAT16(62), FLOAT16(63),
|
||||
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(25), FLOAT16(26), FLOAT16(27),
|
||||
|
||||
FLOAT16(22), FLOAT16(23), FLOAT16(24), FLOAT16(28), FLOAT16(29), FLOAT16(30),
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(51), FLOAT16(52), FLOAT16(53),
|
||||
|
||||
FLOAT16(28), FLOAT16(29), FLOAT16(30), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||
FLOAT16(34), FLOAT16(35), FLOAT16(36), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d3323_i322_ir3_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 3;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 3, 3, 2 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 2 } }); // indices
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2), FLOAT16(3), FLOAT16(4), FLOAT16(5), FLOAT16(6),
|
||||
FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
|
||||
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||
FLOAT16(25), FLOAT16(26), FLOAT16(27), FLOAT16(28), FLOAT16(29), FLOAT16(30),
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
FLOAT16(61), FLOAT16(62), FLOAT16(63), FLOAT16(64), FLOAT16(65), FLOAT16(66),
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(2), FLOAT16(0),
|
||||
FLOAT16(2), FLOAT16(1),
|
||||
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(1), FLOAT16(0),
|
||||
|
||||
FLOAT16(0), FLOAT16(1),
|
||||
FLOAT16(0), FLOAT16(2),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(41), FLOAT16(42), FLOAT16(43), FLOAT16(44), FLOAT16(45), FLOAT16(46),
|
||||
FLOAT16(51), FLOAT16(52), FLOAT16(53), FLOAT16(54), FLOAT16(55), FLOAT16(56),
|
||||
|
||||
FLOAT16(31), FLOAT16(32), FLOAT16(33), FLOAT16(34), FLOAT16(35), FLOAT16(36),
|
||||
FLOAT16(19), FLOAT16(20), FLOAT16(21), FLOAT16(22), FLOAT16(23), FLOAT16(24),
|
||||
|
||||
FLOAT16(7), FLOAT16(8), FLOAT16(9), FLOAT16(10), FLOAT16(11), FLOAT16(12),
|
||||
FLOAT16(13), FLOAT16(14), FLOAT16(15), FLOAT16(16), FLOAT16(17), FLOAT16(18),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d22_i21_ir2_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 2;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 1, 1, 1 } }); // indices
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(3), FLOAT16(4)
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(1), FLOAT16(0),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(3), FLOAT16(4),
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
};
|
||||
|
||||
DoTest(engine, input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
||||
TEST(gather_nd_gpu_fp16, d22_i32_ir2_batch0) {
|
||||
const auto& engine = get_test_engine();
|
||||
|
||||
const int indices_rank = 2;
|
||||
const int batch_dims = 0;
|
||||
auto input0 = memory::allocate(engine, { data_types::f16, format::bfyx, { 2, 2, 1, 1 } }); // data
|
||||
auto input1 = memory::allocate(engine, { data_types::f16, format::bfyx, { 3, 2, 1, 1 } }); // indices
|
||||
|
||||
set_values(input0, {
|
||||
FLOAT16(1), FLOAT16(2),
|
||||
FLOAT16(3), FLOAT16(4)
|
||||
});
|
||||
|
||||
set_values(input1, {
|
||||
FLOAT16(0), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(0),
|
||||
FLOAT16(1), FLOAT16(1),
|
||||
});
|
||||
|
||||
std::vector<float> expected_results = {
|
||||
FLOAT16(1),
|
||||
FLOAT16(3),
|
||||
FLOAT16(4),
|
||||
};
|
||||
|
||||
DoTest(engine,input0, input1, expected_results, indices_rank, batch_dims);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user