[CPU] Added dynamism support for GatherND and added GatherND-8 support (#8495)
This commit is contained in:
parent
2518f88d7c
commit
27901a87af
@ -8,7 +8,7 @@
|
||||
#include <mkldnn_types.h>
|
||||
#include "ie_parallel.hpp"
|
||||
#include "mkldnn_gather_nd_node.h"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <precision_utils.h>
|
||||
#include <utils/general_utils.h>
|
||||
#include "common/cpu_memcpy.h"
|
||||
@ -16,15 +16,12 @@
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
#define THROW_ERROR IE_THROW() << "GatherND layer with name '" << getName() << "' "
|
||||
|
||||
bool MKLDNNGatherNDNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (isDynamicNgraphNode(op)) {
|
||||
errorMessage = "Doesn't support op with dynamic shapes";
|
||||
return false;
|
||||
}
|
||||
const auto gatherElementsOp = ngraph::as_type_ptr<const ngraph::op::v5::GatherND>(op);
|
||||
if (!gatherElementsOp) {
|
||||
errorMessage = "Node is not an instance of the GatherND operation from operation set v5.";
|
||||
if (!MKLDNNPlugin::one_of(op->get_type_info(), ngraph::op::v5::GatherND::get_type_info_static(), ngraph::op::v8::GatherND::get_type_info_static())) {
|
||||
errorMessage = "Node is not an instance of the GatherND operation from operation set v5 and v8.";
|
||||
return false;
|
||||
}
|
||||
} catch (...) {
|
||||
@ -40,58 +37,36 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr<ngraph::Node>& op,
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
IE_THROW(NotImplemented) << errorMessage;
|
||||
}
|
||||
_errorPrefix = std::string("Layer GatherND with name '") + op->get_friendly_name() + "'";
|
||||
|
||||
if (op->get_input_size() != 2 || op->get_output_size() != 1)
|
||||
IE_THROW() << _errorPrefix << " has invalid number of input/output edges.";
|
||||
if (inputShapes.size() != 2 && outputShapes.size() != 1)
|
||||
THROW_ERROR << "has invalid number of input/output edges.";
|
||||
|
||||
const auto& dataDims = op->get_input_shape(_dataIndex);
|
||||
const auto& indicesDims = op->get_input_shape(_indicesIndex);
|
||||
const size_t inputDataRank = getInputShapeAtPort(GATHERND_DATA).getRank();
|
||||
const size_t indicesDimsRank = getInputShapeAtPort(GATHERND_INDEXES).getRank();
|
||||
|
||||
auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v5::GatherND>(op);
|
||||
_batchDims = gatherNdOp->get_batch_dims();
|
||||
if (_batchDims >= std::min(dataDims.size(), indicesDims.size()))
|
||||
IE_THROW() << _errorPrefix << " has invalid batch_dims attribute: " << _batchDims;
|
||||
|
||||
_batchNum = 1lu;
|
||||
for (size_t i = 0; i < _batchDims; i++) {
|
||||
_batchNum *= indicesDims[i];
|
||||
if (auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v8::GatherND>(op)) {
|
||||
attrs.batchDims = gatherNdOp->get_batch_dims();
|
||||
} else if (auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v5::GatherND>(op)) {
|
||||
attrs.batchDims = gatherNdOp->get_batch_dims();
|
||||
} else {
|
||||
THROW_ERROR << "has support only opset5.";
|
||||
}
|
||||
|
||||
_sliceRank = indicesDims[indicesDims.size() - 1];
|
||||
_dataRank = dataDims.size() - _batchDims;
|
||||
if (_sliceRank > _dataRank)
|
||||
IE_THROW() << _errorPrefix << " has invalid inputs shapes.";
|
||||
|
||||
_blockSize = 1;
|
||||
for (size_t i = _sliceRank + _batchDims; i < dataDims.size(); i++) {
|
||||
_blockSize *= dataDims[i];
|
||||
}
|
||||
_batchStep = 1;
|
||||
for (size_t i = _batchDims; i < dataDims.size(); i++) {
|
||||
_batchStep *= dataDims[i];
|
||||
}
|
||||
if (attrs.batchDims >= std::min(inputDataRank, indicesDimsRank))
|
||||
THROW_ERROR << "has invalid batch_dims attribute: " << attrs.batchDims;
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
Precision inDataPrecision = getOriginalInputPrecisionAtPort(_dataIndex);
|
||||
if (!MKLDNNPlugin::one_of(inDataPrecision.size(),
|
||||
sizeof(PrecisionTrait<Precision::I32>::value_type),
|
||||
sizeof(PrecisionTrait<Precision::I16>::value_type),
|
||||
sizeof(PrecisionTrait<Precision::I8>::value_type))) {
|
||||
IE_THROW() << _errorPrefix << " has unsupported 'data' input precision: " << inDataPrecision;
|
||||
}
|
||||
|
||||
Precision indicesPrecision = getOriginalInputPrecisionAtPort(_indicesIndex);
|
||||
Precision inDataPrecision = getOriginalInputPrecisionAtPort(GATHERND_DATA);
|
||||
Precision indicesPrecision = getOriginalInputPrecisionAtPort(GATHERND_INDEXES);
|
||||
if (!MKLDNNPlugin::one_of(indicesPrecision,
|
||||
Precision::I32, Precision::I64, Precision::I16, Precision::U16, Precision::I8, Precision::U8)) {
|
||||
IE_THROW() << _errorPrefix << " has unsupported 'indices' input precision: " << indicesPrecision;
|
||||
THROW_ERROR << "has unsupported 'indices' input precision: " << indicesPrecision;
|
||||
}
|
||||
|
||||
_dataTypeSize = inDataPrecision.size();
|
||||
attrs.dataSize = inDataPrecision.size();
|
||||
|
||||
addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision},
|
||||
{LayoutType::ncsp, Precision::I32}},
|
||||
@ -99,121 +74,77 @@ void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() {
|
||||
impl_desc_type::ref_any);
|
||||
}
|
||||
|
||||
template <typename dataType>
|
||||
void MKLDNNGatherNDNode::gatherElementwise() {
|
||||
const auto *srcData = reinterpret_cast<const dataType *>(getParentEdgeAt(_dataIndex)->getMemoryPtr()->GetPtr());
|
||||
const auto *indices = reinterpret_cast<const int *>(getParentEdgeAt(_indicesIndex)->getMemoryPtr()->GetPtr());
|
||||
auto *dstData = reinterpret_cast<dataType *>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
|
||||
auto strides = getParentEdgeAt(_dataIndex)->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
const size_t* srcMultipliers = strides.data() + _batchDims;
|
||||
|
||||
const size_t cycles = getChildEdgeAt(0)->getMemory().GetShape().getElementsCount() *
|
||||
getChildEdgeAt(0)->getMemory().getDesc().getPrecision().size() / (sizeof(dataType) * _batchNum);
|
||||
const size_t CS = cycles * _sliceRank;
|
||||
const size_t CB = cycles * _blockSize;
|
||||
const size_t workAmount = _batchNum * cycles;
|
||||
|
||||
auto threadBody = [&](const int ithr, const int nthr) {
|
||||
size_t start(0lu), end(0lu);
|
||||
splitter(workAmount, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
size_t bStart = start / cycles;
|
||||
size_t cStart = start % cycles;
|
||||
size_t workCounter = start;
|
||||
|
||||
const dataType* shiftedSrcData = srcData + bStart * _batchStep;
|
||||
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
|
||||
dataType* shiftedDstData = dstData + bStart * CB + cStart * _blockSize;
|
||||
|
||||
for (size_t b = bStart; b < _batchNum; b++) {
|
||||
for (size_t j = cStart; j < cycles; j++) {
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0lu; i < _sliceRank; i++)
|
||||
dataIdx += srcMultipliers[i] * shiftedIndices[i];
|
||||
shiftedDstData[0] = shiftedSrcData[dataIdx];
|
||||
shiftedDstData++;
|
||||
shiftedIndices += _sliceRank;
|
||||
if (++workCounter == end) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cStart = 0lu;
|
||||
shiftedSrcData += _batchStep;
|
||||
}
|
||||
};
|
||||
|
||||
parallel_nt(0, threadBody);
|
||||
void MKLDNNGatherNDNode::createPrimitive() {
|
||||
if (inputShapesDefined()) {
|
||||
if (needPrepareParams())
|
||||
prepareParams();
|
||||
updateLastInputDims();
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::gatherBlocks() {
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t *>(getParentEdgeAt(_dataIndex)->getMemoryPtr()->GetPtr());
|
||||
const int* indices = reinterpret_cast<const int *>(getParentEdgeAt(_indicesIndex)->getMemoryPtr()->GetPtr());
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t *>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
void MKLDNNGatherNDNode::prepareParams() {
|
||||
auto& srcMemPtr = getParentEdgeAt(GATHERND_DATA)->getMemoryPtr();
|
||||
auto& idxMemPtr = getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr();
|
||||
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
|
||||
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
|
||||
THROW_ERROR << " has not allocated input memory of 'data'.";
|
||||
if (!idxMemPtr || !idxMemPtr->GetPrimitivePtr())
|
||||
THROW_ERROR << " has not allocated input memory of 'indices'.";
|
||||
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
|
||||
THROW_ERROR << " has not allocated output memory.";
|
||||
if (getSelectedPrimitiveDescriptor() == nullptr)
|
||||
THROW_ERROR << " has unidentified preferable primitive descriptor.";
|
||||
|
||||
std::vector<size_t> srcMultipliers(_sliceRank);
|
||||
for (size_t i = 0; i < _sliceRank ; i++)
|
||||
srcMultipliers[i] = _dataTypeSize * getParentEdgeAt(_dataIndex)->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides()[i + _batchDims];
|
||||
attrs.srcDims = srcMemPtr->getStaticDims();
|
||||
attrs.srcStrides = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
attrs.dstSize = dstMemPtr->GetSize();
|
||||
attrs.sliceRank = idxMemPtr->getStaticDims().back();
|
||||
execPtr = std::make_shared<GatherNDExecutor>(attrs);
|
||||
}
|
||||
|
||||
const size_t batchStep = _batchStep * _dataTypeSize;
|
||||
const size_t dataStep = _blockSize * _dataTypeSize;
|
||||
const size_t cycles = getChildEdgeAt(0)->getMemory().GetSize() / (dataStep * _batchNum);
|
||||
const size_t CS = cycles * _sliceRank;
|
||||
const size_t CB = cycles * dataStep;
|
||||
const size_t workAmount = _batchNum * cycles;
|
||||
MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : attrs(attrs) {
|
||||
batchSize = std::accumulate(attrs.srcDims.begin(), attrs.srcDims.begin() + attrs.batchDims, 1lu, std::multiplies<size_t>());
|
||||
dataLength = std::accumulate(attrs.srcDims.begin() + attrs.sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu,
|
||||
std::multiplies<size_t>()) * attrs.dataSize;
|
||||
cycles = attrs.dstSize / (dataLength * batchSize);
|
||||
|
||||
auto threadBody = [&](const int ithr, const int nthr) {
|
||||
size_t start(0lu), end(0lu);
|
||||
splitter(workAmount, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
size_t bStart = start / cycles;
|
||||
size_t cStart = start % cycles;
|
||||
size_t workCounter = start;
|
||||
srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, attrs.srcDims.end(), 1lu,
|
||||
std::multiplies<size_t>()) * attrs.dataSize;
|
||||
idxBatchStride = cycles * attrs.sliceRank;
|
||||
dstBatchStride = cycles * dataLength;
|
||||
|
||||
const uint8_t* shiftedSrcData = srcData + bStart * batchStep;
|
||||
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
|
||||
uint8_t* shiftedDstData = dstData + bStart * CB + cStart * dataStep;
|
||||
srcShifts.resize(attrs.sliceRank, 0);
|
||||
for (size_t i = 0; i < attrs.sliceRank ; i++)
|
||||
srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * attrs.dataSize;
|
||||
}
|
||||
|
||||
for (size_t b = bStart; b < _batchNum; b++) {
|
||||
for (size_t j = cStart; j < cycles; j++) {
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0; i < _sliceRank ; i++)
|
||||
dataIdx += srcMultipliers[i] * shiftedIndices[i];
|
||||
cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataStep);
|
||||
shiftedDstData += dataStep;
|
||||
shiftedIndices += _sliceRank;
|
||||
if (++workCounter == end) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cStart = 0;
|
||||
shiftedSrcData += batchStep;
|
||||
}
|
||||
};
|
||||
void MKLDNNGatherNDNode::GatherNDExecutor::exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData) {
|
||||
parallel_for2d(batchSize, cycles, [&](const size_t b, const size_t j) {
|
||||
const size_t srcStride = b * srcBatchStride;
|
||||
const size_t idxStride = b * idxBatchStride + j * attrs.sliceRank;
|
||||
const size_t dstStride = b * dstBatchStride + j * dataLength;
|
||||
|
||||
parallel_nt(0, threadBody);
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0; i < attrs.sliceRank ; ++i)
|
||||
dataIdx += srcShifts[i] * indices[idxStride + i];
|
||||
|
||||
cpu_memcpy(&dstData[dstStride], &srcData[srcStride + dataIdx], dataLength);
|
||||
});
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::execute(mkldnn::stream strm) {
|
||||
if (_blockSize > 1) {
|
||||
gatherBlocks();
|
||||
} else {
|
||||
switch (_dataTypeSize) {
|
||||
case sizeof(PrecisionTrait<Precision::I32>::value_type):
|
||||
gatherElementwise<PrecisionTrait<Precision::I32>::value_type>();
|
||||
break;
|
||||
case sizeof(PrecisionTrait<Precision::I16>::value_type):
|
||||
gatherElementwise<PrecisionTrait<Precision::I16>::value_type>();
|
||||
break;
|
||||
case sizeof(PrecisionTrait<Precision::I8>::value_type):
|
||||
gatherElementwise<PrecisionTrait<Precision::I8>::value_type>();
|
||||
break;
|
||||
default:
|
||||
IE_THROW() << _errorPrefix + " has data input with unsupported precision: " + getOriginalInputPrecisionAtPort(_dataIndex).name();
|
||||
}
|
||||
}
|
||||
if (!execPtr)
|
||||
THROW_ERROR << "has not compiled executor.";
|
||||
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr()->GetPtr());
|
||||
const int32_t* indices = reinterpret_cast<const int32_t*>(getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr()->GetPtr());
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t*>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
|
||||
execPtr->exec(srcData, indices, dstData);
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::executeDynamicImpl(dnnl::stream strm) {
|
||||
execute(strm);
|
||||
}
|
||||
|
||||
bool MKLDNNGatherNDNode::created() const {
|
||||
|
@ -18,27 +18,50 @@ public:
|
||||
|
||||
void getSupportedDescriptors() override {};
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void createPrimitive() override {};
|
||||
void createPrimitive() override;
|
||||
void execute(mkldnn::stream strm) override;
|
||||
bool created() const override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
size_t _dataRank;
|
||||
size_t _sliceRank;
|
||||
size_t _blockSize;
|
||||
size_t _batchDims;
|
||||
size_t _batchNum;
|
||||
size_t _batchStep;
|
||||
size_t _dataTypeSize;
|
||||
const size_t _dataIndex = 0;
|
||||
const size_t _indicesIndex = 1;
|
||||
std::string _errorPrefix;
|
||||
protected:
|
||||
void executeDynamicImpl(mkldnn::stream strm) override;
|
||||
void prepareParams() override;
|
||||
|
||||
template <typename dataType>
|
||||
void gatherElementwise();
|
||||
void gatherBlocks();
|
||||
private:
|
||||
struct GatherNDAttributes {
|
||||
size_t batchDims = 0lu;
|
||||
size_t dataSize = 1lu;
|
||||
size_t dstSize = 0lu;
|
||||
size_t sliceRank = 0lu;
|
||||
|
||||
VectorDims srcDims;
|
||||
VectorDims srcStrides;
|
||||
} attrs;
|
||||
|
||||
struct GatherNDExecutor {
|
||||
GatherNDExecutor(const GatherNDAttributes& attrs);
|
||||
~GatherNDExecutor() = default;
|
||||
void exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData);
|
||||
|
||||
private:
|
||||
size_t batchSize = 1lu;
|
||||
size_t cycles = 1lu;
|
||||
size_t dataLength = 1lu;
|
||||
|
||||
size_t srcBatchStride = 1lu;
|
||||
size_t idxBatchStride = 1lu;
|
||||
size_t dstBatchStride = 1lu;
|
||||
VectorDims srcShifts;
|
||||
|
||||
GatherNDAttributes attrs;
|
||||
};
|
||||
|
||||
static constexpr size_t GATHERND_DATA = 0lu;
|
||||
static constexpr size_t GATHERND_INDEXES = 1lu;
|
||||
|
||||
using executorPtr = std::shared_ptr<GatherNDExecutor>;
|
||||
executorPtr execPtr = nullptr;
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -31,7 +31,8 @@ const auto gatherNDArgsSubset1 = ::testing::Combine(
|
||||
{{2, 1}, {2, 1, 1}})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({0, 1})) // Batch dims
|
||||
);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Set1, GatherNDLayerTest,
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set1, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset1,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
@ -40,6 +41,15 @@ INSTANTIATE_TEST_SUITE_P(smoke_Set1, GatherNDLayerTest,
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset1,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
const auto gatherNDArgsSubset2 = ::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<size_t>>(
|
||||
{{15, 12, 20, 15, 2}, {15, 12, 18, 7, 17}})), // Data shape
|
||||
@ -47,7 +57,8 @@ const auto gatherNDArgsSubset2 = ::testing::Combine(
|
||||
{{15, 12, 2}, {15, 12, 5, 9, 1, 3}})), // Indices shape
|
||||
::testing::ValuesIn(std::vector<int>({0, 1, 2})) // Batch dims
|
||||
);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Set2, GatherNDLayerTest,
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set2, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset2,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
@ -55,4 +66,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_Set2, GatherNDLayerTest,
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherNDLayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset2,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
@ -0,0 +1,193 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <shared_test_classes/single_layer/gather_nd.hpp>
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace ov;
|
||||
using namespace test;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
using GatherNDLayerCPUTestParamSet = std::tuple<
|
||||
InputShape, // Input shapes
|
||||
std::pair<Shape, std::vector<int>>, // Indexes shape and values
|
||||
ElementType, // Input element type
|
||||
ElementType, // Indices element type
|
||||
int // Batch dims
|
||||
>;
|
||||
|
||||
class GatherNDLayerCPUTest : public testing::WithParamInterface<GatherNDLayerCPUTestParamSet>,
|
||||
virtual public SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<GatherNDLayerCPUTestParamSet> obj) {
|
||||
InputShape shapes;
|
||||
std::pair<Shape, std::vector<int>> indexes;
|
||||
ElementType dataElementType, idxElementType;
|
||||
int batchDims;
|
||||
std::tie(shapes, indexes, dataElementType, idxElementType, batchDims) = obj.param;
|
||||
|
||||
std::ostringstream results;
|
||||
results << "IS=" << CommonTestUtils::partialShape2str({shapes.first}) << "_";
|
||||
results << "TS=";
|
||||
for (const auto& item : shapes.second) {
|
||||
results << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
results << "IDXShape=" << CommonTestUtils::vec2str(indexes.first) << "_";
|
||||
results << "SRCPrc=" << dataElementType << "_";
|
||||
results << "IDXPrc=" << idxElementType << "_";
|
||||
results << "BD=" << batchDims << "_";
|
||||
|
||||
return results.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InputShape shapes;
|
||||
std::pair<Shape, std::vector<int>> indexes;
|
||||
ElementType dataElementType, idxElementType;
|
||||
int batchDims;
|
||||
std::tie(shapes, indexes, dataElementType, idxElementType, batchDims) = this->GetParam();
|
||||
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
init_input_shapes({shapes});
|
||||
|
||||
auto params = ngraph::builder::makeDynamicParams(dataElementType, inputDynamicShapes);
|
||||
auto indexes_node = ngraph::opset3::Constant::create(idxElementType, indexes.first, indexes.second);
|
||||
auto gather_nd = std::make_shared<ngraph::opset5::GatherND>(params[0], indexes_node, batchDims);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(gather_nd)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
|
||||
}
|
||||
};
|
||||
|
||||
class GatherND8LayerCPUTest : public testing::WithParamInterface<GatherNDLayerCPUTestParamSet>,
|
||||
virtual public SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<GatherNDLayerCPUTestParamSet> obj) {
|
||||
return GatherNDLayerCPUTest::getTestCaseName(obj);
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InputShape shapes;
|
||||
std::pair<Shape, std::vector<int>> indexes;
|
||||
ElementType dataElementType, idxElementType;
|
||||
int batchDims;
|
||||
std::tie(shapes, indexes, dataElementType, idxElementType, batchDims) = this->GetParam();
|
||||
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
init_input_shapes({shapes});
|
||||
|
||||
auto params = ngraph::builder::makeDynamicParams(dataElementType, inputDynamicShapes);
|
||||
auto indexes_node = ngraph::opset3::Constant::create(idxElementType, indexes.first, indexes.second);
|
||||
auto gather_nd = std::make_shared<ngraph::opset8::GatherND>(params[0], indexes_node, batchDims);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(gather_nd)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(GatherNDLayerCPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_P(GatherND8LayerCPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<ElementType> inputPrecisions = {
|
||||
ElementType::f32,
|
||||
ElementType::bf16,
|
||||
ElementType::i8
|
||||
};
|
||||
|
||||
const std::vector<ElementType> indexesPrecisions = {
|
||||
ElementType::i32
|
||||
};
|
||||
|
||||
const std::vector<InputShape> inputShapesDynamicBD_0 = {
|
||||
{{-1, -1, -1}, // dynamic
|
||||
{{5, 10, 5}, {4, 12, 4}, {4, 12, 4}, {5, 5, 5}}}, // target
|
||||
|
||||
{{-1, 5, -1, -1}, // dynamic
|
||||
{{8, 5, 5, 5}, {5, 5, 8, 4}, {4, 5, 4, 5}}}, // target
|
||||
|
||||
{{{4, 10}, {5, 10}, {5, 10}, {5, 10}, {5, 10}}, // dynamic
|
||||
{{4, 5, 5, 5, 5}, {4, 5, 5, 8, 5}, {10, 8, 5, 5, 5}}}, // target
|
||||
};
|
||||
|
||||
const std::vector<std::pair<Shape, std::vector<int>>> indexesShapesBD_0 = {
|
||||
std::pair<Shape, std::vector<int>>{{2, 2}, {3, 3, 2, 1}},
|
||||
std::pair<Shape, std::vector<int>>{{1, 2, 3}, {0, 1, 1, 1, 0, 2}},
|
||||
std::pair<Shape, std::vector<int>>{{2, 1, 1, 2}, {0, 2, 1, 1}},
|
||||
};
|
||||
|
||||
const auto subset_BD0 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesDynamicBD_0),
|
||||
::testing::ValuesIn(indexesShapesBD_0),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(indexesPrecisions),
|
||||
::testing::Values(0));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5DynamicBD_0, GatherNDLayerCPUTest, subset_BD0, GatherNDLayerCPUTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8DynamicBD_0, GatherND8LayerCPUTest, subset_BD0, GatherNDLayerCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<InputShape> inputShapesDynamicBD_1 = {
|
||||
{{3, -1, -1}, // dynamic
|
||||
{{3, 10, 5}, {3, 10, 5}, {3, 12, 8}, {3, 8, 8}}}, // target
|
||||
|
||||
{{3, {5, 10}, {5, 10}, {5, 10}, {5, 10}}, // dynamic
|
||||
{{3, 5, 5, 5, 5}, {3, 8, 10, 10, 10}, {3, 8, 6, 8, 7}}}, // target
|
||||
};
|
||||
|
||||
const std::vector<std::pair<Shape, std::vector<int>>> indexesShapesBD_1 = {
|
||||
std::pair<Shape, std::vector<int>>{{3, 2}, {0, 1, 2, 1, 0, 0}},
|
||||
std::pair<Shape, std::vector<int>>{{3, 2, 2}, {0, 1, 1, 1, 0, 2, 0, 1, 1, 1, 0, 2}},
|
||||
std::pair<Shape, std::vector<int>>{{3, 1, 1, 2}, {0, 2, 1, 1, 0, 2}},
|
||||
};
|
||||
|
||||
const auto subset_BD1 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesDynamicBD_1),
|
||||
::testing::ValuesIn(indexesShapesBD_1),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(indexesPrecisions),
|
||||
::testing::Values(0));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5DynamicBD_1, GatherNDLayerCPUTest, subset_BD1, GatherNDLayerCPUTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8DynamicBD_1, GatherND8LayerCPUTest, subset_BD1, GatherNDLayerCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<InputShape> inputShapesDynamicBD_2 = {
|
||||
{{2, 2, -1, -1, -1}, // dynamic
|
||||
{{2, 2, 5, 6, 5}, {2, 2, 2, 3, 3}, {2, 2, 2, 3, 3}, {2, 2, 7, 2, 3}}}, // target
|
||||
|
||||
{{2, 2, {5, 10}, {5, 10}, {5, 10}}, // dynamic
|
||||
{{2, 2, 5, 5, 5}, {2, 2, 10, 10, 5}, {2, 2, 7, 8, 7}}}, // target
|
||||
};
|
||||
|
||||
const std::vector<std::pair<Shape, std::vector<int>>> indexesShapesBD_2 = {
|
||||
std::pair<Shape, std::vector<int>>{{2, 2, 3}, {0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0}},
|
||||
std::pair<Shape, std::vector<int>>{{2, 2, 2, 3}, {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,
|
||||
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0}},
|
||||
};
|
||||
|
||||
const auto subset_BD2 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesDynamicBD_2),
|
||||
::testing::ValuesIn(indexesShapesBD_2),
|
||||
::testing::ValuesIn(inputPrecisions),
|
||||
::testing::ValuesIn(indexesPrecisions),
|
||||
::testing::Values(0));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5DynamicBD_2, GatherNDLayerCPUTest, subset_BD2, GatherNDLayerCPUTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8DynamicBD_2, GatherND8LayerCPUTest, subset_BD2, GatherNDLayerCPUTest::getTestCaseName);
|
||||
|
||||
|
||||
} // namespace
|
||||
} // namespace CPULayerTestsDefinitions
|
@ -12,4 +12,8 @@ TEST_P(GatherNDLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(GatherND8LayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -36,4 +36,13 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class GatherND8LayerTest : public testing::WithParamInterface<GatherNDParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherNDParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -54,4 +54,30 @@ void GatherNDLayerTest::SetUp() {
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
|
||||
}
|
||||
|
||||
std::string GatherND8LayerTest::getTestCaseName(const testing::TestParamInfo<GatherNDParams>& obj) {
|
||||
return GatherNDLayerTest::getTestCaseName(obj);
|
||||
}
|
||||
|
||||
void GatherND8LayerTest::SetUp() {
|
||||
InferenceEngine::SizeVector dataShape, indicesShape;
|
||||
InferenceEngine::Precision dPrecision, iPrecision;
|
||||
int batchDims;
|
||||
GatherNDParamsSubset gatherArgsSubset;
|
||||
std::tie(gatherArgsSubset, dPrecision, iPrecision, targetDevice, configuration) = this->GetParam();
|
||||
std::tie(dataShape, indicesShape, batchDims) = gatherArgsSubset;
|
||||
|
||||
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
|
||||
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||
auto dataNode = paramOuts[0];
|
||||
auto gather = std::dynamic_pointer_cast<ngraph::opset8::GatherND>(
|
||||
ngraph::builder::makeGatherND(dataNode, indicesShape, ngIPrc, batchDims));
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
|
||||
}
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user