[CPU] GatherND implementation. (#2757)
This commit is contained in:
parent
6fec63862b
commit
257bfc9944
@ -57,6 +57,7 @@ set(LAYERS
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/extract_image_patches.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/extract_image_patches.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/fill.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/fill.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_nd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
|
||||||
|
230
inference-engine/src/mkldnn_plugin/nodes/gather_nd.cpp
Normal file
230
inference-engine/src/mkldnn_plugin/nodes/gather_nd.cpp
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "base.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ie_parallel.hpp"
|
||||||
|
#include "common/cpu_memcpy.h"
|
||||||
|
|
||||||
|
namespace InferenceEngine {
|
||||||
|
namespace Extensions {
|
||||||
|
namespace Cpu {
|
||||||
|
|
||||||
|
class GatherNDImpl: public ExtLayerBase {
|
||||||
|
public:
|
||||||
|
explicit GatherNDImpl(const CNNLayer* layer) {
|
||||||
|
_errorPrefix = std::string("Layer GatherND with name '") + layer->name + "'";
|
||||||
|
|
||||||
|
if (layer->insData.size() != 2 || layer->outData.size() != 1)
|
||||||
|
THROW_IE_EXCEPTION << _errorPrefix << " has invalid number of input/output edges.";
|
||||||
|
|
||||||
|
auto data = layer->insData[_dataIndex].lock();
|
||||||
|
auto indices = layer->insData[_indicesIndex].lock();
|
||||||
|
if (!data || !indices)
|
||||||
|
THROW_IE_EXCEPTION << _errorPrefix << " has nullable inputs.";
|
||||||
|
Precision dataPrecision = data->getTensorDesc().getPrecision();
|
||||||
|
if (dataPrecision.size() != sizeof(PrecisionTrait<Precision::I32>::value_type) &&
|
||||||
|
dataPrecision.size() != sizeof(PrecisionTrait<Precision::I16>::value_type) &&
|
||||||
|
dataPrecision.size() != sizeof(PrecisionTrait<Precision::I8>::value_type)) {
|
||||||
|
THROW_IE_EXCEPTION << _errorPrefix << " has unsupported 'data' input precision: " << dataPrecision;
|
||||||
|
}
|
||||||
|
|
||||||
|
Precision indicesPrecision = indices->getTensorDesc().getPrecision();
|
||||||
|
if (indicesPrecision != Precision::I32 &&
|
||||||
|
indicesPrecision != Precision::I16 && indicesPrecision != Precision::U16 &&
|
||||||
|
indicesPrecision != Precision::I8 && indicesPrecision != Precision::U8) {
|
||||||
|
THROW_IE_EXCEPTION << _errorPrefix << " has unsupported 'indices' input precision: " << indicesPrecision;
|
||||||
|
}
|
||||||
|
|
||||||
|
_dataTypeSize = dataPrecision.size();
|
||||||
|
const auto& dataDims = data->getTensorDesc().getDims();
|
||||||
|
const auto& indicesDims = indices->getTensorDesc().getDims();
|
||||||
|
|
||||||
|
_batchDims = layer->GetParamAsInt("batch_dims", 0);
|
||||||
|
if (_batchDims >= std::min(dataDims.size(), indicesDims.size()))
|
||||||
|
THROW_IE_EXCEPTION << _errorPrefix << " has invalid batch_dims attribute: " << _batchDims;
|
||||||
|
|
||||||
|
_batchNum = 1lu;
|
||||||
|
for (size_t i = 0; i < _batchDims; i++) {
|
||||||
|
_batchNum *= indicesDims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
_sliceRank = indicesDims[indicesDims.size() - 1];
|
||||||
|
_dataRank = dataDims.size() - _batchDims;
|
||||||
|
if (_sliceRank > _dataRank)
|
||||||
|
THROW_IE_EXCEPTION << _errorPrefix << " has invalid inputs shapes.";
|
||||||
|
|
||||||
|
_blockSize = 1;
|
||||||
|
for (size_t i = _sliceRank + _batchDims; i < dataDims.size(); i++) {
|
||||||
|
_blockSize *= dataDims[i];
|
||||||
|
}
|
||||||
|
_batchStep = 1;
|
||||||
|
for (size_t i = _batchDims; i < dataDims.size(); i++) {
|
||||||
|
_batchStep *= dataDims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
LayerConfig config;
|
||||||
|
DataConfig dataConfig, indicesConfig, outConfig;
|
||||||
|
dataConfig.desc = TensorDesc(dataPrecision, dataDims,
|
||||||
|
data->getTensorDesc().getLayoutByDims(dataDims));
|
||||||
|
config.inConfs.push_back(dataConfig);
|
||||||
|
indicesConfig.desc = TensorDesc(Precision::I32, indicesDims,
|
||||||
|
indices->getTensorDesc().getLayoutByDims(indicesDims));
|
||||||
|
config.inConfs.push_back(indicesConfig);
|
||||||
|
|
||||||
|
const auto& outDims = layer->outData[0]->getTensorDesc().getDims();
|
||||||
|
outConfig.desc = TensorDesc(dataPrecision, outDims,
|
||||||
|
layer->outData[0]->getTensorDesc().getLayoutByDims(outDims));
|
||||||
|
config.outConfs.push_back(outConfig);
|
||||||
|
config.dynBatchSupport = false;
|
||||||
|
|
||||||
|
confs.push_back(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
||||||
|
if (_blockSize > 1) {
|
||||||
|
gatherBlocks(inputs, outputs, resp);
|
||||||
|
} else {
|
||||||
|
switch (_dataTypeSize) {
|
||||||
|
case sizeof(PrecisionTrait<Precision::I32>::value_type):
|
||||||
|
gatherElementwise<PrecisionTrait<Precision::I32>::value_type>(inputs, outputs, resp);
|
||||||
|
break;
|
||||||
|
case sizeof(PrecisionTrait<Precision::I16>::value_type):
|
||||||
|
gatherElementwise<PrecisionTrait<Precision::I16>::value_type>(inputs, outputs, resp);
|
||||||
|
break;
|
||||||
|
case sizeof(PrecisionTrait<Precision::I8>::value_type):
|
||||||
|
gatherElementwise<PrecisionTrait<Precision::I8>::value_type>(inputs, outputs, resp);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
std::string errMsg = _errorPrefix + " has data input with unsupported precision: " +
|
||||||
|
inputs[_dataIndex]->getTensorDesc().getPrecision().name();
|
||||||
|
errMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||||
|
return GENERAL_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
template <typename dataType>
|
||||||
|
void gatherElementwise(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept {
|
||||||
|
const dataType* srcData = inputs[_dataIndex]->cbuffer().as<const dataType*>() +
|
||||||
|
inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
const int* indices = inputs[_indicesIndex]->cbuffer().as<const int*>() +
|
||||||
|
inputs[_indicesIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
dataType* dstData = outputs[0]->buffer().as<dataType*>() +
|
||||||
|
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
|
||||||
|
const size_t* srcMultipliers = inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getStrides().data() + _batchDims;
|
||||||
|
|
||||||
|
const size_t cycles = outputs[0]->byteSize() / (sizeof(dataType) * _batchNum);
|
||||||
|
const size_t CS = cycles * _sliceRank;
|
||||||
|
const size_t CB = cycles * _blockSize;
|
||||||
|
const size_t workAmount = _batchNum * cycles;
|
||||||
|
|
||||||
|
auto threadBody = [&](const int ithr, const int nthr) {
|
||||||
|
size_t start(0lu), end(0lu);
|
||||||
|
splitter(workAmount, nthr, ithr, start, end);
|
||||||
|
if (start >= end)
|
||||||
|
return;
|
||||||
|
size_t bStart = start / cycles;
|
||||||
|
size_t cStart = start % cycles;
|
||||||
|
size_t workCounter = start;
|
||||||
|
|
||||||
|
const dataType* shiftedSrcData = srcData + bStart * _batchStep;
|
||||||
|
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
|
||||||
|
dataType* shiftedDstData = dstData + bStart * CB + cStart * _blockSize;
|
||||||
|
|
||||||
|
for (size_t b = bStart; b < _batchNum; b++) {
|
||||||
|
for (size_t j = cStart; j < cycles; j++) {
|
||||||
|
size_t dataIdx = 0lu;
|
||||||
|
for (size_t i = 0lu; i < _sliceRank; i++)
|
||||||
|
dataIdx += srcMultipliers[i] * shiftedIndices[i];
|
||||||
|
shiftedDstData[0] = shiftedSrcData[dataIdx];
|
||||||
|
shiftedDstData++;
|
||||||
|
shiftedIndices += _sliceRank;
|
||||||
|
if (++workCounter == end) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cStart = 0lu;
|
||||||
|
shiftedSrcData += _batchStep;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
parallel_nt(0, threadBody);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gatherBlocks(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept {
|
||||||
|
const uint8_t* srcData = inputs[_dataIndex]->cbuffer().as<const uint8_t*>() +
|
||||||
|
inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
const int* indices = inputs[_indicesIndex]->cbuffer().as<const int*>() +
|
||||||
|
inputs[_indicesIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
uint8_t* dstData = outputs[0]->buffer().as<uint8_t*>() +
|
||||||
|
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
|
||||||
|
std::vector<size_t> srcMultipliers(_sliceRank);
|
||||||
|
for (size_t i = 0; i < _sliceRank ; i++)
|
||||||
|
srcMultipliers[i] = _dataTypeSize * inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getStrides()[i + _batchDims];
|
||||||
|
|
||||||
|
const size_t batchStep = _batchStep * _dataTypeSize;
|
||||||
|
const size_t dataStep = _blockSize * _dataTypeSize;
|
||||||
|
const size_t cycles = outputs[0]->byteSize() / (dataStep * _batchNum);
|
||||||
|
const size_t CS = cycles * _sliceRank;
|
||||||
|
const size_t CB = cycles * dataStep;
|
||||||
|
const size_t workAmount = _batchNum * cycles;
|
||||||
|
|
||||||
|
auto threadBody = [&](const int ithr, const int nthr) {
|
||||||
|
size_t start(0lu), end(0lu);
|
||||||
|
splitter(workAmount, nthr, ithr, start, end);
|
||||||
|
if (start >= end)
|
||||||
|
return;
|
||||||
|
size_t bStart = start / cycles;
|
||||||
|
size_t cStart = start % cycles;
|
||||||
|
size_t workCounter = start;
|
||||||
|
|
||||||
|
const uint8_t* shiftedSrcData = srcData + bStart * batchStep;
|
||||||
|
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
|
||||||
|
uint8_t* shiftedDstData = dstData + bStart * CB + cStart * dataStep;
|
||||||
|
|
||||||
|
for (size_t b = bStart; b < _batchNum; b++) {
|
||||||
|
for (size_t j = cStart; j < cycles; j++) {
|
||||||
|
size_t dataIdx = 0lu;
|
||||||
|
for (size_t i = 0; i < _sliceRank ; i++)
|
||||||
|
dataIdx += srcMultipliers[i] * shiftedIndices[i];
|
||||||
|
cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataStep);
|
||||||
|
shiftedDstData += dataStep;
|
||||||
|
shiftedIndices += _sliceRank;
|
||||||
|
if (++workCounter == end) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cStart = 0;
|
||||||
|
shiftedSrcData += batchStep;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
parallel_nt(0, threadBody);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t _dataRank;
|
||||||
|
size_t _sliceRank;
|
||||||
|
size_t _blockSize;
|
||||||
|
size_t _batchDims;
|
||||||
|
size_t _batchNum;
|
||||||
|
size_t _batchStep;
|
||||||
|
size_t _dataTypeSize;
|
||||||
|
const size_t _dataIndex = 0;
|
||||||
|
const size_t _indicesIndex = 1;
|
||||||
|
std::string _errorPrefix;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
REG_FACTORY_FOR(GatherNDImpl, GatherND);
|
||||||
|
} // namespace Cpu
|
||||||
|
} // namespace Extensions
|
||||||
|
} // namespace InferenceEngine
|
@ -73,6 +73,7 @@ MKLDNN_EXTENSION_NODE(SparseFillEmptyRowsImpl, SparseFillEmptyRows);
|
|||||||
MKLDNN_EXTENSION_NODE(BucketizeImpl, Bucketize);
|
MKLDNN_EXTENSION_NODE(BucketizeImpl, Bucketize);
|
||||||
MKLDNN_EXTENSION_NODE(CTCGreedyDecoderImpl, CTCGreedyDecoder);
|
MKLDNN_EXTENSION_NODE(CTCGreedyDecoderImpl, CTCGreedyDecoder);
|
||||||
MKLDNN_EXTENSION_NODE(GatherImpl, Gather);
|
MKLDNN_EXTENSION_NODE(GatherImpl, Gather);
|
||||||
|
MKLDNN_EXTENSION_NODE(GatherNDImpl, GatherND);
|
||||||
MKLDNN_EXTENSION_NODE(ProposalImpl, Proposal);
|
MKLDNN_EXTENSION_NODE(ProposalImpl, Proposal);
|
||||||
MKLDNN_EXTENSION_NODE(RangeImpl, Range);
|
MKLDNN_EXTENSION_NODE(RangeImpl, Range);
|
||||||
MKLDNN_EXTENSION_NODE(SelectImpl, Select);
|
MKLDNN_EXTENSION_NODE(SelectImpl, Select);
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "single_layer_tests/gather_nd.hpp"
|
||||||
|
|
||||||
|
using namespace LayerTestsDefinitions;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const std::vector<InferenceEngine::Precision> dPrecisions = {
|
||||||
|
InferenceEngine::Precision::FP32,
|
||||||
|
InferenceEngine::Precision::FP16,
|
||||||
|
InferenceEngine::Precision::I32,
|
||||||
|
InferenceEngine::Precision::I64,
|
||||||
|
InferenceEngine::Precision::I16,
|
||||||
|
InferenceEngine::Precision::U8,
|
||||||
|
InferenceEngine::Precision::I8
|
||||||
|
};
|
||||||
|
const std::vector<InferenceEngine::Precision> iPrecisions = {
|
||||||
|
InferenceEngine::Precision::I32,
|
||||||
|
InferenceEngine::Precision::I64
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto gatherNDArgsSubset1 = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{{2, 2}, {2, 3, 4}})), // Data shape
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{{2, 1}, {2, 1, 1}})), // Indices shape
|
||||||
|
::testing::ValuesIn(std::vector<int>({0, 1})) // Batch dims
|
||||||
|
);
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_Set1, GatherNDLayerTest,
|
||||||
|
::testing::Combine(
|
||||||
|
gatherNDArgsSubset1,
|
||||||
|
::testing::ValuesIn(dPrecisions),
|
||||||
|
::testing::ValuesIn(iPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
GatherNDLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
const auto gatherNDArgsSubset2 = ::testing::Combine(
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{{15, 12, 20, 15, 2}, {15, 12, 18, 7, 17}})), // Data shape
|
||||||
|
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||||
|
{{15, 12, 2}, {15, 12, 5, 9, 1, 3}})), // Indices shape
|
||||||
|
::testing::ValuesIn(std::vector<int>({0, 1, 2})) // Batch dims
|
||||||
|
);
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_Set2, GatherNDLayerTest,
|
||||||
|
::testing::Combine(
|
||||||
|
gatherNDArgsSubset2,
|
||||||
|
::testing::ValuesIn(dPrecisions),
|
||||||
|
::testing::ValuesIn(iPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
GatherNDLayerTest::getTestCaseName);
|
||||||
|
} // namespace
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "functional_test_utils/layer_test_utils.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
typedef std::tuple<
|
||||||
|
std::vector<size_t>, // Data shapes
|
||||||
|
std::vector<size_t>, // Indices shape
|
||||||
|
int // batch dims
|
||||||
|
> GatherNDParamsSubset;
|
||||||
|
|
||||||
|
typedef std::tuple<
|
||||||
|
GatherNDParamsSubset,
|
||||||
|
InferenceEngine::Precision, // Data precision
|
||||||
|
InferenceEngine::Precision, // Indices precision
|
||||||
|
LayerTestsUtils::TargetDevice // Device name
|
||||||
|
> GatherNDParams;
|
||||||
|
|
||||||
|
namespace LayerTestsDefinitions {
|
||||||
|
|
||||||
|
class GatherNDLayerTest : public testing::WithParamInterface<GatherNDParams>,
|
||||||
|
public LayerTestsUtils::LayerTestsCommon {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<GatherNDParams> &obj);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ngraph_functions/builders.hpp"
|
||||||
|
#include "single_layer_tests/gather_nd.hpp"
|
||||||
|
|
||||||
|
namespace LayerTestsDefinitions {
|
||||||
|
|
||||||
|
std::string GatherNDLayerTest::getTestCaseName(const testing::TestParamInfo<GatherNDParams>& obj) {
|
||||||
|
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||||
|
InferenceEngine::Precision dPrecision, iPrecision;
|
||||||
|
int batchDims;
|
||||||
|
std::string device;
|
||||||
|
GatherNDParamsSubset gatherArgsSubset;
|
||||||
|
std::tie(gatherArgsSubset, dPrecision, iPrecision, device) = obj.param;
|
||||||
|
std::tie(dataShape, indicesShape, batchDims) = gatherArgsSubset;
|
||||||
|
|
||||||
|
std::ostringstream result;
|
||||||
|
result << "DS=" << CommonTestUtils::vec2str(dataShape) << "_";
|
||||||
|
result << "IS=" << CommonTestUtils::vec2str(indicesShape) << "_";
|
||||||
|
result << "BD=" << batchDims << "_";
|
||||||
|
result << "DP=" << dPrecision.name() << "_";
|
||||||
|
result << "IP=" << iPrecision.name() << "_";
|
||||||
|
result << "device=" << device;
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GatherNDLayerTest::SetUp() {
|
||||||
|
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||||
|
InferenceEngine::Precision dPrecision, iPrecision;
|
||||||
|
int batchDims;
|
||||||
|
GatherNDParamsSubset gatherArgsSubset;
|
||||||
|
std::tie(gatherArgsSubset, dPrecision, iPrecision, targetDevice) = this->GetParam();
|
||||||
|
std::tie(dataShape, indicesShape, batchDims) = gatherArgsSubset;
|
||||||
|
|
||||||
|
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
|
||||||
|
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
|
||||||
|
|
||||||
|
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
|
||||||
|
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||||
|
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||||
|
auto dataNode = paramOuts[0];
|
||||||
|
auto gather = std::dynamic_pointer_cast<ngraph::opset5::GatherND>(
|
||||||
|
ngraph::builder::makeGatherND(dataNode, indicesShape, ngIPrc, batchDims));
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
|
||||||
|
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(GatherNDLayerTest, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
}
|
||||||
|
} // namespace LayerTestsDefinitions
|
@ -23,6 +23,8 @@ inline ::ngraph::element::Type convertIE2nGraphPrc(const InferenceEngine::Precis
|
|||||||
return ::ngraph::element::Type(::ngraph::element::Type_t::f32);
|
return ::ngraph::element::Type(::ngraph::element::Type_t::f32);
|
||||||
case InferenceEngine::Precision::FP16:
|
case InferenceEngine::Precision::FP16:
|
||||||
return ::ngraph::element::Type(::ngraph::element::Type_t::f16);
|
return ::ngraph::element::Type(::ngraph::element::Type_t::f16);
|
||||||
|
case InferenceEngine::Precision::BF16:
|
||||||
|
return ::ngraph::element::Type(::ngraph::element::Type_t::bf16);
|
||||||
case InferenceEngine::Precision::U8:
|
case InferenceEngine::Precision::U8:
|
||||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
|
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
|
||||||
case InferenceEngine::Precision::I8:
|
case InferenceEngine::Precision::I8:
|
||||||
|
@ -426,6 +426,12 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
|
|||||||
bool make_sequence = false,
|
bool make_sequence = false,
|
||||||
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
|
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> makeGatherND(
|
||||||
|
const ngraph::Output<Node>& dataNode,
|
||||||
|
const ngraph::Shape& indicesShape,
|
||||||
|
const element::Type& indicesType,
|
||||||
|
const std::size_t batchDims);
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> makeTile(const ngraph::Output<Node>& in,
|
std::shared_ptr<ngraph::Node> makeTile(const ngraph::Output<Node>& in,
|
||||||
const std::vector<size_t>& repeats);
|
const std::vector<size_t>& repeats);
|
||||||
|
|
||||||
|
45
inference-engine/tests/ngraph_functions/src/gather_nd.cpp
Normal file
45
inference-engine/tests/ngraph_functions/src/gather_nd.cpp
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ngraph_functions/builders.hpp"
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
std::shared_ptr<Node> makeGatherND(
|
||||||
|
const ngraph::Output<Node>& dataNode,
|
||||||
|
const ngraph::Shape& indicesShape,
|
||||||
|
const element::Type& indicesType,
|
||||||
|
const std::size_t batchDims) {
|
||||||
|
const auto indices = [&] {
|
||||||
|
const auto& dataShape = dataNode.get_shape();
|
||||||
|
const auto indicesCount = std::accumulate(begin(indicesShape), prev(end(indicesShape)),
|
||||||
|
1ull, std::multiplies<std::size_t>{});
|
||||||
|
const auto sliceRank = indicesShape.back();
|
||||||
|
|
||||||
|
const auto maxDim = *std::max_element(begin(dataShape), end(dataShape));
|
||||||
|
|
||||||
|
auto indicesValues = NGraphFunctions::Utils::generateVector<element::Type_t::i32>(indicesCount * sliceRank, maxDim, 0);
|
||||||
|
auto indicesData = indicesValues.data();
|
||||||
|
for (int i = 0; i < indicesCount; i++) {
|
||||||
|
for (int dim = 0; dim < sliceRank; dim++) {
|
||||||
|
indicesData[0] = indicesData[0] % dataShape[dim + batchDims];
|
||||||
|
indicesData++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return opset5::Constant::create(indicesType, indicesShape, indicesValues);
|
||||||
|
}();
|
||||||
|
|
||||||
|
auto gatherNdNode = std::make_shared<opset5::GatherND>(dataNode, indices, batchDims);
|
||||||
|
gatherNdNode->set_friendly_name("GatherND");
|
||||||
|
|
||||||
|
return gatherNdNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace builder
|
||||||
|
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user