[CPU] GatherND implementation. (#2757)
This commit is contained in:
parent
6fec63862b
commit
257bfc9944
@ -57,6 +57,7 @@ set(LAYERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/extract_image_patches.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/fill.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_nd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
|
||||
|
230
inference-engine/src/mkldnn_plugin/nodes/gather_nd.cpp
Normal file
230
inference-engine/src/mkldnn_plugin/nodes/gather_nd.cpp
Normal file
@ -0,0 +1,230 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "base.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ie_parallel.hpp"
|
||||
#include "common/cpu_memcpy.h"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace Extensions {
|
||||
namespace Cpu {
|
||||
|
||||
class GatherNDImpl: public ExtLayerBase {
|
||||
public:
|
||||
explicit GatherNDImpl(const CNNLayer* layer) {
|
||||
_errorPrefix = std::string("Layer GatherND with name '") + layer->name + "'";
|
||||
|
||||
if (layer->insData.size() != 2 || layer->outData.size() != 1)
|
||||
THROW_IE_EXCEPTION << _errorPrefix << " has invalid number of input/output edges.";
|
||||
|
||||
auto data = layer->insData[_dataIndex].lock();
|
||||
auto indices = layer->insData[_indicesIndex].lock();
|
||||
if (!data || !indices)
|
||||
THROW_IE_EXCEPTION << _errorPrefix << " has nullable inputs.";
|
||||
Precision dataPrecision = data->getTensorDesc().getPrecision();
|
||||
if (dataPrecision.size() != sizeof(PrecisionTrait<Precision::I32>::value_type) &&
|
||||
dataPrecision.size() != sizeof(PrecisionTrait<Precision::I16>::value_type) &&
|
||||
dataPrecision.size() != sizeof(PrecisionTrait<Precision::I8>::value_type)) {
|
||||
THROW_IE_EXCEPTION << _errorPrefix << " has unsupported 'data' input precision: " << dataPrecision;
|
||||
}
|
||||
|
||||
Precision indicesPrecision = indices->getTensorDesc().getPrecision();
|
||||
if (indicesPrecision != Precision::I32 &&
|
||||
indicesPrecision != Precision::I16 && indicesPrecision != Precision::U16 &&
|
||||
indicesPrecision != Precision::I8 && indicesPrecision != Precision::U8) {
|
||||
THROW_IE_EXCEPTION << _errorPrefix << " has unsupported 'indices' input precision: " << indicesPrecision;
|
||||
}
|
||||
|
||||
_dataTypeSize = dataPrecision.size();
|
||||
const auto& dataDims = data->getTensorDesc().getDims();
|
||||
const auto& indicesDims = indices->getTensorDesc().getDims();
|
||||
|
||||
_batchDims = layer->GetParamAsInt("batch_dims", 0);
|
||||
if (_batchDims >= std::min(dataDims.size(), indicesDims.size()))
|
||||
THROW_IE_EXCEPTION << _errorPrefix << " has invalid batch_dims attribute: " << _batchDims;
|
||||
|
||||
_batchNum = 1lu;
|
||||
for (size_t i = 0; i < _batchDims; i++) {
|
||||
_batchNum *= indicesDims[i];
|
||||
}
|
||||
|
||||
_sliceRank = indicesDims[indicesDims.size() - 1];
|
||||
_dataRank = dataDims.size() - _batchDims;
|
||||
if (_sliceRank > _dataRank)
|
||||
THROW_IE_EXCEPTION << _errorPrefix << " has invalid inputs shapes.";
|
||||
|
||||
_blockSize = 1;
|
||||
for (size_t i = _sliceRank + _batchDims; i < dataDims.size(); i++) {
|
||||
_blockSize *= dataDims[i];
|
||||
}
|
||||
_batchStep = 1;
|
||||
for (size_t i = _batchDims; i < dataDims.size(); i++) {
|
||||
_batchStep *= dataDims[i];
|
||||
}
|
||||
|
||||
LayerConfig config;
|
||||
DataConfig dataConfig, indicesConfig, outConfig;
|
||||
dataConfig.desc = TensorDesc(dataPrecision, dataDims,
|
||||
data->getTensorDesc().getLayoutByDims(dataDims));
|
||||
config.inConfs.push_back(dataConfig);
|
||||
indicesConfig.desc = TensorDesc(Precision::I32, indicesDims,
|
||||
indices->getTensorDesc().getLayoutByDims(indicesDims));
|
||||
config.inConfs.push_back(indicesConfig);
|
||||
|
||||
const auto& outDims = layer->outData[0]->getTensorDesc().getDims();
|
||||
outConfig.desc = TensorDesc(dataPrecision, outDims,
|
||||
layer->outData[0]->getTensorDesc().getLayoutByDims(outDims));
|
||||
config.outConfs.push_back(outConfig);
|
||||
config.dynBatchSupport = false;
|
||||
|
||||
confs.push_back(config);
|
||||
}
|
||||
|
||||
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
||||
if (_blockSize > 1) {
|
||||
gatherBlocks(inputs, outputs, resp);
|
||||
} else {
|
||||
switch (_dataTypeSize) {
|
||||
case sizeof(PrecisionTrait<Precision::I32>::value_type):
|
||||
gatherElementwise<PrecisionTrait<Precision::I32>::value_type>(inputs, outputs, resp);
|
||||
break;
|
||||
case sizeof(PrecisionTrait<Precision::I16>::value_type):
|
||||
gatherElementwise<PrecisionTrait<Precision::I16>::value_type>(inputs, outputs, resp);
|
||||
break;
|
||||
case sizeof(PrecisionTrait<Precision::I8>::value_type):
|
||||
gatherElementwise<PrecisionTrait<Precision::I8>::value_type>(inputs, outputs, resp);
|
||||
break;
|
||||
default:
|
||||
std::string errMsg = _errorPrefix + " has data input with unsupported precision: " +
|
||||
inputs[_dataIndex]->getTensorDesc().getPrecision().name();
|
||||
errMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename dataType>
|
||||
void gatherElementwise(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept {
|
||||
const dataType* srcData = inputs[_dataIndex]->cbuffer().as<const dataType*>() +
|
||||
inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const int* indices = inputs[_indicesIndex]->cbuffer().as<const int*>() +
|
||||
inputs[_indicesIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
dataType* dstData = outputs[0]->buffer().as<dataType*>() +
|
||||
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
|
||||
const size_t* srcMultipliers = inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getStrides().data() + _batchDims;
|
||||
|
||||
const size_t cycles = outputs[0]->byteSize() / (sizeof(dataType) * _batchNum);
|
||||
const size_t CS = cycles * _sliceRank;
|
||||
const size_t CB = cycles * _blockSize;
|
||||
const size_t workAmount = _batchNum * cycles;
|
||||
|
||||
auto threadBody = [&](const int ithr, const int nthr) {
|
||||
size_t start(0lu), end(0lu);
|
||||
splitter(workAmount, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
size_t bStart = start / cycles;
|
||||
size_t cStart = start % cycles;
|
||||
size_t workCounter = start;
|
||||
|
||||
const dataType* shiftedSrcData = srcData + bStart * _batchStep;
|
||||
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
|
||||
dataType* shiftedDstData = dstData + bStart * CB + cStart * _blockSize;
|
||||
|
||||
for (size_t b = bStart; b < _batchNum; b++) {
|
||||
for (size_t j = cStart; j < cycles; j++) {
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0lu; i < _sliceRank; i++)
|
||||
dataIdx += srcMultipliers[i] * shiftedIndices[i];
|
||||
shiftedDstData[0] = shiftedSrcData[dataIdx];
|
||||
shiftedDstData++;
|
||||
shiftedIndices += _sliceRank;
|
||||
if (++workCounter == end) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cStart = 0lu;
|
||||
shiftedSrcData += _batchStep;
|
||||
}
|
||||
};
|
||||
|
||||
parallel_nt(0, threadBody);
|
||||
}
|
||||
|
||||
void gatherBlocks(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept {
|
||||
const uint8_t* srcData = inputs[_dataIndex]->cbuffer().as<const uint8_t*>() +
|
||||
inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const int* indices = inputs[_indicesIndex]->cbuffer().as<const int*>() +
|
||||
inputs[_indicesIndex]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
uint8_t* dstData = outputs[0]->buffer().as<uint8_t*>() +
|
||||
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
|
||||
std::vector<size_t> srcMultipliers(_sliceRank);
|
||||
for (size_t i = 0; i < _sliceRank ; i++)
|
||||
srcMultipliers[i] = _dataTypeSize * inputs[_dataIndex]->getTensorDesc().getBlockingDesc().getStrides()[i + _batchDims];
|
||||
|
||||
const size_t batchStep = _batchStep * _dataTypeSize;
|
||||
const size_t dataStep = _blockSize * _dataTypeSize;
|
||||
const size_t cycles = outputs[0]->byteSize() / (dataStep * _batchNum);
|
||||
const size_t CS = cycles * _sliceRank;
|
||||
const size_t CB = cycles * dataStep;
|
||||
const size_t workAmount = _batchNum * cycles;
|
||||
|
||||
auto threadBody = [&](const int ithr, const int nthr) {
|
||||
size_t start(0lu), end(0lu);
|
||||
splitter(workAmount, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
size_t bStart = start / cycles;
|
||||
size_t cStart = start % cycles;
|
||||
size_t workCounter = start;
|
||||
|
||||
const uint8_t* shiftedSrcData = srcData + bStart * batchStep;
|
||||
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
|
||||
uint8_t* shiftedDstData = dstData + bStart * CB + cStart * dataStep;
|
||||
|
||||
for (size_t b = bStart; b < _batchNum; b++) {
|
||||
for (size_t j = cStart; j < cycles; j++) {
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0; i < _sliceRank ; i++)
|
||||
dataIdx += srcMultipliers[i] * shiftedIndices[i];
|
||||
cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataStep);
|
||||
shiftedDstData += dataStep;
|
||||
shiftedIndices += _sliceRank;
|
||||
if (++workCounter == end) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cStart = 0;
|
||||
shiftedSrcData += batchStep;
|
||||
}
|
||||
};
|
||||
|
||||
parallel_nt(0, threadBody);
|
||||
}
|
||||
|
||||
size_t _dataRank;
|
||||
size_t _sliceRank;
|
||||
size_t _blockSize;
|
||||
size_t _batchDims;
|
||||
size_t _batchNum;
|
||||
size_t _batchStep;
|
||||
size_t _dataTypeSize;
|
||||
const size_t _dataIndex = 0;
|
||||
const size_t _indicesIndex = 1;
|
||||
std::string _errorPrefix;
|
||||
};
|
||||
|
||||
|
||||
REG_FACTORY_FOR(GatherNDImpl, GatherND);
|
||||
} // namespace Cpu
|
||||
} // namespace Extensions
|
||||
} // namespace InferenceEngine
|
@ -73,6 +73,7 @@ MKLDNN_EXTENSION_NODE(SparseFillEmptyRowsImpl, SparseFillEmptyRows);
|
||||
MKLDNN_EXTENSION_NODE(BucketizeImpl, Bucketize);
|
||||
MKLDNN_EXTENSION_NODE(CTCGreedyDecoderImpl, CTCGreedyDecoder);
|
||||
MKLDNN_EXTENSION_NODE(GatherImpl, Gather);
|
||||
MKLDNN_EXTENSION_NODE(GatherNDImpl, GatherND);
|
||||
MKLDNN_EXTENSION_NODE(ProposalImpl, Proposal);
|
||||
MKLDNN_EXTENSION_NODE(RangeImpl, Range);
|
||||
MKLDNN_EXTENSION_NODE(SelectImpl, Select);
|
||||
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/gather_nd.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> dPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16,
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64,
|
||||
InferenceEngine::Precision::I16,
|
||||
InferenceEngine::Precision::U8,
|
||||
InferenceEngine::Precision::I8
|
||||
};
|
||||
const std::vector<InferenceEngine::Precision> iPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64
|
||||
};
|
||||
|
||||
const auto gatherNDArgsSubset1 = ::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{{2, 2}, {2, 3, 4}})), // Data shape
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{{2, 1}, {2, 1, 1}})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({0, 1})) // Batch dims
|
||||
);
|
||||
INSTANTIATE_TEST_CASE_P(smoke_Set1, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset1,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
const auto gatherNDArgsSubset2 = ::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{{15, 12, 20, 15, 2}, {15, 12, 18, 7, 17}})), // Data shape
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{{15, 12, 2}, {15, 12, 5, 9, 1, 3}})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({0, 1, 2})) // Batch dims
|
||||
);
|
||||
INSTANTIATE_TEST_CASE_P(smoke_Set2, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset2,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,38 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<size_t>, // Data shapes
|
||||
std::vector<size_t>, // Indices shape
|
||||
int // batch dims
|
||||
> GatherNDParamsSubset;
|
||||
|
||||
typedef std::tuple<
|
||||
GatherNDParamsSubset,
|
||||
InferenceEngine::Precision, // Data precision
|
||||
InferenceEngine::Precision, // Indices precision
|
||||
LayerTestsUtils::TargetDevice // Device name
|
||||
> GatherNDParams;
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
class GatherNDLayerTest : public testing::WithParamInterface<GatherNDParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherNDParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,58 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "single_layer_tests/gather_nd.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string GatherNDLayerTest::getTestCaseName(const testing::TestParamInfo<GatherNDParams>& obj) {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
int batchDims;
|
||||
std::string device;
|
||||
GatherNDParamsSubset gatherArgsSubset;
|
||||
std::tie(gatherArgsSubset, dPrecision, iPrecision, device) = obj.param;
|
||||
std::tie(dataShape, indicesShape, batchDims) = gatherArgsSubset;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "DS=" << CommonTestUtils::vec2str(dataShape) << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(indicesShape) << "_";
|
||||
result << "BD=" << batchDims << "_";
|
||||
result << "DP=" << dPrecision.name() << "_";
|
||||
result << "IP=" << iPrecision.name() << "_";
|
||||
result << "device=" << device;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void GatherNDLayerTest::SetUp() {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
int batchDims;
|
||||
GatherNDParamsSubset gatherArgsSubset;
|
||||
std::tie(gatherArgsSubset, dPrecision, iPrecision, targetDevice) = this->GetParam();
|
||||
std::tie(dataShape, indicesShape, batchDims) = gatherArgsSubset;
|
||||
|
||||
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
|
||||
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||
auto dataNode = paramOuts[0];
|
||||
auto gather = std::dynamic_pointer_cast<ngraph::opset5::GatherND>(
|
||||
ngraph::builder::makeGatherND(dataNode, indicesShape, ngIPrc, batchDims));
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
|
||||
}
|
||||
|
||||
TEST_P(GatherNDLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
} // namespace LayerTestsDefinitions
|
@ -23,6 +23,8 @@ inline ::ngraph::element::Type convertIE2nGraphPrc(const InferenceEngine::Precis
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::f32);
|
||||
case InferenceEngine::Precision::FP16:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::f16);
|
||||
case InferenceEngine::Precision::BF16:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::bf16);
|
||||
case InferenceEngine::Precision::U8:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
|
||||
case InferenceEngine::Precision::I8:
|
||||
@ -50,4 +52,4 @@ inline ::ngraph::element::Type convertIE2nGraphPrc(const InferenceEngine::Precis
|
||||
}
|
||||
|
||||
} // namespace PrecisionUtils
|
||||
} // namespace FuncTestUtils
|
||||
} // namespace FuncTestUtils
|
||||
|
@ -426,6 +426,12 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
|
||||
bool make_sequence = false,
|
||||
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGatherND(
|
||||
const ngraph::Output<Node>& dataNode,
|
||||
const ngraph::Shape& indicesShape,
|
||||
const element::Type& indicesType,
|
||||
const std::size_t batchDims);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeTile(const ngraph::Output<Node>& in,
|
||||
const std::vector<size_t>& repeats);
|
||||
|
||||
|
45
inference-engine/tests/ngraph_functions/src/gather_nd.cpp
Normal file
45
inference-engine/tests/ngraph_functions/src/gather_nd.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<Node> makeGatherND(
|
||||
const ngraph::Output<Node>& dataNode,
|
||||
const ngraph::Shape& indicesShape,
|
||||
const element::Type& indicesType,
|
||||
const std::size_t batchDims) {
|
||||
const auto indices = [&] {
|
||||
const auto& dataShape = dataNode.get_shape();
|
||||
const auto indicesCount = std::accumulate(begin(indicesShape), prev(end(indicesShape)),
|
||||
1ull, std::multiplies<std::size_t>{});
|
||||
const auto sliceRank = indicesShape.back();
|
||||
|
||||
const auto maxDim = *std::max_element(begin(dataShape), end(dataShape));
|
||||
|
||||
auto indicesValues = NGraphFunctions::Utils::generateVector<element::Type_t::i32>(indicesCount * sliceRank, maxDim, 0);
|
||||
auto indicesData = indicesValues.data();
|
||||
for (int i = 0; i < indicesCount; i++) {
|
||||
for (int dim = 0; dim < sliceRank; dim++) {
|
||||
indicesData[0] = indicesData[0] % dataShape[dim + batchDims];
|
||||
indicesData++;
|
||||
}
|
||||
}
|
||||
return opset5::Constant::create(indicesType, indicesShape, indicesValues);
|
||||
}();
|
||||
|
||||
auto gatherNdNode = std::make_shared<opset5::GatherND>(dataNode, indices, batchDims);
|
||||
gatherNdNode->set_friendly_name("GatherND");
|
||||
|
||||
return gatherNdNode;
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user