[GPU] Add new operation GatherElements to IE clDNN plugin (#6676)

This commit is contained in:
Yunji Kim 2021-07-26 23:52:27 +09:00 committed by GitHub
parent b4ad7a1755
commit f5666fb3e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2296 additions and 6 deletions

View File

@ -204,6 +204,7 @@ REGISTER_FACTORY(v5, Loop);
// ------------------------------ Supported v6 ops ------------------------------ //
REGISTER_FACTORY(v6, CTCGreedyDecoderSeqLen);
REGISTER_FACTORY(v6, MVN);
REGISTER_FACTORY(v6, GatherElements);
// ------------------------------ Supported v7 ops ------------------------------ //
REGISTER_FACTORY(v7, Gather);

View File

@ -0,0 +1,66 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "cldnn_program.h"
#include "cldnn_common_utils.h"
#include "ngraph/op/gather_elements.hpp"
#include "ngraph/op/constant.hpp"
#include "cldnn/primitives/gather_elements.hpp"
namespace CLDNNPlugin {
static cldnn::gather_elements::gather_elements_axis GetGatherAxis(int axis, unsigned rank) {
if (axis < 0)
axis += rank;
if (axis < 0 || axis >= rank)
IE_THROW() << "GatherElements axis is not correspond to number of dimensions";
// Difference in dimension ordering between IE and clDNN,
// reverse spatial dimensions after batch and feature.
unsigned cldnn_axis = axis;
if (axis >= 2) {
auto spatial_axis = axis - 2;
// Default and minimum number of dimensions is 4
auto spatial_size = std::max(rank, 4u) - 2;
cldnn_axis = spatial_size - spatial_axis - 1 + 2;
}
switch (cldnn_axis) {
case 0: return cldnn::gather_elements::gather_elements_axis::along_b;
case 1: return cldnn::gather_elements::gather_elements_axis::along_f;
case 2: return cldnn::gather_elements::gather_elements_axis::along_x;
case 3: return cldnn::gather_elements::gather_elements_axis::along_y;
case 4: return cldnn::gather_elements::gather_elements_axis::along_z;
case 5: return cldnn::gather_elements::gather_elements_axis::along_w;
default: IE_THROW() << "Unsupported GatherElements axis: " << axis;
}
return cldnn::gather_elements::gather_elements_axis::along_f; // shouldn't get here
}
void CreateGatherElementsOp(Program& p, const std::shared_ptr<ngraph::op::v6::GatherElements>& op) {
p.ValidateInputs(op, {2});
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);
size_t rank = op->get_input_shape(0).size();
int32_t axis = static_cast<int32_t>(op->get_axis());
auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size());
auto primitive = cldnn::gather_elements(layerName,
inputPrimitives[0],
inputPrimitives[1],
outLayout,
CldnnTensorFromIEDims(op->get_output_shape(0)),
GetGatherAxis(axis, rank));
p.AddPrimitive(primitive);
p.AddPrimitiveToProfiler(op);
}
REGISTER_FACTORY_IMPL(v6, GatherElements);
} // namespace CLDNNPlugin

View File

@ -4,7 +4,8 @@
#include <vector>
#include "shared_test_classes/single_layer/gather_elements.hpp"
#include "single_layer_tests/gather_elements.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
@ -16,8 +17,6 @@ const std::vector<InferenceEngine::Precision> dPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
InferenceEngine::Precision::I16,
InferenceEngine::Precision::U8,
InferenceEngine::Precision::I8
};
const std::vector<InferenceEngine::Precision> iPrecisions = {
InferenceEngine::Precision::I32,

View File

@ -0,0 +1,227 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <ngraph/opsets/opset6.hpp>
#include "single_layer_tests/gather_elements.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset6;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>({2, 2})),
::testing::Values(std::vector<size_t>({2, 2})),
::testing::ValuesIn(std::vector<int>({-1, 0, 1})),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>({2, 2, 1})),
::testing::Values(std::vector<size_t>({4, 2, 1})),
::testing::ValuesIn(std::vector<int>({0, -3})),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>({2, 2, 3, 5})),
::testing::Values(std::vector<size_t>({2, 2, 3, 7})),
::testing::Values(3, -1),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>({3, 2, 3, 8})),
::testing::Values(std::vector<size_t>({2, 2, 3, 8})),
::testing::Values(0, -4),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>({3, 2, 3, 4, 8})),
::testing::Values(std::vector<size_t>({3, 2, 3, 5, 8})),
::testing::Values(3, -2),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis0, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{7, 7, 8, 4}),
::testing::Values(std::vector<size_t>{2, 7, 8, 4}),
::testing::Values(0),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis1, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{6, 1, 8, 4}),
::testing::Values(std::vector<size_t>{6, 8, 8, 4}),
::testing::Values(1, -3),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis2, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{6, 7, 4, 4}),
::testing::Values(std::vector<size_t>{6, 7, 2, 4}),
::testing::Values(2, -2),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis3, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{6, 5, 8, 7}),
::testing::Values(std::vector<size_t>{6, 5, 8, 7}),
::testing::Values(3, -1),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis0, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{2, 3, 9, 4, 9}),
::testing::Values(std::vector<size_t>{1, 3, 9, 4, 9}),
::testing::Values(0),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis1, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{2, 3, 5, 4, 7}),
::testing::Values(std::vector<size_t>{2, 9, 5, 4, 7}),
::testing::Values(1, -4),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis2, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{1, 2, 6, 8, 9}),
::testing::Values(std::vector<size_t>{1, 2, 6, 8, 9}),
::testing::Values(2, -3),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis3, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{2, 2, 4, 7, 7}),
::testing::Values(std::vector<size_t>{2, 2, 4, 3, 7}),
::testing::Values(3, -2),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis4, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{1, 3, 9, 3, 2}),
::testing::Values(std::vector<size_t>{1, 3, 9, 3, 9}),
::testing::Values(4, -1),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis0, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{3, 3, 2, 4, 4, 3}),
::testing::Values(std::vector<size_t>{7, 3, 2, 4, 4, 3}),
::testing::Values(0),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis1, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{1, 6, 2, 3, 5, 9}),
::testing::Values(std::vector<size_t>{1, 6, 2, 3, 5, 9}),
::testing::Values(1, -5),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis2, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{2, 3, 9, 7, 2, 1}),
::testing::Values(std::vector<size_t>{2, 3, 5, 7, 2, 1}),
::testing::Values(2, -4),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis3, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{1, 3, 4, 5, 1, 3}),
::testing::Values(std::vector<size_t>{1, 3, 4, 4, 1, 3}),
::testing::Values(3, -3),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis4, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{1, 3, 2, 4, 3, 3}),
::testing::Values(std::vector<size_t>{1, 3, 2, 4, 6, 3}),
::testing::Values(4, -2),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis5, GatherElementsLayerTest,
::testing::Combine(
::testing::Values(std::vector<size_t>{2, 1, 7, 8, 1, 6}),
::testing::Values(std::vector<size_t>{2, 1, 7, 8, 1, 5}),
::testing::Values(5, -1),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
GatherElementsLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,15 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/single_layer/gather_elements.hpp"
namespace LayerTestsDefinitions {
TEST_P(GatherElementsLayerTest, CompareWithRefs) {
Run();
}
} // namespace LayerTestsDefinitions

View File

@ -48,7 +48,4 @@ void GatherElementsLayerTest::SetUp() {
function = std::make_shared<ngraph::Function>(results, params, "gatherEl");
}
TEST_P(GatherElementsLayerTest, CompareWithRefs) {
Run();
}
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,58 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "primitive.hpp"
namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{
/// @brief
/// @details
struct gather_elements : public primitive_base<gather_elements> {
CLDNN_DECLARE_PRIMITIVE(gather_elements)
enum gather_elements_axis {
along_b,
along_f,
along_x,
along_y,
along_z,
along_w
};
/// @brief Constructs gather_elements primitive.
/// @param id This primitive id.
/// @param data Input data primitive id.
/// @param indices Input indexes primitive id.
/// @param output_format Output format.
/// @param output_shape Output shape.
/// @param axis Gathering axis.
gather_elements(const primitive_id& id,
const primitive_id& data,
const primitive_id& indices,
const format& output_format,
const tensor& output_shape,
const gather_elements_axis axis,
const padding& output_padding = padding())
: primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {}
/// @brief Gather Elements output format
format output_format;
/// @brief Gather Elements output shape
tensor output_shape;
/// @brief Which axis to gather on.
gather_elements_axis axis;
};
/// @}
/// @}
/// @}
} // namespace cldnn

View File

@ -48,6 +48,7 @@ enum class KernelType {
ONE_HOT,
GATHER,
GATHER_ND,
GATHER_ELEMENTS,
SCATTER_UPDATE,
SCATTER_ND_UPDATE,
SCATTER_ELEMENTS_UPDATE,

View File

@ -0,0 +1,154 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gather_elements_kernel_ref.h"
#include "kernel_selector_utils.h"
#include <string>
#include <vector>
namespace kernel_selector {
static size_t GetGatherElementsChannelIndex(const gather_elements_params& params) {
Tensor::DataChannelName name = Tensor::DataChannelName::X;
size_t inputSize = params.inputs[0].GetDims().size();
switch (params.axis) {
case GatherAxis::X:
return inputSize - 1;
case GatherAxis::Y:
return inputSize - 2;
case GatherAxis::Z:
return inputSize - 3;
case GatherAxis::W:
return 2;
case GatherAxis::FEATURE:
return 1;
case GatherAxis::BATCH:
return 0;
default:
break;
}
return DataTensor::Channelndex(params.output.GetLayout(), name);
}
ParamsKey GatherElementsKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
return k;
}
static inline std::vector<std::string> GetDefaultOrder(size_t size) {
std::vector<std::string> default_order;
if (size <= 4) {
default_order = { "b", "f", "y", "x" };
} else if (size == 5) {
default_order = { "b", "f", "z", "y", "x" };
} else if (size == 6) {
default_order = { "b", "f", "w", "z", "y", "x" };
}
return default_order;
}
CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_params& params, const optional_params&) const {
CommonDispatchData dispatchData;
const auto& output = params.output;
switch (params.inputs[1].GetLayout()) {
case DataLayout::bfyx:
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
break;
case DataLayout::bfzyx:
dispatchData.gws = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v};
break;
case DataLayout::bfwzyx:
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
break;
default:
throw std::invalid_argument("Unsupported data layout for gather elements primitive");
break;
}
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
return dispatchData;
}
JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);
jit.AddConstant(MakeJitConstant("AXIS", GetGatherElementsChannelIndex(params)));
if (!params.fused_ops.empty()) {
std::vector<std::string> idx_order = GetDefaultOrder(params.inputs[0].GetDims().size());
FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() };
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
}
return jit;
}
bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o) const {
if (p.GetType() != KernelType::GATHER_ELEMENTS || o.GetType() != KernelType::GATHER_ELEMENTS) {
return false;
}
const gather_elements_params& params = static_cast<const gather_elements_params&>(p);
auto input_dims = params.inputs[0].LogicalDims();
auto indices_dims = params.inputs[1].LogicalDims();
if (input_dims.size() != indices_dims.size()) {
return false;
}
for (auto& fused_op : params.fused_ops) {
if (!IsFusedPrimitiveSupported(fused_op))
return false;
}
return true;
}
KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}
KernelData kd = KernelData::Default<gather_elements_params>(params);
gather_elements_params& newParams = *static_cast<gather_elements_params*>(kd.params.get());
auto dispatchData = SetDefault(newParams, options);
auto cldnn_jit = GetJitConstants(newParams);
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
auto& kernel = kd.kernels[0];
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params));
return { kd };
}
KernelsPriority GatherElementsKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const {
return DONT_USE_IF_HAVE_SOMETHING_ELSE;
}
} // namespace kernel_selector

View File

@ -0,0 +1,45 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "kernel_base_opencl.h"
namespace kernel_selector {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// gather_elements_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct gather_elements_params : public base_params {
gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherAxis::BATCH) {}
GatherAxis axis;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// gather_elements_optional_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct gather_elements_optional_params : optional_params {
gather_elements_optional_params() : optional_params(KernelType::GATHER_ELEMENTS) {}
};
class GatherElementsKernelRef : public KernelBaseOpenCL {
public:
GatherElementsKernelRef() : KernelBaseOpenCL("gather_elements_ref") {}
virtual ~GatherElementsKernelRef() {}
virtual JitConstants GetJitConstants(const gather_elements_params& params) const;
virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const;
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::QUANTIZE,
FusedOpType::SCALE,
FusedOpType::ACTIVATION,
FusedOpType::ELTWISE };
}
protected:
bool Validate(const Params& p, const optional_params& o) const override;
};
} // namespace kernel_selector

View File

@ -0,0 +1,27 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#include "gather_elements_kernel_selector.h"
#include "gather_elements_kernel_ref.h"
namespace kernel_selector {
gather_elements_kernel_selector::gather_elements_kernel_selector() { Attach<GatherElementsKernelRef>(); }
KernelsData gather_elements_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::GATHER_ELEMENTS);
}
} // namespace kernel_selector

View File

@ -0,0 +1,35 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
#pragma once
#include "kernel_selector.h"
namespace kernel_selector {
class gather_elements_kernel_selector : public kernel_selector_base {
public:
static gather_elements_kernel_selector& Instance() {
static gather_elements_kernel_selector instance_;
return instance_;
}
gather_elements_kernel_selector();
virtual ~gather_elements_kernel_selector() {}
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
};
} // namespace kernel_selector

View File

@ -0,0 +1,86 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "include/data_types.cl"
#include "include/fetch_data.cl"
#define GET_OUTPUT_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
KERNEL(gather_elements_ref)(const __global INPUT0_TYPE* data,
const __global INPUT1_TYPE* indices,
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint dim0 = get_global_id(0);
const uint dim1 = get_global_id(1);
const uint dim2 = get_global_id(2);
// Calculate indice index
#if INPUT1_DIMS == 4
#define ORDER b,f,y,x
const uint x = dim0;
const uint y = dim1;
#elif INPUT1_DIMS == 5
#define ORDER b,f,z,y,x
const uint x = dim0;
const uint y = dim1 % OUTPUT_SIZE_Y;
const uint z = dim1 / OUTPUT_SIZE_Y;
#else
#define ORDER b,f,w,z,y,x
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint z = dim1 % OUTPUT_SIZE_Z;
const uint w = dim1 / OUTPUT_SIZE_Z;
#endif
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
const int out_idx = GET_OUTPUT_INDEX(INPUT1, ORDER);
#if INPUT1_DIMS == 4
size_t data_shape[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
size_t indices_shape[4] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#elif INPUT1_DIMS == 5
size_t data_shape[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
size_t indices_shape[5] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#else
size_t data_shape[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
size_t indices_shape[6] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#endif
size_t max_inner_sum = 1, max_outer_sum = 1, outer_sum_inc_data = 1, outer_sum_inc_indices = 1;
for (size_t i = AXIS + 1; i < INPUT1_DIMS; i++)
max_inner_sum *= indices_shape[i];
for (int i = 0; i < AXIS; i++)
max_outer_sum *= indices_shape[i];
for (size_t i = AXIS; i < INPUT1_DIMS; i++) {
outer_sum_inc_data *= data_shape[i];
}
max_outer_sum *= outer_sum_inc_data;
for (size_t i = AXIS; i < INPUT1_DIMS; i++) {
outer_sum_inc_indices *= indices_shape[i];
}
size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data;
size_t inner_sum = out_idx % max_inner_sum;
uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum;
INPUT0_TYPE val = data[idx];
#if HAS_FUSED_OPS
FUSED_OPS;
output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT);
#else
output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif
}
#undef ORDER
#undef GET_OUTPUT_INDEX

View File

@ -402,6 +402,8 @@ std::string toString(GatherAxis a) {
switch (a) {
case GatherAxis::X: return "X";
case GatherAxis::Y: return "Y";
case GatherAxis::Z: return "Z";
case GatherAxis::W: return "W";
case GatherAxis::FEATURE: return "FEATURE";
case GatherAxis::BATCH: return "BATCH";
default: return "";

View File

@ -0,0 +1,62 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gather_elements_inst.h"
#include "primitive_type_base.h"
#include "cldnn/runtime/error_handler.hpp"
#include "json_object.h"
#include <string>
namespace cldnn {
primitive_type_id gather_elements::type_id() {
static primitive_type_base<gather_elements> instance;
return &instance;
}
layout gather_elements_inst::calc_output_layout(gather_elements_node const& node) {
auto op = node.get_primitive();
auto input_layout_origin = node.input(0).get_output_layout();
auto indices_layout_origin = node.input(1).get_output_layout();
auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format);
auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format);
if (node.has_fused_primitives()) {
input_layout_origin.data_type = node.get_fused_output_layout().data_type;
}
auto output_type = indices_layout_origin.data_type;
auto output_format = op->output_format;
auto output_shape = op->output_shape;
// calculate initial output shape
return layout(output_type, output_format, output_shape);
}
std::string gather_elements_inst::to_string(gather_elements_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();
auto& input = node.input();
std::stringstream primitive_description;
json_composite gather_elements_info;
gather_elements_info.add("input id", input.id());
gather_elements_info.add("input shape", node.input(0).get_output_layout().size.to_string());
gather_elements_info.add("indices shape", node.input(1).get_output_layout().size.to_string());
gather_elements_info.add("output format", calc_output_layout(node).format);
gather_elements_info.add("output shape", calc_output_layout(node).size.to_string());
gather_elements_info.add("axis", desc->axis);
node_info->add("gather_elements info", gather_elements_info);
node_info->dump(primitive_description);
return primitive_description.str();
}
gather_elements_inst::typed_primitive_inst(network_impl& network, gather_elements_node const& node) : parent(network, node) {}
} // namespace cldnn

View File

@ -32,6 +32,7 @@
#include "space_to_depth_inst.h"
#include "gather_inst.h"
#include "gather_nd_inst.h"
#include "gather_elements_inst.h"
#include "scatter_update_inst.h"
#include "scatter_nd_update_inst.h"
#include "scatter_elements_update_inst.h"
@ -200,6 +201,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
!input.is_type<scatter_nd_update>() &&
!input.is_type<gather_nd>() &&
!input.is_type<gather_elements>() &&
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
@ -609,6 +611,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
should_fuse |= input_data.is_type<gather_nd>();
should_fuse |= input_data.is_type<gather_elements>();
should_fuse |= input_data.is_type<scatter_update>();
should_fuse |= input_data.is_type<scatter_nd_update>();
@ -677,6 +681,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
should_fuse |= input_data.is_type<gather_nd>();
should_fuse |= input_data.is_type<gather_elements>();
should_fuse |= input_data.is_type<scatter_update>();
should_fuse |= input_data.is_type<scatter_nd_update>();
@ -767,6 +773,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
should_fuse |= input_data.is_type<gather_nd>() && quantize_node.get_scale_shift_opt();
should_fuse |= input_data.is_type<gather_elements>() && quantize_node.get_scale_shift_opt();
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
@ -829,6 +837,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
(parents[i]->is_type<scale>()) ||
(parents[i]->is_type<gather_nd>()) ||
(parents[i]->is_type<gather_elements>()) ||
(parents[i]->is_type<scatter_nd_update>()) ||
(parents[i]->is_type<scatter_elements_update>()) ||
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||

View File

@ -0,0 +1,86 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gather_elements_inst.h"
#include "primitive_base.hpp"
#include "impls/implementation_map.hpp"
#include "kernel_selector_helper.h"
#include "gather/gather_elements_kernel_selector.h"
#include "gather/gather_elements_kernel_ref.h"
#include "cldnn/runtime/error_handler.hpp"
using namespace cldnn;
namespace cldnn {
namespace ocl {
kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_elements_axis axis) {
switch (axis) {
case gather_elements::along_x:
return kernel_selector::gather_elements_axis::X;
case gather_elements::along_y:
return kernel_selector::gather_elements_axis::Y;
case gather_elements::along_z:
return kernel_selector::gather_elements_axis::Z;
case gather_elements::along_w:
return kernel_selector::gather_elements_axis::W;
case gather_elements::along_f:
return kernel_selector::gather_elements_axis::FEATURE;
case gather_elements::along_b:
return kernel_selector::gather_elements_axis::BATCH;
default:
return kernel_selector::gather_elements_axis::BATCH;
}
}
struct gather_elements_impl : typed_primitive_impl_ocl<gather_elements> {
using parent = typed_primitive_impl_ocl<gather_elements>;
using parent::parent;
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<gather_elements_impl>(*this);
}
public:
static primitive_impl* create(const gather_elements_node& arg) {
auto gather_elements_params = get_default_params<kernel_selector::gather_elements_params>(arg);
auto gather_elements_optional_params =
get_default_optional_params<kernel_selector::gather_elements_optional_params>(arg.get_program());
gather_elements_params.axis = convert_axis(arg.get_primitive()->axis);
gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params);
CLDNN_ERROR_BOOL(arg.id(),
"Best_kernel.empty()",
best_kernels.empty(),
"Cannot find a proper kernel with this arguments");
auto gather_elements = new gather_elements_impl(arg, best_kernels[0]);
return gather_elements;
}
};
namespace detail {
attach_gather_elements_impl::attach_gather_elements_impl() {
implementation_map<gather_elements>::add(impl_types::ocl, gather_elements_impl::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::f32, format::bfwzyx),
std::make_tuple(data_types::f16, format::bfwzyx),
std::make_tuple(data_types::i32, format::bfwzyx),
});
}
} // namespace detail
} // namespace ocl
} // namespace cldnn

View File

@ -30,6 +30,7 @@ void register_implementations() {
REGISTER_OCL(eltwise);
REGISTER_OCL(fully_connected);
REGISTER_OCL(gather);
REGISTER_OCL(gather_elements);
REGISTER_OCL(gather_nd);
REGISTER_OCL(gemm);
REGISTER_OCL(lrn);

View File

@ -22,6 +22,7 @@
#include "cldnn/primitives/fully_connected.hpp"
#include "cldnn/primitives/gather.hpp"
#include "cldnn/primitives/gather_nd.hpp"
#include "cldnn/primitives/gather_elements.hpp"
#include "cldnn/primitives/gemm.hpp"
#include "cldnn/primitives/lrn.hpp"
#include "cldnn/primitives/lstm.hpp"
@ -94,6 +95,7 @@ REGISTER_OCL(embed);
REGISTER_OCL(fully_connected);
REGISTER_OCL(gather);
REGISTER_OCL(gather_nd);
REGISTER_OCL(gather_elements);
REGISTER_OCL(gemm);
REGISTER_OCL(lrn);
REGISTER_OCL(lstm_gemm);

View File

@ -0,0 +1,49 @@
/*
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "cldnn/primitives/gather_elements.hpp"
#include "primitive_inst.h"
#include <string>
namespace cldnn {
template <>
struct typed_program_node<gather_elements> : public typed_program_node_base<gather_elements> {
using parent = typed_program_node_base<gather_elements>;
public:
using parent::parent;
program_node& input(size_t index = 0) const { return get_dependency(index); }
};
using gather_elements_node = typed_program_node<gather_elements>;
template <>
class typed_primitive_inst<gather_elements> : public typed_primitive_inst_base<gather_elements> {
using parent = typed_primitive_inst_base<gather_elements>;
public:
static layout calc_output_layout(gather_elements_node const& node);
static std::string to_string(gather_elements_node const& node);
public:
typed_primitive_inst(network_impl& network, gather_elements_node const& desc);
};
using gather_elements_inst = typed_primitive_inst<gather_elements>;
} // namespace cldnn

View File

@ -72,6 +72,7 @@ using shape_calculation_mode = kernel_selector::ShapeCalculationMode;
using interpolate_axis = kernel_selector::InterpolateAxis;
using border_type = kernel_selector::BorderType;
using gather_axis = kernel_selector::GatherAxis;
using gather_elements_axis = kernel_selector::GatherAxis;
using scatter_update_axis = kernel_selector::ScatterUpdateAxis;
using reduce_mode = kernel_selector::ReduceMode;
using cum_sum_axis = kernel_selector::CumSumAxis;

View File

@ -18,6 +18,7 @@
#include <cldnn/primitives/permute.hpp>
#include <cldnn/primitives/gather.hpp>
#include <cldnn/primitives/gather_nd.hpp>
#include <cldnn/primitives/gather_elements.hpp>
#include <cldnn/primitives/scatter_update.hpp>
#include <cldnn/primitives/scatter_nd_update.hpp>
#include <cldnn/primitives/scatter_elements_update.hpp>
@ -8413,3 +8414,228 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_nd_activation_scale_eltwise,
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 },
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 },
}));
/* ----------------------------------------------------------------------------------------------------- */
/* ------------------------------------------ GatherElements cases ------------------------------------- */
/* ----------------------------------------------------------------------------------------------------- */
struct gather_elements_test_params {
data_types data_type;
format input_format;
tensor input_shape;
format indices_format;
tensor indices_shape;
format output_format;
tensor output_shape;
cldnn::gather_elements::gather_elements_axis axis;
data_types default_type;
format default_format;
size_t expected_fused_primitives;
size_t expected_not_fused_primitives;
};
#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f16, format::bfyx
#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f16, format::bfyx
#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfyx
#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfzyx
#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f16, format::bfzyx
#define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f16, format::bfwzyx
#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f16, format::bfwzyx
#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfwzyx
#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f32, format::bfyx
#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f32, format::bfyx
#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfyx
#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfzyx
#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f32, format::bfzyx
#define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f32, format::bfwzyx
#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f32, format::bfwzyx
#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfwzyx
class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest<gather_elements_test_params> {
public:
void execute(gather_elements_test_params& p) {
auto input_prim = get_mem(get_input_layout(p));
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
network network_fused(this->engine, this->topology_fused, bo_fused);
network_fused.set_input_data("input", input_prim);
network_not_fused.set_input_data("input", input_prim);
compare(network_not_fused, network_fused, p);
}
size_t get_axis_dim(gather_elements_test_params& p) {
switch (p.axis) {
case cldnn::gather_elements::gather_elements_axis::along_x:
return p.input_shape.spatial[0];
case cldnn::gather_elements::gather_elements_axis::along_y:
return p.input_shape.spatial[1];
case cldnn::gather_elements::gather_elements_axis::along_z:
return p.input_shape.spatial[2];
case cldnn::gather_elements::gather_elements_axis::along_w:
return p.input_shape.spatial[3];
case cldnn::gather_elements::gather_elements_axis::along_f:
return p.input_shape.feature[0];
case cldnn::gather_elements::gather_elements_axis::along_b:
return p.input_shape.batch[0];
default:
return 1;
}
}
layout get_input_layout(gather_elements_test_params& p) {
return layout{ p.data_type, p.input_format, p.input_shape };
}
layout get_indices_layout(gather_elements_test_params& p) {
return layout{ p.data_type, p.indices_format, p.indices_shape };
}
layout get_output_layout(gather_elements_test_params& p) {
return layout{ p.data_type, p.output_format, p.output_shape };
}
layout get_per_channel_layout(gather_elements_test_params& p) {
return layout{ p.default_type, p.default_format, tensor{1, p.output_shape.feature[0], 1, 1} };
}
};
class gather_elements_quantize : public GatherElementsPrimitiveFusingTest {};
TEST_P(gather_elements_quantize, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p))-1)),
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)),
gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis),
quantize("quantize", "gather_elements_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
tolerance = 1.f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_quantize,
::testing::ValuesIn(std::vector<gather_elements_test_params>{
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 3 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 3 },
}));
class gather_elements_scale_activation : public GatherElementsPrimitiveFusingTest {};
TEST_P(gather_elements_scale_activation, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p))-1)),
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis),
activation("activation", "gather_elements_prim", activation_func::abs),
scale("scale", "activation", "scale_data"),
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
);
tolerance = 1e-5f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_scale_activation,
::testing::ValuesIn(std::vector<gather_elements_test_params>{
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 4 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 4 },
}));
class gather_elements_activation_scale_eltwise : public GatherElementsPrimitiveFusingTest {};
TEST_P(gather_elements_activation_scale_eltwise, basic) {
auto p = GetParam();
create_topologies(input_layout("input", get_input_layout(p)),
data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p))-1)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)),
data("eltwise_data", get_mem(get_output_layout(p))),
gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis),
activation("activation", "gather_elements_prim", activation_func::abs),
scale("scale", "activation", "scale_data"),
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
);
tolerance = 1e-5f;
execute(p);
}
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_activation_scale_eltwise,
::testing::ValuesIn(std::vector<gather_elements_test_params>{
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 },
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 },
}));

File diff suppressed because it is too large Load Diff