[CPU] GatherElements implementation. (#3860)
This commit is contained in:
parent
d5aa6d4fa1
commit
245bc33e8a
@ -1,4 +1,4 @@
|
||||
# Copyright (C) 2018-2020 Intel Corporation
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
@ -62,6 +62,7 @@ set(LAYERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/extract_image_patches.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/fill.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_elements.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_nd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -18,7 +18,7 @@ class BF16Transformer {
|
||||
const InferenceEngine::details::caseless_set<std::string> _complementbf16 =
|
||||
{ "relu", "tanh", "elu", "square", "abs", "sqrt", "linear", "bounded_relu", "soft_relu", "normalize",
|
||||
"sigmoid", "ReLU6", "not", "activation", "HSwish", "mish", "logistic", "mod", "resample",
|
||||
"exp", "gelu", "clamp", "swish", "prelu", "pooling", "norm", "gather", "memory", "mvn", "crop", "activation",
|
||||
"exp", "gelu", "clamp", "swish", "prelu", "pooling", "norm", "gather", "gather_elements", "memory", "mvn", "crop", "activation",
|
||||
"broadcast", "convert", "BatchToSpace", "DepthToSpace", "ExtractImagePatches", "concat", "power", "lrn",
|
||||
"permute", "ScatterUpdate", "ScatterElementsUpdate", "ScatterNDUpdate", "depthwise",
|
||||
"select", "ShuffleChannels", "SpaceToBatch", "SpaceToDepth", "squeeze", "StridedSlice", "unsqueeze", "eltwise",
|
||||
|
149
inference-engine/src/mkldnn_plugin/nodes/gather_elements.cpp
Normal file
149
inference-engine/src/mkldnn_plugin/nodes/gather_elements.cpp
Normal file
@ -0,0 +1,149 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "base.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ie_parallel.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace Extensions {
|
||||
namespace Cpu {
|
||||
|
||||
class GatherElementsImpl: public ExtLayerBase {
|
||||
public:
|
||||
explicit GatherElementsImpl(const CNNLayer* layer) : strideAx1Diff_(0) {
|
||||
errorPrefix_ = std::string("Layer GatherElements with name '") + layer->name + "'";
|
||||
|
||||
if (layer->insData.size() != 2 || layer->outData.size() != 1)
|
||||
THROW_IE_EXCEPTION << errorPrefix_ << " has invalid number of input/output edges.";
|
||||
|
||||
auto inputData = layer->insData[dataIndex_].lock();
|
||||
auto indices = layer->insData[indicesIndex_].lock();
|
||||
if (!inputData || !indices)
|
||||
THROW_IE_EXCEPTION << errorPrefix_ << " has nullable inputs.";
|
||||
|
||||
const auto& dataDims = inputData->getTensorDesc().getDims();
|
||||
const auto& indicesDims = indices->getTensorDesc().getDims();
|
||||
if (dataDims.size() != indicesDims.size())
|
||||
THROW_IE_EXCEPTION << errorPrefix_ << " has invalid input shapes. Inputs 'Data' and 'Indices' must have equal ranks.";
|
||||
|
||||
Precision dataPrecision = inputData->getTensorDesc().getPrecision();
|
||||
if (dataPrecision.size() != sizeof(PrecisionTrait<Precision::I32>::value_type) &&
|
||||
dataPrecision.size() != sizeof(PrecisionTrait<Precision::I16>::value_type) &&
|
||||
dataPrecision.size() != sizeof(PrecisionTrait<Precision::I8>::value_type)) {
|
||||
THROW_IE_EXCEPTION << errorPrefix_ << " has unsupported 'inputData' input precision: " << dataPrecision;
|
||||
}
|
||||
|
||||
Precision indicesPrecision = indices->getTensorDesc().getPrecision();
|
||||
if (indicesPrecision != Precision::I32) {
|
||||
THROW_IE_EXCEPTION << errorPrefix_ << " has unsupported 'indices' input precision: " << indicesPrecision;
|
||||
}
|
||||
|
||||
dataTypeSize_ = dataPrecision.size();
|
||||
|
||||
int axis = layer->GetParamAsInt("axis");
|
||||
if (axis < 0)
|
||||
axis += dataDims.size();
|
||||
if (axis < 0 || axis >= static_cast<int>(dataDims.size()))
|
||||
THROW_IE_EXCEPTION << errorPrefix_ << " has invalid axis attribute: " << axis;
|
||||
axis_ = axis;
|
||||
|
||||
auto& outputData = layer->outData[0];
|
||||
strideAxDst_ = outputData->getTensorDesc().getBlockingDesc().getStrides()[axis_];
|
||||
dstAxDim_ = outputData->getTensorDesc().getDims()[axis_];
|
||||
if (axis_ > 0) {
|
||||
strideAx1Diff_ = inputData->getTensorDesc().getBlockingDesc().getStrides()[axis_ - 1] -
|
||||
outputData->getTensorDesc().getBlockingDesc().getStrides()[axis_ - 1];
|
||||
}
|
||||
|
||||
LayerConfig config;
|
||||
DataConfig dataConfig, indicesConfig, outConfig;
|
||||
dataConfig.desc = TensorDesc(dataPrecision, dataDims,
|
||||
inputData->getTensorDesc().getLayoutByDims(dataDims));
|
||||
config.inConfs.push_back(dataConfig);
|
||||
indicesConfig.desc = TensorDesc(Precision::I32, indicesDims,
|
||||
indices->getTensorDesc().getLayoutByDims(indicesDims));
|
||||
config.inConfs.push_back(indicesConfig);
|
||||
|
||||
const auto& outDims = outputData->getTensorDesc().getDims();
|
||||
outConfig.desc = TensorDesc(dataPrecision, outDims,
|
||||
outputData->getTensorDesc().getLayoutByDims(outDims));
|
||||
config.outConfs.push_back(outConfig);
|
||||
|
||||
config.dynBatchSupport = false;
|
||||
|
||||
confs.push_back(config);
|
||||
}
|
||||
|
||||
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
||||
switch (dataTypeSize_) {
|
||||
case sizeof(PrecisionTrait<Precision::I32>::value_type):
|
||||
return directExecution<PrecisionTrait<Precision::I32>::value_type>(inputs, outputs, resp);
|
||||
case sizeof(PrecisionTrait<Precision::I16>::value_type):
|
||||
return directExecution<PrecisionTrait<Precision::I16>::value_type>(inputs, outputs, resp);
|
||||
case sizeof(PrecisionTrait<Precision::I8>::value_type):
|
||||
return directExecution<PrecisionTrait<Precision::I8>::value_type>(inputs, outputs, resp);
|
||||
default:
|
||||
std::string errMsg = errorPrefix_ + " has inputData input with unsupported precision: " +
|
||||
inputs[dataIndex_]->getTensorDesc().getPrecision().name();
|
||||
errMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename dataType>
|
||||
StatusCode directExecution(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept {
|
||||
const dataType* srcData = inputs[dataIndex_]->cbuffer().as<const dataType*>() +
|
||||
inputs[dataIndex_]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const int* indices = inputs[indicesIndex_]->cbuffer().as<const int*>() +
|
||||
inputs[indicesIndex_]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
dataType* dstData = outputs[0]->buffer().as<dataType*>() +
|
||||
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
|
||||
const int outSize = outputs[0]->size();
|
||||
auto threadBody = [&](const int ithr, const int nthr) {
|
||||
int start(0lu), end(0lu);
|
||||
splitter(outSize, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
|
||||
int axStrideIt = start % strideAxDst_;
|
||||
int dstAxIdx = (start / strideAxDst_) % dstAxDim_;
|
||||
int dstShift0 = (start / strideAxDst_ / dstAxDim_) * strideAx1Diff_;
|
||||
|
||||
for (size_t o = start; o < end; o++, axStrideIt++) {
|
||||
if (axStrideIt == strideAxDst_) {
|
||||
axStrideIt = 0;
|
||||
dstAxIdx++;
|
||||
if (dstAxIdx == dstAxDim_) {
|
||||
dstAxIdx = 0;
|
||||
dstShift0 += strideAx1Diff_;
|
||||
}
|
||||
}
|
||||
dstData[o] = srcData[o + dstShift0 + (indices[o] - dstAxIdx) * strideAxDst_];
|
||||
}
|
||||
};
|
||||
parallel_nt(0, threadBody);
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
const size_t dataIndex_ = 0;
|
||||
const size_t indicesIndex_ = 1;
|
||||
|
||||
size_t axis_;
|
||||
size_t dataTypeSize_;
|
||||
int strideAxDst_;
|
||||
int dstAxDim_;
|
||||
int strideAx1Diff_;
|
||||
std::string errorPrefix_;
|
||||
};
|
||||
|
||||
REG_FACTORY_FOR(GatherElementsImpl, GatherElements);
|
||||
} // namespace Cpu
|
||||
} // namespace Extensions
|
||||
} // namespace InferenceEngine
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -72,6 +72,7 @@ MKLDNN_EXTENSION_NODE(SparseFillEmptyRowsImpl, SparseFillEmptyRows);
|
||||
MKLDNN_EXTENSION_NODE(BucketizeImpl, Bucketize);
|
||||
MKLDNN_EXTENSION_NODE(CTCGreedyDecoderImpl, CTCGreedyDecoder);
|
||||
MKLDNN_EXTENSION_NODE(GatherImpl, Gather);
|
||||
MKLDNN_EXTENSION_NODE(GatherElementsImpl, GatherElements);
|
||||
MKLDNN_EXTENSION_NODE(GatherNDImpl, GatherND);
|
||||
MKLDNN_EXTENSION_NODE(ProposalImpl, Proposal);
|
||||
MKLDNN_EXTENSION_NODE(RangeImpl, Range);
|
||||
|
@ -0,0 +1,76 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "shared_test_classes/single_layer/gather_elements.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> dPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16,
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64,
|
||||
InferenceEngine::Precision::I16,
|
||||
InferenceEngine::Precision::U8,
|
||||
InferenceEngine::Precision::I8
|
||||
};
|
||||
const std::vector<InferenceEngine::Precision> iPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 2})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({2, 2})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({-1, 0, 1})), // Axis
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set2, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 2, 1})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({4, 2, 1})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({0, -3})), // Axis
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set3, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 2, 3, 5})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({2, 2, 3, 7})), // Indices shape
|
||||
::testing::Values(3, -1), // Axis
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set4, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({3, 2, 3, 8})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({2, 2, 3, 8})), // Indices shape
|
||||
::testing::Values(0, -4), // Axis
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set5, GatherElementsLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({3, 2, 3, 4, 8})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({3, 2, 3, 5, 8})), // Indices shape
|
||||
::testing::Values(3, -2), // Axis
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherElementsLayerTest::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,92 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <shared_test_classes/single_layer/gather_elements.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ngraph::helpers;
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
typedef std::tuple<
|
||||
GatherElementsParams,
|
||||
CPUSpecificParams
|
||||
> GatherElementsCPUTestParamSet;
|
||||
|
||||
class GatherElementsCPUTest : public testing::WithParamInterface<GatherElementsCPUTestParamSet>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherElementsCPUTestParamSet> &obj) {
|
||||
GatherElementsParams basicParamsSet;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(basicParamsSet, cpuParams) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << GatherElementsLayerTest::getTestCaseName(testing::TestParamInfo<GatherElementsParams>(basicParamsSet, 0));
|
||||
|
||||
result << CPUTestsBase::getTestCaseName(cpuParams);
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &info) const override {
|
||||
return FuncTestUtils::createAndFillBlob(info.getTensorDesc(), 15, 0, 32768);
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
int axis;
|
||||
|
||||
GatherElementsParams basicParamsSet;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(basicParamsSet, cpuParams) = this->GetParam();
|
||||
|
||||
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
||||
|
||||
std::tie(dataShape, indicesShape, axis, dPrecision, iPrecision, targetDevice) = basicParamsSet;
|
||||
selectedType = std::string("unknown_") + dPrecision.name();
|
||||
|
||||
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
|
||||
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
|
||||
auto activation = ngraph::builder::makeGatherElements(params[0], indicesShape, ngIPrc, axis);
|
||||
activation->get_rt_info() = getCPUInfo();
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{activation}, params, "GatherElements");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(GatherElementsCPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
CheckPluginRelatedResults(executableNetwork, "GatherElements");
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
std::vector<CPUSpecificParams> cpuParams_4D = {
|
||||
CPUSpecificParams({nchw}, {nchw}, {}, {})
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set1, GatherElementsCPUTest,
|
||||
::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 3, 5, 7})), // Data shape
|
||||
::testing::Values(std::vector<size_t>({2, 3, 9, 7})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({2, -2})), // Axis
|
||||
::testing::Values(Precision::BF16),
|
||||
::testing::Values(Precision::I32),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D))),
|
||||
GatherElementsCPUTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace CPULayerTestsDefinitions
|
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<size_t>, // Data shapes
|
||||
std::vector<size_t>, // Indices shape
|
||||
int, // Axis
|
||||
InferenceEngine::Precision, // Data precision
|
||||
InferenceEngine::Precision, // Indices precision
|
||||
LayerTestsUtils::TargetDevice // Device name
|
||||
> GatherElementsParams;
|
||||
|
||||
class GatherElementsLayerTest : public testing::WithParamInterface<GatherElementsParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherElementsParams>& obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,54 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "shared_test_classes/single_layer/gather_elements.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string GatherElementsLayerTest::getTestCaseName(const testing::TestParamInfo<GatherElementsParams>& obj) {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
int axis;
|
||||
std::string device;
|
||||
std::tie(dataShape, indicesShape, axis, dPrecision, iPrecision, device) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "DS=" << CommonTestUtils::vec2str(dataShape) << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(indicesShape) << "_";
|
||||
result << "Ax=" << axis << "_";
|
||||
result << "DP=" << dPrecision.name() << "_";
|
||||
result << "IP=" << iPrecision.name() << "_";
|
||||
result << "device=" << device;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void GatherElementsLayerTest::SetUp() {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
int axis;
|
||||
std::tie(dataShape, indicesShape, axis, dPrecision, iPrecision, targetDevice) = this->GetParam();
|
||||
|
||||
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
|
||||
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||
auto gather = std::dynamic_pointer_cast<ngraph::op::v6::GatherElements>(
|
||||
ngraph::builder::makeGatherElements(paramOuts[0], indicesShape, ngIPrc, axis));
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherEl");
|
||||
}
|
||||
|
||||
TEST_P(GatherElementsLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
} // namespace LayerTestsDefinitions
|
@ -440,6 +440,12 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
|
||||
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
ngraph::helpers::SequenceTestsMode mode = ngraph::helpers::SequenceTestsMode::PURE_SEQ);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGatherElements(
|
||||
const ngraph::Output<Node>& dataNode,
|
||||
const ngraph::Shape& indicesShape,
|
||||
const element::Type& indicesType,
|
||||
const int axis);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGatherND(
|
||||
const ngraph::Output<Node>& dataNode,
|
||||
const ngraph::Shape& indicesShape,
|
||||
|
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<Node> makeGatherElements(
|
||||
const Output<Node>& dataNode,
|
||||
const Shape& indicesShape,
|
||||
const element::Type& indicesType,
|
||||
const int axis) {
|
||||
const auto& dataShape = dataNode.get_shape();
|
||||
int posAxis = axis;
|
||||
if (posAxis < 0)
|
||||
posAxis += dataShape.size();
|
||||
const auto axisDim = dataShape[posAxis];
|
||||
const auto indicesSize = std::accumulate(begin(indicesShape), end(indicesShape),
|
||||
1ull, std::multiplies<std::size_t>{});
|
||||
|
||||
auto indicesValues = NGraphFunctions::Utils::generateVector<element::Type_t::i32>(indicesSize, axisDim - 1, 0);
|
||||
auto indicesNode = opset5::Constant::create(indicesType, indicesShape, indicesValues);
|
||||
|
||||
auto gatherElNode = std::make_shared<op::v6::GatherElements>(dataNode, indicesNode, axis);
|
||||
gatherElNode->set_friendly_name("GatherElements");
|
||||
|
||||
return gatherElNode;
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -438,9 +438,7 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded_cpu", # noqa
|
||||
"OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded_cpu", # noqa
|
||||
"OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded_cpu", # noqa
|
||||
"OnnxBackendNodeModelTest.test_gather_elements_0_cpu",
|
||||
"OnnxBackendNodeModelTest.test_gather_elements_negative_indices_cpu",
|
||||
"OnnxBackendNodeModelTest.test_gather_elements_1_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NC_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NC_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_cpu",
|
||||
|
Loading…
Reference in New Issue
Block a user