[GPU] Add new operation GatherElements to IE clDNN plugin (#6676)
This commit is contained in:
parent
b4ad7a1755
commit
f5666fb3e1
@ -204,6 +204,7 @@ REGISTER_FACTORY(v5, Loop);
|
||||
// ------------------------------ Supported v6 ops ------------------------------ //
|
||||
REGISTER_FACTORY(v6, CTCGreedyDecoderSeqLen);
|
||||
REGISTER_FACTORY(v6, MVN);
|
||||
REGISTER_FACTORY(v6, GatherElements);
|
||||
|
||||
// ------------------------------ Supported v7 ops ------------------------------ //
|
||||
REGISTER_FACTORY(v7, Gather);
|
||||
|
66
inference-engine/src/cldnn_engine/ops/gather_elements.cpp
Normal file
66
inference-engine/src/cldnn_engine/ops/gather_elements.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "cldnn_program.h"
|
||||
#include "cldnn_common_utils.h"
|
||||
|
||||
#include "ngraph/op/gather_elements.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
|
||||
#include "cldnn/primitives/gather_elements.hpp"
|
||||
|
||||
namespace CLDNNPlugin {
|
||||
|
||||
static cldnn::gather_elements::gather_elements_axis GetGatherAxis(int axis, unsigned rank) {
|
||||
if (axis < 0)
|
||||
axis += rank;
|
||||
if (axis < 0 || axis >= rank)
|
||||
IE_THROW() << "GatherElements axis is not correspond to number of dimensions";
|
||||
|
||||
// Difference in dimension ordering between IE and clDNN,
|
||||
// reverse spatial dimensions after batch and feature.
|
||||
unsigned cldnn_axis = axis;
|
||||
if (axis >= 2) {
|
||||
auto spatial_axis = axis - 2;
|
||||
// Default and minimum number of dimensions is 4
|
||||
auto spatial_size = std::max(rank, 4u) - 2;
|
||||
cldnn_axis = spatial_size - spatial_axis - 1 + 2;
|
||||
}
|
||||
|
||||
switch (cldnn_axis) {
|
||||
case 0: return cldnn::gather_elements::gather_elements_axis::along_b;
|
||||
case 1: return cldnn::gather_elements::gather_elements_axis::along_f;
|
||||
case 2: return cldnn::gather_elements::gather_elements_axis::along_x;
|
||||
case 3: return cldnn::gather_elements::gather_elements_axis::along_y;
|
||||
case 4: return cldnn::gather_elements::gather_elements_axis::along_z;
|
||||
case 5: return cldnn::gather_elements::gather_elements_axis::along_w;
|
||||
default: IE_THROW() << "Unsupported GatherElements axis: " << axis;
|
||||
}
|
||||
return cldnn::gather_elements::gather_elements_axis::along_f; // shouldn't get here
|
||||
}
|
||||
|
||||
void CreateGatherElementsOp(Program& p, const std::shared_ptr<ngraph::op::v6::GatherElements>& op) {
|
||||
p.ValidateInputs(op, {2});
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
size_t rank = op->get_input_shape(0).size();
|
||||
int32_t axis = static_cast<int32_t>(op->get_axis());
|
||||
|
||||
auto outLayout = DefaultFormatForDims(op->get_output_shape(0).size());
|
||||
|
||||
auto primitive = cldnn::gather_elements(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
outLayout,
|
||||
CldnnTensorFromIEDims(op->get_output_shape(0)),
|
||||
GetGatherAxis(axis, rank));
|
||||
|
||||
p.AddPrimitive(primitive);
|
||||
p.AddPrimitiveToProfiler(op);
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v6, GatherElements);
|
||||
|
||||
} // namespace CLDNNPlugin
|
@ -4,7 +4,8 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "shared_test_classes/single_layer/gather_elements.hpp"
|
||||
#include "single_layer_tests/gather_elements.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
@ -16,8 +17,6 @@ const std::vector<InferenceEngine::Precision> dPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64,
|
||||
InferenceEngine::Precision::I16,
|
||||
InferenceEngine::Precision::U8,
|
||||
InferenceEngine::Precision::I8
|
||||
};
|
||||
const std::vector<InferenceEngine::Precision> iPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
|
@ -0,0 +1,227 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
|
||||
#include "single_layer_tests/gather_elements.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using namespace ngraph::opset6;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> inputPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16,
|
||||
InferenceEngine::Precision::I32,
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> idxPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64,
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 2})),
|
||||
::testing::Values(std::vector<size_t>({2, 2})),
|
||||
::testing::ValuesIn(std::vector<int>({-1, 0, 1})),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 2, 1})),
|
||||
::testing::Values(std::vector<size_t>({4, 2, 1})),
|
||||
::testing::ValuesIn(std::vector<int>({0, -3})),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 2, 3, 5})),
|
||||
::testing::Values(std::vector<size_t>({2, 2, 3, 7})),
|
||||
::testing::Values(3, -1),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({3, 2, 3, 8})),
|
||||
::testing::Values(std::vector<size_t>({2, 2, 3, 8})),
|
||||
::testing::Values(0, -4),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({3, 2, 3, 4, 8})),
|
||||
::testing::Values(std::vector<size_t>({3, 2, 3, 5, 8})),
|
||||
::testing::Values(3, -2),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis0, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{7, 7, 8, 4}),
|
||||
::testing::Values(std::vector<size_t>{2, 7, 8, 4}),
|
||||
::testing::Values(0),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis1, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{6, 1, 8, 4}),
|
||||
::testing::Values(std::vector<size_t>{6, 8, 8, 4}),
|
||||
::testing::Values(1, -3),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis2, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{6, 7, 4, 4}),
|
||||
::testing::Values(std::vector<size_t>{6, 7, 2, 4}),
|
||||
::testing::Values(2, -2),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank4axis3, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{6, 5, 8, 7}),
|
||||
::testing::Values(std::vector<size_t>{6, 5, 8, 7}),
|
||||
::testing::Values(3, -1),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis0, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{2, 3, 9, 4, 9}),
|
||||
::testing::Values(std::vector<size_t>{1, 3, 9, 4, 9}),
|
||||
::testing::Values(0),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis1, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{2, 3, 5, 4, 7}),
|
||||
::testing::Values(std::vector<size_t>{2, 9, 5, 4, 7}),
|
||||
::testing::Values(1, -4),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis2, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{1, 2, 6, 8, 9}),
|
||||
::testing::Values(std::vector<size_t>{1, 2, 6, 8, 9}),
|
||||
::testing::Values(2, -3),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis3, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{2, 2, 4, 7, 7}),
|
||||
::testing::Values(std::vector<size_t>{2, 2, 4, 3, 7}),
|
||||
::testing::Values(3, -2),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank5axis4, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{1, 3, 9, 3, 2}),
|
||||
::testing::Values(std::vector<size_t>{1, 3, 9, 3, 9}),
|
||||
::testing::Values(4, -1),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis0, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{3, 3, 2, 4, 4, 3}),
|
||||
::testing::Values(std::vector<size_t>{7, 3, 2, 4, 4, 3}),
|
||||
::testing::Values(0),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis1, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{1, 6, 2, 3, 5, 9}),
|
||||
::testing::Values(std::vector<size_t>{1, 6, 2, 3, 5, 9}),
|
||||
::testing::Values(1, -5),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis2, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{2, 3, 9, 7, 2, 1}),
|
||||
::testing::Values(std::vector<size_t>{2, 3, 5, 7, 2, 1}),
|
||||
::testing::Values(2, -4),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis3, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{1, 3, 4, 5, 1, 3}),
|
||||
::testing::Values(std::vector<size_t>{1, 3, 4, 4, 1, 3}),
|
||||
::testing::Values(3, -3),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis4, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{1, 3, 2, 4, 3, 3}),
|
||||
::testing::Values(std::vector<size_t>{1, 3, 2, 4, 6, 3}),
|
||||
::testing::Values(4, -2),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_GatherElements_rank6axis5, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{2, 1, 7, 8, 1, 6}),
|
||||
::testing::Values(std::vector<size_t>{2, 1, 7, 8, 1, 5}),
|
||||
::testing::Values(5, -1),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,15 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/single_layer/gather_elements.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
TEST_P(GatherElementsLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -48,7 +48,4 @@ void GatherElementsLayerTest::SetUp() {
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherEl");
|
||||
}
|
||||
|
||||
TEST_P(GatherElementsLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
58
inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp
vendored
Normal file
58
inference-engine/thirdparty/clDNN/api/cldnn/primitives/gather_elements.hpp
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#pragma once
|
||||
#include "primitive.hpp"
|
||||
|
||||
namespace cldnn {
|
||||
/// @addtogroup cpp_api C++ API
|
||||
/// @{
|
||||
/// @addtogroup cpp_topology Network Topology
|
||||
/// @{
|
||||
/// @addtogroup cpp_primitives Primitives
|
||||
/// @{
|
||||
|
||||
/// @brief
|
||||
/// @details
|
||||
struct gather_elements : public primitive_base<gather_elements> {
|
||||
CLDNN_DECLARE_PRIMITIVE(gather_elements)
|
||||
|
||||
enum gather_elements_axis {
|
||||
along_b,
|
||||
along_f,
|
||||
along_x,
|
||||
along_y,
|
||||
along_z,
|
||||
along_w
|
||||
};
|
||||
|
||||
/// @brief Constructs gather_elements primitive.
|
||||
/// @param id This primitive id.
|
||||
/// @param data Input data primitive id.
|
||||
/// @param indices Input indexes primitive id.
|
||||
/// @param output_format Output format.
|
||||
/// @param output_shape Output shape.
|
||||
/// @param axis Gathering axis.
|
||||
gather_elements(const primitive_id& id,
|
||||
const primitive_id& data,
|
||||
const primitive_id& indices,
|
||||
const format& output_format,
|
||||
const tensor& output_shape,
|
||||
const gather_elements_axis axis,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data, indices}, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {}
|
||||
|
||||
/// @brief Gather Elements output format
|
||||
format output_format;
|
||||
/// @brief Gather Elements output shape
|
||||
tensor output_shape;
|
||||
|
||||
/// @brief Which axis to gather on.
|
||||
gather_elements_axis axis;
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
/// @}
|
||||
} // namespace cldnn
|
@ -48,6 +48,7 @@ enum class KernelType {
|
||||
ONE_HOT,
|
||||
GATHER,
|
||||
GATHER_ND,
|
||||
GATHER_ELEMENTS,
|
||||
SCATTER_UPDATE,
|
||||
SCATTER_ND_UPDATE,
|
||||
SCATTER_ELEMENTS_UPDATE,
|
||||
|
@ -0,0 +1,154 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gather_elements_kernel_ref.h"
|
||||
#include "kernel_selector_utils.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace kernel_selector {
|
||||
static size_t GetGatherElementsChannelIndex(const gather_elements_params& params) {
|
||||
Tensor::DataChannelName name = Tensor::DataChannelName::X;
|
||||
|
||||
size_t inputSize = params.inputs[0].GetDims().size();
|
||||
|
||||
switch (params.axis) {
|
||||
case GatherAxis::X:
|
||||
return inputSize - 1;
|
||||
case GatherAxis::Y:
|
||||
return inputSize - 2;
|
||||
case GatherAxis::Z:
|
||||
return inputSize - 3;
|
||||
case GatherAxis::W:
|
||||
return 2;
|
||||
case GatherAxis::FEATURE:
|
||||
return 1;
|
||||
case GatherAxis::BATCH:
|
||||
return 0;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return DataTensor::Channelndex(params.output.GetLayout(), name);
|
||||
}
|
||||
|
||||
ParamsKey GatherElementsKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableInputLayout(DataLayout::bfzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfzyx);
|
||||
k.EnableInputLayout(DataLayout::bfwzyx);
|
||||
k.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
return k;
|
||||
}
|
||||
|
||||
static inline std::vector<std::string> GetDefaultOrder(size_t size) {
|
||||
std::vector<std::string> default_order;
|
||||
if (size <= 4) {
|
||||
default_order = { "b", "f", "y", "x" };
|
||||
} else if (size == 5) {
|
||||
default_order = { "b", "f", "z", "y", "x" };
|
||||
} else if (size == 6) {
|
||||
default_order = { "b", "f", "w", "z", "y", "x" };
|
||||
}
|
||||
|
||||
return default_order;
|
||||
}
|
||||
|
||||
CommonDispatchData GatherElementsKernelRef::SetDefault(const gather_elements_params& params, const optional_params&) const {
|
||||
CommonDispatchData dispatchData;
|
||||
|
||||
const auto& output = params.output;
|
||||
|
||||
switch (params.inputs[1].GetLayout()) {
|
||||
case DataLayout::bfyx:
|
||||
dispatchData.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
|
||||
case DataLayout::bfzyx:
|
||||
dispatchData.gws = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
|
||||
case DataLayout::bfwzyx:
|
||||
dispatchData.gws = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::invalid_argument("Unsupported data layout for gather elements primitive");
|
||||
break;
|
||||
}
|
||||
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||
|
||||
return dispatchData;
|
||||
}
|
||||
|
||||
JitConstants GatherElementsKernelRef::GetJitConstants(const gather_elements_params& params) const {
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
|
||||
jit.AddConstant(MakeJitConstant("AXIS", GetGatherElementsChannelIndex(params)));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
std::vector<std::string> idx_order = GetDefaultOrder(params.inputs[0].GetDims().size());
|
||||
FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() };
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
|
||||
}
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
bool GatherElementsKernelRef::Validate(const Params& p, const optional_params& o) const {
|
||||
if (p.GetType() != KernelType::GATHER_ELEMENTS || o.GetType() != KernelType::GATHER_ELEMENTS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const gather_elements_params& params = static_cast<const gather_elements_params&>(p);
|
||||
auto input_dims = params.inputs[0].LogicalDims();
|
||||
auto indices_dims = params.inputs[1].LogicalDims();
|
||||
|
||||
if (input_dims.size() != indices_dims.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& fused_op : params.fused_ops) {
|
||||
if (!IsFusedPrimitiveSupported(fused_op))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
KernelsData GatherElementsKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
|
||||
if (!Validate(params, options)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
KernelData kd = KernelData::Default<gather_elements_params>(params);
|
||||
gather_elements_params& newParams = *static_cast<gather_elements_params*>(kd.params.get());
|
||||
|
||||
auto dispatchData = SetDefault(newParams, options);
|
||||
auto cldnn_jit = GetJitConstants(newParams);
|
||||
|
||||
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, options);
|
||||
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
|
||||
auto& kernel = kd.kernels[0];
|
||||
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 2, GetFusedPrimitiveInputsCount(params));
|
||||
return { kd };
|
||||
}
|
||||
|
||||
KernelsPriority GatherElementsKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const {
|
||||
return DONT_USE_IF_HAVE_SOMETHING_ELSE;
|
||||
}
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernel_base_opencl.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// gather_elements_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct gather_elements_params : public base_params {
|
||||
gather_elements_params() : base_params(KernelType::GATHER_ELEMENTS), axis(GatherAxis::BATCH) {}
|
||||
|
||||
GatherAxis axis;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// gather_elements_optional_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct gather_elements_optional_params : optional_params {
|
||||
gather_elements_optional_params() : optional_params(KernelType::GATHER_ELEMENTS) {}
|
||||
};
|
||||
|
||||
class GatherElementsKernelRef : public KernelBaseOpenCL {
|
||||
public:
|
||||
GatherElementsKernelRef() : KernelBaseOpenCL("gather_elements_ref") {}
|
||||
virtual ~GatherElementsKernelRef() {}
|
||||
virtual JitConstants GetJitConstants(const gather_elements_params& params) const;
|
||||
virtual CommonDispatchData SetDefault(const gather_elements_params& params, const optional_params&) const;
|
||||
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
|
||||
KernelsPriority GetKernelsPriority(const Params& params, const optional_params& options) const override;
|
||||
ParamsKey GetSupportedKey() const override;
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
return { FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION,
|
||||
FusedOpType::ELTWISE };
|
||||
}
|
||||
|
||||
protected:
|
||||
bool Validate(const Params& p, const optional_params& o) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,27 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gather_elements_kernel_selector.h"
|
||||
#include "gather_elements_kernel_ref.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
gather_elements_kernel_selector::gather_elements_kernel_selector() { Attach<GatherElementsKernelRef>(); }
|
||||
|
||||
KernelsData gather_elements_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
|
||||
return GetNaiveBestKernel(params, options, KernelType::GATHER_ELEMENTS);
|
||||
}
|
||||
} // namespace kernel_selector
|
@ -0,0 +1,35 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernel_selector.h"
|
||||
|
||||
namespace kernel_selector {
|
||||
class gather_elements_kernel_selector : public kernel_selector_base {
|
||||
public:
|
||||
static gather_elements_kernel_selector& Instance() {
|
||||
static gather_elements_kernel_selector instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
gather_elements_kernel_selector();
|
||||
|
||||
virtual ~gather_elements_kernel_selector() {}
|
||||
|
||||
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
|
||||
};
|
||||
} // namespace kernel_selector
|
86
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl
vendored
Normal file
86
inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/gather_elements_ref.cl
vendored
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "include/data_types.cl"
|
||||
#include "include/fetch_data.cl"
|
||||
|
||||
#define GET_OUTPUT_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
|
||||
|
||||
KERNEL(gather_elements_ref)(const __global INPUT0_TYPE* data,
|
||||
const __global INPUT1_TYPE* indices,
|
||||
__global OUTPUT_TYPE* output
|
||||
#if HAS_FUSED_OPS_DECLS
|
||||
, FUSED_OPS_DECLS
|
||||
#endif
|
||||
)
|
||||
{
|
||||
const uint dim0 = get_global_id(0);
|
||||
const uint dim1 = get_global_id(1);
|
||||
const uint dim2 = get_global_id(2);
|
||||
|
||||
// Calculate indice index
|
||||
#if INPUT1_DIMS == 4
|
||||
#define ORDER b,f,y,x
|
||||
const uint x = dim0;
|
||||
const uint y = dim1;
|
||||
#elif INPUT1_DIMS == 5
|
||||
#define ORDER b,f,z,y,x
|
||||
const uint x = dim0;
|
||||
const uint y = dim1 % OUTPUT_SIZE_Y;
|
||||
const uint z = dim1 / OUTPUT_SIZE_Y;
|
||||
#else
|
||||
#define ORDER b,f,w,z,y,x
|
||||
const uint x = dim0 % OUTPUT_SIZE_X;
|
||||
const uint y = dim0 / OUTPUT_SIZE_X;
|
||||
const uint z = dim1 % OUTPUT_SIZE_Z;
|
||||
const uint w = dim1 / OUTPUT_SIZE_Z;
|
||||
#endif
|
||||
const uint f = dim2 % OUTPUT_FEATURE_NUM;
|
||||
const uint b = dim2 / OUTPUT_FEATURE_NUM;
|
||||
|
||||
const int out_idx = GET_OUTPUT_INDEX(INPUT1, ORDER);
|
||||
|
||||
#if INPUT1_DIMS == 4
|
||||
size_t data_shape[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
size_t indices_shape[4] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#elif INPUT1_DIMS == 5
|
||||
size_t data_shape[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
size_t indices_shape[5] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#else
|
||||
size_t data_shape[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
size_t indices_shape[6] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#endif
|
||||
|
||||
size_t max_inner_sum = 1, max_outer_sum = 1, outer_sum_inc_data = 1, outer_sum_inc_indices = 1;
|
||||
for (size_t i = AXIS + 1; i < INPUT1_DIMS; i++)
|
||||
max_inner_sum *= indices_shape[i];
|
||||
|
||||
for (int i = 0; i < AXIS; i++)
|
||||
max_outer_sum *= indices_shape[i];
|
||||
|
||||
for (size_t i = AXIS; i < INPUT1_DIMS; i++) {
|
||||
outer_sum_inc_data *= data_shape[i];
|
||||
}
|
||||
max_outer_sum *= outer_sum_inc_data;
|
||||
|
||||
for (size_t i = AXIS; i < INPUT1_DIMS; i++) {
|
||||
outer_sum_inc_indices *= indices_shape[i];
|
||||
}
|
||||
|
||||
size_t outer_sum = (out_idx / outer_sum_inc_indices) * outer_sum_inc_data;
|
||||
size_t inner_sum = out_idx % max_inner_sum;
|
||||
|
||||
uint idx = outer_sum + max_inner_sum * indices[out_idx] + inner_sum;
|
||||
INPUT0_TYPE val = data[idx];
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS;
|
||||
output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT);
|
||||
#else
|
||||
output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
|
||||
#endif
|
||||
}
|
||||
|
||||
#undef ORDER
|
||||
#undef GET_OUTPUT_INDEX
|
@ -402,6 +402,8 @@ std::string toString(GatherAxis a) {
|
||||
switch (a) {
|
||||
case GatherAxis::X: return "X";
|
||||
case GatherAxis::Y: return "Y";
|
||||
case GatherAxis::Z: return "Z";
|
||||
case GatherAxis::W: return "W";
|
||||
case GatherAxis::FEATURE: return "FEATURE";
|
||||
case GatherAxis::BATCH: return "BATCH";
|
||||
default: return "";
|
||||
|
62
inference-engine/thirdparty/clDNN/src/gather_elements.cpp
vendored
Normal file
62
inference-engine/thirdparty/clDNN/src/gather_elements.cpp
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gather_elements_inst.h"
|
||||
|
||||
#include "primitive_type_base.h"
|
||||
#include "cldnn/runtime/error_handler.hpp"
|
||||
#include "json_object.h"
|
||||
#include <string>
|
||||
|
||||
namespace cldnn {
|
||||
primitive_type_id gather_elements::type_id() {
|
||||
static primitive_type_base<gather_elements> instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
layout gather_elements_inst::calc_output_layout(gather_elements_node const& node) {
|
||||
auto op = node.get_primitive();
|
||||
|
||||
auto input_layout_origin = node.input(0).get_output_layout();
|
||||
auto indices_layout_origin = node.input(1).get_output_layout();
|
||||
|
||||
auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format);
|
||||
auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format);
|
||||
|
||||
if (node.has_fused_primitives()) {
|
||||
input_layout_origin.data_type = node.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
auto output_type = indices_layout_origin.data_type;
|
||||
auto output_format = op->output_format;
|
||||
auto output_shape = op->output_shape;
|
||||
|
||||
// calculate initial output shape
|
||||
return layout(output_type, output_format, output_shape);
|
||||
}
|
||||
|
||||
std::string gather_elements_inst::to_string(gather_elements_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto node_info = node.desc_to_json();
|
||||
auto& input = node.input();
|
||||
|
||||
std::stringstream primitive_description;
|
||||
|
||||
json_composite gather_elements_info;
|
||||
gather_elements_info.add("input id", input.id());
|
||||
gather_elements_info.add("input shape", node.input(0).get_output_layout().size.to_string());
|
||||
gather_elements_info.add("indices shape", node.input(1).get_output_layout().size.to_string());
|
||||
gather_elements_info.add("output format", calc_output_layout(node).format);
|
||||
gather_elements_info.add("output shape", calc_output_layout(node).size.to_string());
|
||||
gather_elements_info.add("axis", desc->axis);
|
||||
|
||||
node_info->add("gather_elements info", gather_elements_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
||||
return primitive_description.str();
|
||||
}
|
||||
|
||||
gather_elements_inst::typed_primitive_inst(network_impl& network, gather_elements_node const& node) : parent(network, node) {}
|
||||
|
||||
} // namespace cldnn
|
@ -32,6 +32,7 @@
|
||||
#include "space_to_depth_inst.h"
|
||||
#include "gather_inst.h"
|
||||
#include "gather_nd_inst.h"
|
||||
#include "gather_elements_inst.h"
|
||||
#include "scatter_update_inst.h"
|
||||
#include "scatter_nd_update_inst.h"
|
||||
#include "scatter_elements_update_inst.h"
|
||||
@ -200,6 +201,7 @@ void prepare_primitive_fusing::fuse_activations(program_impl &p) {
|
||||
!input.is_type<space_to_batch>() && !input.is_type<gather>() && !input.is_type<scatter_update>() && !input.is_type<shuffle_channels>() &&
|
||||
!input.is_type<scatter_nd_update>() &&
|
||||
!input.is_type<gather_nd>() &&
|
||||
!input.is_type<gather_elements>() &&
|
||||
!input.is_type<strided_slice>() && !input.is_type<cum_sum>() && !input.is_type<reverse_sequence>() &&
|
||||
!input.is_type<embedding_bag>() && !input.is_type<extract_image_patches>() &&
|
||||
!input.is_type<fused_conv_eltwise>() && !input.is_type<activation>()))
|
||||
@ -609,6 +611,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather_nd>();
|
||||
|
||||
should_fuse |= input_data.is_type<gather_elements>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_update>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||
@ -677,6 +681,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather_nd>();
|
||||
|
||||
should_fuse |= input_data.is_type<gather_elements>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_update>();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_nd_update>();
|
||||
@ -767,6 +773,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
|
||||
should_fuse |= input_data.is_type<gather_nd>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<gather_elements>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_update>() && quantize_node.get_scale_shift_opt();
|
||||
|
||||
should_fuse |= input_data.is_type<scatter_nd_update>() && quantize_node.get_scale_shift_opt();
|
||||
@ -829,6 +837,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program_impl &p) {
|
||||
(parents[i]->is_type<eltwise>() && eltwise_supports_fusings(parents[i]->as<eltwise>())) ||
|
||||
(parents[i]->is_type<scale>()) ||
|
||||
(parents[i]->is_type<gather_nd>()) ||
|
||||
(parents[i]->is_type<gather_elements>()) ||
|
||||
(parents[i]->is_type<scatter_nd_update>()) ||
|
||||
(parents[i]->is_type<scatter_elements_update>()) ||
|
||||
(parents[i]->is_type<pooling>() && pooling_supports_fusings(parents[i]->as<pooling>())) ||
|
||||
|
86
inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp
vendored
Normal file
86
inference-engine/thirdparty/clDNN/src/impls/ocl/gather_elements.cpp
vendored
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gather_elements_inst.h"
|
||||
#include "primitive_base.hpp"
|
||||
#include "impls/implementation_map.hpp"
|
||||
#include "kernel_selector_helper.h"
|
||||
#include "gather/gather_elements_kernel_selector.h"
|
||||
#include "gather/gather_elements_kernel_ref.h"
|
||||
#include "cldnn/runtime/error_handler.hpp"
|
||||
|
||||
using namespace cldnn;
|
||||
|
||||
namespace cldnn {
|
||||
namespace ocl {
|
||||
kernel_selector::gather_elements_axis convert_axis(gather_elements::gather_elements_axis axis) {
|
||||
switch (axis) {
|
||||
case gather_elements::along_x:
|
||||
return kernel_selector::gather_elements_axis::X;
|
||||
case gather_elements::along_y:
|
||||
return kernel_selector::gather_elements_axis::Y;
|
||||
case gather_elements::along_z:
|
||||
return kernel_selector::gather_elements_axis::Z;
|
||||
case gather_elements::along_w:
|
||||
return kernel_selector::gather_elements_axis::W;
|
||||
case gather_elements::along_f:
|
||||
return kernel_selector::gather_elements_axis::FEATURE;
|
||||
case gather_elements::along_b:
|
||||
return kernel_selector::gather_elements_axis::BATCH;
|
||||
default:
|
||||
return kernel_selector::gather_elements_axis::BATCH;
|
||||
}
|
||||
}
|
||||
|
||||
struct gather_elements_impl : typed_primitive_impl_ocl<gather_elements> {
|
||||
using parent = typed_primitive_impl_ocl<gather_elements>;
|
||||
using parent::parent;
|
||||
|
||||
std::unique_ptr<primitive_impl> clone() const override {
|
||||
return make_unique<gather_elements_impl>(*this);
|
||||
}
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const gather_elements_node& arg) {
|
||||
auto gather_elements_params = get_default_params<kernel_selector::gather_elements_params>(arg);
|
||||
auto gather_elements_optional_params =
|
||||
get_default_optional_params<kernel_selector::gather_elements_optional_params>(arg.get_program());
|
||||
|
||||
gather_elements_params.axis = convert_axis(arg.get_primitive()->axis);
|
||||
|
||||
gather_elements_params.inputs.push_back(convert_data_tensor(arg.input(1).get_output_layout()));
|
||||
|
||||
auto& kernel_selector = kernel_selector::gather_elements_kernel_selector::Instance();
|
||||
auto best_kernels = kernel_selector.GetBestKernels(gather_elements_params, gather_elements_optional_params);
|
||||
|
||||
CLDNN_ERROR_BOOL(arg.id(),
|
||||
"Best_kernel.empty()",
|
||||
best_kernels.empty(),
|
||||
"Cannot find a proper kernel with this arguments");
|
||||
|
||||
auto gather_elements = new gather_elements_impl(arg, best_kernels[0]);
|
||||
|
||||
return gather_elements;
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
attach_gather_elements_impl::attach_gather_elements_impl() {
|
||||
implementation_map<gather_elements>::add(impl_types::ocl, gather_elements_impl::create, {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
std::make_tuple(data_types::i32, format::bfzyx),
|
||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
||||
std::make_tuple(data_types::i32, format::bfwzyx),
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ocl
|
||||
} // namespace cldnn
|
@ -30,6 +30,7 @@ void register_implementations() {
|
||||
REGISTER_OCL(eltwise);
|
||||
REGISTER_OCL(fully_connected);
|
||||
REGISTER_OCL(gather);
|
||||
REGISTER_OCL(gather_elements);
|
||||
REGISTER_OCL(gather_nd);
|
||||
REGISTER_OCL(gemm);
|
||||
REGISTER_OCL(lrn);
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "cldnn/primitives/fully_connected.hpp"
|
||||
#include "cldnn/primitives/gather.hpp"
|
||||
#include "cldnn/primitives/gather_nd.hpp"
|
||||
#include "cldnn/primitives/gather_elements.hpp"
|
||||
#include "cldnn/primitives/gemm.hpp"
|
||||
#include "cldnn/primitives/lrn.hpp"
|
||||
#include "cldnn/primitives/lstm.hpp"
|
||||
@ -94,6 +95,7 @@ REGISTER_OCL(embed);
|
||||
REGISTER_OCL(fully_connected);
|
||||
REGISTER_OCL(gather);
|
||||
REGISTER_OCL(gather_nd);
|
||||
REGISTER_OCL(gather_elements);
|
||||
REGISTER_OCL(gemm);
|
||||
REGISTER_OCL(lrn);
|
||||
REGISTER_OCL(lstm_gemm);
|
||||
|
49
inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h
vendored
Normal file
49
inference-engine/thirdparty/clDNN/src/include/gather_elements_inst.h
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
// Copyright (c) 2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#pragma once
|
||||
#include "cldnn/primitives/gather_elements.hpp"
|
||||
#include "primitive_inst.h"
|
||||
#include <string>
|
||||
|
||||
namespace cldnn {
|
||||
template <>
|
||||
struct typed_program_node<gather_elements> : public typed_program_node_base<gather_elements> {
|
||||
using parent = typed_program_node_base<gather_elements>;
|
||||
|
||||
public:
|
||||
using parent::parent;
|
||||
|
||||
program_node& input(size_t index = 0) const { return get_dependency(index); }
|
||||
};
|
||||
|
||||
using gather_elements_node = typed_program_node<gather_elements>;
|
||||
|
||||
template <>
|
||||
class typed_primitive_inst<gather_elements> : public typed_primitive_inst_base<gather_elements> {
|
||||
using parent = typed_primitive_inst_base<gather_elements>;
|
||||
|
||||
public:
|
||||
static layout calc_output_layout(gather_elements_node const& node);
|
||||
static std::string to_string(gather_elements_node const& node);
|
||||
|
||||
public:
|
||||
typed_primitive_inst(network_impl& network, gather_elements_node const& desc);
|
||||
};
|
||||
|
||||
using gather_elements_inst = typed_primitive_inst<gather_elements>;
|
||||
} // namespace cldnn
|
@ -72,6 +72,7 @@ using shape_calculation_mode = kernel_selector::ShapeCalculationMode;
|
||||
using interpolate_axis = kernel_selector::InterpolateAxis;
|
||||
using border_type = kernel_selector::BorderType;
|
||||
using gather_axis = kernel_selector::GatherAxis;
|
||||
using gather_elements_axis = kernel_selector::GatherAxis;
|
||||
using scatter_update_axis = kernel_selector::ScatterUpdateAxis;
|
||||
using reduce_mode = kernel_selector::ReduceMode;
|
||||
using cum_sum_axis = kernel_selector::CumSumAxis;
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include <cldnn/primitives/permute.hpp>
|
||||
#include <cldnn/primitives/gather.hpp>
|
||||
#include <cldnn/primitives/gather_nd.hpp>
|
||||
#include <cldnn/primitives/gather_elements.hpp>
|
||||
#include <cldnn/primitives/scatter_update.hpp>
|
||||
#include <cldnn/primitives/scatter_nd_update.hpp>
|
||||
#include <cldnn/primitives/scatter_elements_update.hpp>
|
||||
@ -8413,3 +8414,228 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_nd_activation_scale_eltwise,
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_3, 2, 5 },
|
||||
gather_nd_test_params{ CASE_GATHER_ND_FP32_6D_4, 2, 5 },
|
||||
}));
|
||||
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
/* ------------------------------------------ GatherElements cases ------------------------------------- */
|
||||
/* ----------------------------------------------------------------------------------------------------- */
|
||||
struct gather_elements_test_params {
|
||||
data_types data_type;
|
||||
|
||||
format input_format;
|
||||
tensor input_shape;
|
||||
|
||||
format indices_format;
|
||||
tensor indices_shape;
|
||||
|
||||
format output_format;
|
||||
tensor output_shape;
|
||||
|
||||
cldnn::gather_elements::gather_elements_axis axis;
|
||||
|
||||
data_types default_type;
|
||||
format default_format;
|
||||
|
||||
size_t expected_fused_primitives;
|
||||
size_t expected_not_fused_primitives;
|
||||
};
|
||||
|
||||
#define CASE_GATHER_ELEMENTS_FP16_4D_1 data_types::f16, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ELEMENTS_FP16_4D_2 data_types::f16, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f16, format::bfyx
|
||||
#define CASE_GATHER_ELEMENTS_FP16_4D_3 data_types::f16, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ELEMENTS_FP16_5D_1 data_types::f16, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfzyx
|
||||
#define CASE_GATHER_ELEMENTS_FP16_5D_2 data_types::f16, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f16, format::bfzyx
|
||||
|
||||
#define CASE_GATHER_ELEMENTS_FP16_6D_1 data_types::f16, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f16, format::bfwzyx
|
||||
#define CASE_GATHER_ELEMENTS_FP16_6D_2 data_types::f16, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f16, format::bfwzyx
|
||||
#define CASE_GATHER_ELEMENTS_FP16_6D_3 data_types::f16, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f16, format::bfwzyx
|
||||
|
||||
|
||||
#define CASE_GATHER_ELEMENTS_FP32_4D_1 data_types::f32, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, format::bfyx, {3, 7, 9, 8}, cldnn::gather_elements::gather_elements_axis::along_y, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ELEMENTS_FP32_4D_2 data_types::f32, format::bfyx, {3, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, format::bfyx, {2, 2, 8, 3}, cldnn::gather_elements::gather_elements_axis::along_b, data_types::f32, format::bfyx
|
||||
#define CASE_GATHER_ELEMENTS_FP32_4D_3 data_types::f32, format::bfyx, {1, 3, 2, 9}, format::bfyx, {1, 3, 5, 9}, format::bfyx, {1, 3, 5, 9}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfyx
|
||||
|
||||
#define CASE_GATHER_ELEMENTS_FP32_5D_1 data_types::f32, format::bfzyx, {3, 2, 5, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, format::bfzyx, {3, 2, 2, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfzyx
|
||||
#define CASE_GATHER_ELEMENTS_FP32_5D_2 data_types::f32, format::bfzyx, {5, 4, 7, 4, 4}, format::bfzyx, {5, 4, 7, 4, 3}, format::bfzyx, {5, 4, 7, 4, 3}, cldnn::gather_elements::gather_elements_axis::along_z, data_types::f32, format::bfzyx
|
||||
|
||||
#define CASE_GATHER_ELEMENTS_FP32_6D_1 data_types::f32, format::bfwzyx, {5, 4, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, format::bfwzyx, {5, 2, 6, 7, 8, 2}, cldnn::gather_elements::gather_elements_axis::along_f, data_types::f32, format::bfwzyx
|
||||
#define CASE_GATHER_ELEMENTS_FP32_6D_2 data_types::f32, format::bfwzyx, {2, 1, 2, 3, 2, 1}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, format::bfwzyx, {2, 1, 2, 3, 2, 3}, cldnn::gather_elements::gather_elements_axis::along_w, data_types::f32, format::bfwzyx
|
||||
#define CASE_GATHER_ELEMENTS_FP32_6D_3 data_types::f32, format::bfwzyx, {2, 2, 3, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, format::bfwzyx, {2, 2, 6, 4, 4, 2}, cldnn::gather_elements::gather_elements_axis::along_x, data_types::f32, format::bfwzyx
|
||||
|
||||
class GatherElementsPrimitiveFusingTest : public ::BaseFusingTest<gather_elements_test_params> {
|
||||
public:
|
||||
void execute(gather_elements_test_params& p) {
|
||||
auto input_prim = get_mem(get_input_layout(p));
|
||||
network network_not_fused(this->engine, this->topology_non_fused, bo_not_fused);
|
||||
network network_fused(this->engine, this->topology_fused, bo_fused);
|
||||
network_fused.set_input_data("input", input_prim);
|
||||
network_not_fused.set_input_data("input", input_prim);
|
||||
compare(network_not_fused, network_fused, p);
|
||||
}
|
||||
|
||||
size_t get_axis_dim(gather_elements_test_params& p) {
|
||||
switch (p.axis) {
|
||||
case cldnn::gather_elements::gather_elements_axis::along_x:
|
||||
return p.input_shape.spatial[0];
|
||||
case cldnn::gather_elements::gather_elements_axis::along_y:
|
||||
return p.input_shape.spatial[1];
|
||||
case cldnn::gather_elements::gather_elements_axis::along_z:
|
||||
return p.input_shape.spatial[2];
|
||||
case cldnn::gather_elements::gather_elements_axis::along_w:
|
||||
return p.input_shape.spatial[3];
|
||||
case cldnn::gather_elements::gather_elements_axis::along_f:
|
||||
return p.input_shape.feature[0];
|
||||
case cldnn::gather_elements::gather_elements_axis::along_b:
|
||||
return p.input_shape.batch[0];
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
layout get_input_layout(gather_elements_test_params& p) {
|
||||
return layout{ p.data_type, p.input_format, p.input_shape };
|
||||
}
|
||||
|
||||
layout get_indices_layout(gather_elements_test_params& p) {
|
||||
return layout{ p.data_type, p.indices_format, p.indices_shape };
|
||||
}
|
||||
|
||||
layout get_output_layout(gather_elements_test_params& p) {
|
||||
return layout{ p.data_type, p.output_format, p.output_shape };
|
||||
}
|
||||
|
||||
layout get_per_channel_layout(gather_elements_test_params& p) {
|
||||
return layout{ p.default_type, p.default_format, tensor{1, p.output_shape.feature[0], 1, 1} };
|
||||
}
|
||||
};
|
||||
|
||||
class gather_elements_quantize : public GatherElementsPrimitiveFusingTest {};
|
||||
TEST_P(gather_elements_quantize, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p))-1)),
|
||||
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||
gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis),
|
||||
quantize("quantize", "gather_elements_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
||||
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||
);
|
||||
tolerance = 1.f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_quantize,
|
||||
::testing::ValuesIn(std::vector<gather_elements_test_params>{
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 3 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 3 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 3 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 3 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 3 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 3 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 3 },
|
||||
}));
|
||||
|
||||
|
||||
class gather_elements_scale_activation : public GatherElementsPrimitiveFusingTest {};
|
||||
TEST_P(gather_elements_scale_activation, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p))-1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
||||
gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis),
|
||||
activation("activation", "gather_elements_prim", activation_func::abs),
|
||||
scale("scale", "activation", "scale_data"),
|
||||
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_scale_activation,
|
||||
::testing::ValuesIn(std::vector<gather_elements_test_params>{
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 4 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 4 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 4 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 4 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 4 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 4 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 4 },
|
||||
}));
|
||||
|
||||
|
||||
class gather_elements_activation_scale_eltwise : public GatherElementsPrimitiveFusingTest {};
|
||||
TEST_P(gather_elements_activation_scale_eltwise, basic) {
|
||||
auto p = GetParam();
|
||||
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_elements_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p))-1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)),
|
||||
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||
gather_elements("gather_elements_prim", "input", "gather_elements_indices", p.output_format, p.output_shape, p.axis),
|
||||
activation("activation", "gather_elements_prim", activation_func::abs),
|
||||
scale("scale", "activation", "scale_data"),
|
||||
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
|
||||
reorder("reorder_bfyx", "eltwise", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
tolerance = 1e-5f;
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, gather_elements_activation_scale_eltwise,
|
||||
::testing::ValuesIn(std::vector<gather_elements_test_params>{
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_1, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_2, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_4D_3, 2, 5 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_1, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_5D_2, 2, 5 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_1, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_2, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP16_6D_3, 2, 5 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_1, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_2, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_4D_3, 2, 5 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_1, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_5D_2, 2, 5 },
|
||||
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_1, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_2, 2, 5 },
|
||||
gather_elements_test_params{ CASE_GATHER_ELEMENTS_FP32_6D_3, 2, 5 },
|
||||
}));
|
||||
|
1141
inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp
vendored
Normal file
1141
inference-engine/thirdparty/clDNN/tests/test_cases/gather_elements_gpu_test.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user