[CPU] Added dynamism support for GatherND and added GatherND-8 support (#8495)

This commit is contained in:
Alexandra Sidorova 2021-11-17 10:15:35 +03:00 committed by GitHub
parent 2518f88d7c
commit 27901a87af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 374 additions and 167 deletions

View File

@ -8,7 +8,7 @@
#include <mkldnn_types.h>
#include "ie_parallel.hpp"
#include "mkldnn_gather_nd_node.h"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <precision_utils.h>
#include <utils/general_utils.h>
#include "common/cpu_memcpy.h"
@ -16,15 +16,12 @@
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
#define THROW_ERROR IE_THROW() << "GatherND layer with name '" << getName() << "' "
bool MKLDNNGatherNDNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
if (isDynamicNgraphNode(op)) {
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto gatherElementsOp = ngraph::as_type_ptr<const ngraph::op::v5::GatherND>(op);
if (!gatherElementsOp) {
errorMessage = "Node is not an instance of the GatherND operation from operation set v5.";
if (!MKLDNNPlugin::one_of(op->get_type_info(), ngraph::op::v5::GatherND::get_type_info_static(), ngraph::op::v8::GatherND::get_type_info_static())) {
errorMessage = "Node is not an instance of the GatherND operation from operation set v5 and v8.";
return false;
}
} catch (...) {
@ -40,58 +37,36 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr<ngraph::Node>& op,
if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage;
}
_errorPrefix = std::string("Layer GatherND with name '") + op->get_friendly_name() + "'";
if (op->get_input_size() != 2 || op->get_output_size() != 1)
IE_THROW() << _errorPrefix << " has invalid number of input/output edges.";
if (inputShapes.size() != 2 && outputShapes.size() != 1)
THROW_ERROR << "has invalid number of input/output edges.";
const auto& dataDims = op->get_input_shape(_dataIndex);
const auto& indicesDims = op->get_input_shape(_indicesIndex);
const size_t inputDataRank = getInputShapeAtPort(GATHERND_DATA).getRank();
const size_t indicesDimsRank = getInputShapeAtPort(GATHERND_INDEXES).getRank();
auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v5::GatherND>(op);
_batchDims = gatherNdOp->get_batch_dims();
if (_batchDims >= std::min(dataDims.size(), indicesDims.size()))
IE_THROW() << _errorPrefix << " has invalid batch_dims attribute: " << _batchDims;
_batchNum = 1lu;
for (size_t i = 0; i < _batchDims; i++) {
_batchNum *= indicesDims[i];
if (auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v8::GatherND>(op)) {
attrs.batchDims = gatherNdOp->get_batch_dims();
} else if (auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v5::GatherND>(op)) {
attrs.batchDims = gatherNdOp->get_batch_dims();
} else {
THROW_ERROR << "has support only opset5.";
}
_sliceRank = indicesDims[indicesDims.size() - 1];
_dataRank = dataDims.size() - _batchDims;
if (_sliceRank > _dataRank)
IE_THROW() << _errorPrefix << " has invalid inputs shapes.";
_blockSize = 1;
for (size_t i = _sliceRank + _batchDims; i < dataDims.size(); i++) {
_blockSize *= dataDims[i];
}
_batchStep = 1;
for (size_t i = _batchDims; i < dataDims.size(); i++) {
_batchStep *= dataDims[i];
}
if (attrs.batchDims >= std::min(inputDataRank, indicesDimsRank))
THROW_ERROR << "has invalid batch_dims attribute: " << attrs.batchDims;
}
void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
Precision inDataPrecision = getOriginalInputPrecisionAtPort(_dataIndex);
if (!MKLDNNPlugin::one_of(inDataPrecision.size(),
sizeof(PrecisionTrait<Precision::I32>::value_type),
sizeof(PrecisionTrait<Precision::I16>::value_type),
sizeof(PrecisionTrait<Precision::I8>::value_type))) {
IE_THROW() << _errorPrefix << " has unsupported 'data' input precision: " << inDataPrecision;
}
Precision indicesPrecision = getOriginalInputPrecisionAtPort(_indicesIndex);
Precision inDataPrecision = getOriginalInputPrecisionAtPort(GATHERND_DATA);
Precision indicesPrecision = getOriginalInputPrecisionAtPort(GATHERND_INDEXES);
if (!MKLDNNPlugin::one_of(indicesPrecision,
Precision::I32, Precision::I64, Precision::I16, Precision::U16, Precision::I8, Precision::U8)) {
IE_THROW() << _errorPrefix << " has unsupported 'indices' input precision: " << indicesPrecision;
THROW_ERROR << "has unsupported 'indices' input precision: " << indicesPrecision;
}
_dataTypeSize = inDataPrecision.size();
attrs.dataSize = inDataPrecision.size();
addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision},
{LayoutType::ncsp, Precision::I32}},
@ -99,121 +74,77 @@ void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() {
impl_desc_type::ref_any);
}
template <typename dataType>
void MKLDNNGatherNDNode::gatherElementwise() {
const auto *srcData = reinterpret_cast<const dataType *>(getParentEdgeAt(_dataIndex)->getMemoryPtr()->GetPtr());
const auto *indices = reinterpret_cast<const int *>(getParentEdgeAt(_indicesIndex)->getMemoryPtr()->GetPtr());
auto *dstData = reinterpret_cast<dataType *>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
auto strides = getParentEdgeAt(_dataIndex)->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
const size_t* srcMultipliers = strides.data() + _batchDims;
const size_t cycles = getChildEdgeAt(0)->getMemory().GetShape().getElementsCount() *
getChildEdgeAt(0)->getMemory().getDesc().getPrecision().size() / (sizeof(dataType) * _batchNum);
const size_t CS = cycles * _sliceRank;
const size_t CB = cycles * _blockSize;
const size_t workAmount = _batchNum * cycles;
auto threadBody = [&](const int ithr, const int nthr) {
size_t start(0lu), end(0lu);
splitter(workAmount, nthr, ithr, start, end);
if (start >= end)
return;
size_t bStart = start / cycles;
size_t cStart = start % cycles;
size_t workCounter = start;
const dataType* shiftedSrcData = srcData + bStart * _batchStep;
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
dataType* shiftedDstData = dstData + bStart * CB + cStart * _blockSize;
for (size_t b = bStart; b < _batchNum; b++) {
for (size_t j = cStart; j < cycles; j++) {
size_t dataIdx = 0lu;
for (size_t i = 0lu; i < _sliceRank; i++)
dataIdx += srcMultipliers[i] * shiftedIndices[i];
shiftedDstData[0] = shiftedSrcData[dataIdx];
shiftedDstData++;
shiftedIndices += _sliceRank;
if (++workCounter == end) {
return;
}
}
cStart = 0lu;
shiftedSrcData += _batchStep;
}
};
parallel_nt(0, threadBody);
void MKLDNNGatherNDNode::createPrimitive() {
if (inputShapesDefined()) {
if (needPrepareParams())
prepareParams();
updateLastInputDims();
}
}
void MKLDNNGatherNDNode::gatherBlocks() {
const uint8_t* srcData = reinterpret_cast<const uint8_t *>(getParentEdgeAt(_dataIndex)->getMemoryPtr()->GetPtr());
const int* indices = reinterpret_cast<const int *>(getParentEdgeAt(_indicesIndex)->getMemoryPtr()->GetPtr());
uint8_t* dstData = reinterpret_cast<uint8_t *>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
void MKLDNNGatherNDNode::prepareParams() {
auto& srcMemPtr = getParentEdgeAt(GATHERND_DATA)->getMemoryPtr();
auto& idxMemPtr = getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr();
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
THROW_ERROR << " has not allocated input memory of 'data'.";
if (!idxMemPtr || !idxMemPtr->GetPrimitivePtr())
THROW_ERROR << " has not allocated input memory of 'indices'.";
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
THROW_ERROR << " has not allocated output memory.";
if (getSelectedPrimitiveDescriptor() == nullptr)
THROW_ERROR << " has unidentified preferable primitive descriptor.";
std::vector<size_t> srcMultipliers(_sliceRank);
for (size_t i = 0; i < _sliceRank ; i++)
srcMultipliers[i] = _dataTypeSize * getParentEdgeAt(_dataIndex)->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides()[i + _batchDims];
attrs.srcDims = srcMemPtr->getStaticDims();
attrs.srcStrides = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getStrides();
attrs.dstSize = dstMemPtr->GetSize();
attrs.sliceRank = idxMemPtr->getStaticDims().back();
execPtr = std::make_shared<GatherNDExecutor>(attrs);
}
const size_t batchStep = _batchStep * _dataTypeSize;
const size_t dataStep = _blockSize * _dataTypeSize;
const size_t cycles = getChildEdgeAt(0)->getMemory().GetSize() / (dataStep * _batchNum);
const size_t CS = cycles * _sliceRank;
const size_t CB = cycles * dataStep;
const size_t workAmount = _batchNum * cycles;
MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : attrs(attrs) {
batchSize = std::accumulate(attrs.srcDims.begin(), attrs.srcDims.begin() + attrs.batchDims, 1lu, std::multiplies<size_t>());
dataLength = std::accumulate(attrs.srcDims.begin() + attrs.sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu,
std::multiplies<size_t>()) * attrs.dataSize;
cycles = attrs.dstSize / (dataLength * batchSize);
auto threadBody = [&](const int ithr, const int nthr) {
size_t start(0lu), end(0lu);
splitter(workAmount, nthr, ithr, start, end);
if (start >= end)
return;
size_t bStart = start / cycles;
size_t cStart = start % cycles;
size_t workCounter = start;
srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, attrs.srcDims.end(), 1lu,
std::multiplies<size_t>()) * attrs.dataSize;
idxBatchStride = cycles * attrs.sliceRank;
dstBatchStride = cycles * dataLength;
const uint8_t* shiftedSrcData = srcData + bStart * batchStep;
const int* shiftedIndices = indices + bStart * CS + cStart * _sliceRank;
uint8_t* shiftedDstData = dstData + bStart * CB + cStart * dataStep;
srcShifts.resize(attrs.sliceRank, 0);
for (size_t i = 0; i < attrs.sliceRank ; i++)
srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * attrs.dataSize;
}
for (size_t b = bStart; b < _batchNum; b++) {
for (size_t j = cStart; j < cycles; j++) {
size_t dataIdx = 0lu;
for (size_t i = 0; i < _sliceRank ; i++)
dataIdx += srcMultipliers[i] * shiftedIndices[i];
cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataStep);
shiftedDstData += dataStep;
shiftedIndices += _sliceRank;
if (++workCounter == end) {
return;
}
}
cStart = 0;
shiftedSrcData += batchStep;
}
};
void MKLDNNGatherNDNode::GatherNDExecutor::exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData) {
parallel_for2d(batchSize, cycles, [&](const size_t b, const size_t j) {
const size_t srcStride = b * srcBatchStride;
const size_t idxStride = b * idxBatchStride + j * attrs.sliceRank;
const size_t dstStride = b * dstBatchStride + j * dataLength;
parallel_nt(0, threadBody);
size_t dataIdx = 0lu;
for (size_t i = 0; i < attrs.sliceRank ; ++i)
dataIdx += srcShifts[i] * indices[idxStride + i];
cpu_memcpy(&dstData[dstStride], &srcData[srcStride + dataIdx], dataLength);
});
}
void MKLDNNGatherNDNode::execute(mkldnn::stream strm) {
if (_blockSize > 1) {
gatherBlocks();
} else {
switch (_dataTypeSize) {
case sizeof(PrecisionTrait<Precision::I32>::value_type):
gatherElementwise<PrecisionTrait<Precision::I32>::value_type>();
break;
case sizeof(PrecisionTrait<Precision::I16>::value_type):
gatherElementwise<PrecisionTrait<Precision::I16>::value_type>();
break;
case sizeof(PrecisionTrait<Precision::I8>::value_type):
gatherElementwise<PrecisionTrait<Precision::I8>::value_type>();
break;
default:
IE_THROW() << _errorPrefix + " has data input with unsupported precision: " + getOriginalInputPrecisionAtPort(_dataIndex).name();
}
}
if (!execPtr)
THROW_ERROR << "has not compiled executor.";
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr()->GetPtr());
const int32_t* indices = reinterpret_cast<const int32_t*>(getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr()->GetPtr());
uint8_t* dstData = reinterpret_cast<uint8_t*>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
execPtr->exec(srcData, indices, dstData);
}
void MKLDNNGatherNDNode::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}
bool MKLDNNGatherNDNode::created() const {

View File

@ -18,27 +18,50 @@ public:
void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override;
void createPrimitive() override {};
void createPrimitive() override;
void execute(mkldnn::stream strm) override;
bool created() const override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private:
size_t _dataRank;
size_t _sliceRank;
size_t _blockSize;
size_t _batchDims;
size_t _batchNum;
size_t _batchStep;
size_t _dataTypeSize;
const size_t _dataIndex = 0;
const size_t _indicesIndex = 1;
std::string _errorPrefix;
protected:
void executeDynamicImpl(mkldnn::stream strm) override;
void prepareParams() override;
template <typename dataType>
void gatherElementwise();
void gatherBlocks();
private:
struct GatherNDAttributes {
size_t batchDims = 0lu;
size_t dataSize = 1lu;
size_t dstSize = 0lu;
size_t sliceRank = 0lu;
VectorDims srcDims;
VectorDims srcStrides;
} attrs;
struct GatherNDExecutor {
GatherNDExecutor(const GatherNDAttributes& attrs);
~GatherNDExecutor() = default;
void exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData);
private:
size_t batchSize = 1lu;
size_t cycles = 1lu;
size_t dataLength = 1lu;
size_t srcBatchStride = 1lu;
size_t idxBatchStride = 1lu;
size_t dstBatchStride = 1lu;
VectorDims srcShifts;
GatherNDAttributes attrs;
};
static constexpr size_t GATHERND_DATA = 0lu;
static constexpr size_t GATHERND_INDEXES = 1lu;
using executorPtr = std::shared_ptr<GatherNDExecutor>;
executorPtr execPtr = nullptr;
};
} // namespace MKLDNNPlugin

View File

@ -31,7 +31,8 @@ const auto gatherNDArgsSubset1 = ::testing::Combine(
{{2, 1}, {2, 1, 1}})), // Indices shape
::testing::ValuesIn(std::vector<int>({0, 1})) // Batch dims
);
INSTANTIATE_TEST_SUITE_P(smoke_Set1, GatherNDLayerTest,
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set1, GatherNDLayerTest,
::testing::Combine(
gatherNDArgsSubset1,
::testing::ValuesIn(dPrecisions),
@ -40,6 +41,15 @@ INSTANTIATE_TEST_SUITE_P(smoke_Set1, GatherNDLayerTest,
::testing::Values<Config>({})),
GatherNDLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherNDLayerTest,
::testing::Combine(
gatherNDArgsSubset1,
::testing::ValuesIn(dPrecisions),
::testing::ValuesIn(iPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values<Config>({})),
GatherNDLayerTest::getTestCaseName);
const auto gatherNDArgsSubset2 = ::testing::Combine(
::testing::ValuesIn(std::vector<std::vector<size_t>>(
{{15, 12, 20, 15, 2}, {15, 12, 18, 7, 17}})), // Data shape
@ -47,7 +57,8 @@ const auto gatherNDArgsSubset2 = ::testing::Combine(
{{15, 12, 2}, {15, 12, 5, 9, 1, 3}})), // Indices shape
::testing::ValuesIn(std::vector<int>({0, 1, 2})) // Batch dims
);
INSTANTIATE_TEST_SUITE_P(smoke_Set2, GatherNDLayerTest,
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set2, GatherNDLayerTest,
::testing::Combine(
gatherNDArgsSubset2,
::testing::ValuesIn(dPrecisions),
@ -55,4 +66,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_Set2, GatherNDLayerTest,
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values<Config>({})),
GatherNDLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherNDLayerTest,
::testing::Combine(
gatherNDArgsSubset2,
::testing::ValuesIn(dPrecisions),
::testing::ValuesIn(iPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values<Config>({})),
GatherNDLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,193 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <shared_test_classes/single_layer/gather_nd.hpp>
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ngraph_functions/builders.hpp"
using namespace InferenceEngine;
using namespace ov;
using namespace test;
namespace CPULayerTestsDefinitions {
using GatherNDLayerCPUTestParamSet = std::tuple<
InputShape, // Input shapes
std::pair<Shape, std::vector<int>>, // Indexes shape and values
ElementType, // Input element type
ElementType, // Indices element type
int // Batch dims
>;
class GatherNDLayerCPUTest : public testing::WithParamInterface<GatherNDLayerCPUTestParamSet>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<GatherNDLayerCPUTestParamSet> obj) {
InputShape shapes;
std::pair<Shape, std::vector<int>> indexes;
ElementType dataElementType, idxElementType;
int batchDims;
std::tie(shapes, indexes, dataElementType, idxElementType, batchDims) = obj.param;
std::ostringstream results;
results << "IS=" << CommonTestUtils::partialShape2str({shapes.first}) << "_";
results << "TS=";
for (const auto& item : shapes.second) {
results << CommonTestUtils::vec2str(item) << "_";
}
results << "IDXShape=" << CommonTestUtils::vec2str(indexes.first) << "_";
results << "SRCPrc=" << dataElementType << "_";
results << "IDXPrc=" << idxElementType << "_";
results << "BD=" << batchDims << "_";
return results.str();
}
protected:
void SetUp() override {
InputShape shapes;
std::pair<Shape, std::vector<int>> indexes;
ElementType dataElementType, idxElementType;
int batchDims;
std::tie(shapes, indexes, dataElementType, idxElementType, batchDims) = this->GetParam();
targetDevice = CommonTestUtils::DEVICE_CPU;
init_input_shapes({shapes});
auto params = ngraph::builder::makeDynamicParams(dataElementType, inputDynamicShapes);
auto indexes_node = ngraph::opset3::Constant::create(idxElementType, indexes.first, indexes.second);
auto gather_nd = std::make_shared<ngraph::opset5::GatherND>(params[0], indexes_node, batchDims);
ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(gather_nd)};
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
}
};
class GatherND8LayerCPUTest : public testing::WithParamInterface<GatherNDLayerCPUTestParamSet>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<GatherNDLayerCPUTestParamSet> obj) {
return GatherNDLayerCPUTest::getTestCaseName(obj);
}
protected:
void SetUp() override {
InputShape shapes;
std::pair<Shape, std::vector<int>> indexes;
ElementType dataElementType, idxElementType;
int batchDims;
std::tie(shapes, indexes, dataElementType, idxElementType, batchDims) = this->GetParam();
targetDevice = CommonTestUtils::DEVICE_CPU;
init_input_shapes({shapes});
auto params = ngraph::builder::makeDynamicParams(dataElementType, inputDynamicShapes);
auto indexes_node = ngraph::opset3::Constant::create(idxElementType, indexes.first, indexes.second);
auto gather_nd = std::make_shared<ngraph::opset8::GatherND>(params[0], indexes_node, batchDims);
ngraph::ResultVector results{std::make_shared<ngraph::opset3::Result>(gather_nd)};
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
}
};
TEST_P(GatherNDLayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
TEST_P(GatherND8LayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace {
const std::vector<ElementType> inputPrecisions = {
ElementType::f32,
ElementType::bf16,
ElementType::i8
};
const std::vector<ElementType> indexesPrecisions = {
ElementType::i32
};
const std::vector<InputShape> inputShapesDynamicBD_0 = {
{{-1, -1, -1}, // dynamic
{{5, 10, 5}, {4, 12, 4}, {4, 12, 4}, {5, 5, 5}}}, // target
{{-1, 5, -1, -1}, // dynamic
{{8, 5, 5, 5}, {5, 5, 8, 4}, {4, 5, 4, 5}}}, // target
{{{4, 10}, {5, 10}, {5, 10}, {5, 10}, {5, 10}}, // dynamic
{{4, 5, 5, 5, 5}, {4, 5, 5, 8, 5}, {10, 8, 5, 5, 5}}}, // target
};
const std::vector<std::pair<Shape, std::vector<int>>> indexesShapesBD_0 = {
std::pair<Shape, std::vector<int>>{{2, 2}, {3, 3, 2, 1}},
std::pair<Shape, std::vector<int>>{{1, 2, 3}, {0, 1, 1, 1, 0, 2}},
std::pair<Shape, std::vector<int>>{{2, 1, 1, 2}, {0, 2, 1, 1}},
};
const auto subset_BD0 = ::testing::Combine(
::testing::ValuesIn(inputShapesDynamicBD_0),
::testing::ValuesIn(indexesShapesBD_0),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(indexesPrecisions),
::testing::Values(0));
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5DynamicBD_0, GatherNDLayerCPUTest, subset_BD0, GatherNDLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8DynamicBD_0, GatherND8LayerCPUTest, subset_BD0, GatherNDLayerCPUTest::getTestCaseName);
const std::vector<InputShape> inputShapesDynamicBD_1 = {
{{3, -1, -1}, // dynamic
{{3, 10, 5}, {3, 10, 5}, {3, 12, 8}, {3, 8, 8}}}, // target
{{3, {5, 10}, {5, 10}, {5, 10}, {5, 10}}, // dynamic
{{3, 5, 5, 5, 5}, {3, 8, 10, 10, 10}, {3, 8, 6, 8, 7}}}, // target
};
const std::vector<std::pair<Shape, std::vector<int>>> indexesShapesBD_1 = {
std::pair<Shape, std::vector<int>>{{3, 2}, {0, 1, 2, 1, 0, 0}},
std::pair<Shape, std::vector<int>>{{3, 2, 2}, {0, 1, 1, 1, 0, 2, 0, 1, 1, 1, 0, 2}},
std::pair<Shape, std::vector<int>>{{3, 1, 1, 2}, {0, 2, 1, 1, 0, 2}},
};
const auto subset_BD1 = ::testing::Combine(
::testing::ValuesIn(inputShapesDynamicBD_1),
::testing::ValuesIn(indexesShapesBD_1),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(indexesPrecisions),
::testing::Values(0));
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5DynamicBD_1, GatherNDLayerCPUTest, subset_BD1, GatherNDLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8DynamicBD_1, GatherND8LayerCPUTest, subset_BD1, GatherNDLayerCPUTest::getTestCaseName);
const std::vector<InputShape> inputShapesDynamicBD_2 = {
{{2, 2, -1, -1, -1}, // dynamic
{{2, 2, 5, 6, 5}, {2, 2, 2, 3, 3}, {2, 2, 2, 3, 3}, {2, 2, 7, 2, 3}}}, // target
{{2, 2, {5, 10}, {5, 10}, {5, 10}}, // dynamic
{{2, 2, 5, 5, 5}, {2, 2, 10, 10, 5}, {2, 2, 7, 8, 7}}}, // target
};
const std::vector<std::pair<Shape, std::vector<int>>> indexesShapesBD_2 = {
std::pair<Shape, std::vector<int>>{{2, 2, 3}, {0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0}},
std::pair<Shape, std::vector<int>>{{2, 2, 2, 3}, {0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0}},
};
const auto subset_BD2 = ::testing::Combine(
::testing::ValuesIn(inputShapesDynamicBD_2),
::testing::ValuesIn(indexesShapesBD_2),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(indexesPrecisions),
::testing::Values(0));
INSTANTIATE_TEST_SUITE_P(smoke_GatherND5DynamicBD_2, GatherNDLayerCPUTest, subset_BD2, GatherNDLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8DynamicBD_2, GatherND8LayerCPUTest, subset_BD2, GatherNDLayerCPUTest::getTestCaseName);
} // namespace
} // namespace CPULayerTestsDefinitions

View File

@ -12,4 +12,8 @@ TEST_P(GatherNDLayerTest, CompareWithRefs) {
Run();
}
TEST_P(GatherND8LayerTest, CompareWithRefs) {
Run();
}
} // namespace LayerTestsDefinitions

View File

@ -36,4 +36,13 @@ protected:
void SetUp() override;
};
class GatherND8LayerTest : public testing::WithParamInterface<GatherNDParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<GatherNDParams> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -54,4 +54,30 @@ void GatherNDLayerTest::SetUp() {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
}
std::string GatherND8LayerTest::getTestCaseName(const testing::TestParamInfo<GatherNDParams>& obj) {
return GatherNDLayerTest::getTestCaseName(obj);
}
void GatherND8LayerTest::SetUp() {
InferenceEngine::SizeVector dataShape, indicesShape;
InferenceEngine::Precision dPrecision, iPrecision;
int batchDims;
GatherNDParamsSubset gatherArgsSubset;
std::tie(gatherArgsSubset, dPrecision, iPrecision, targetDevice, configuration) = this->GetParam();
std::tie(dataShape, indicesShape, batchDims) = gatherArgsSubset;
auto ngDPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(dPrecision);
auto ngIPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(iPrecision);
auto params = ngraph::builder::makeParams(ngDPrc, {dataShape});
auto paramOuts = ngraph::helpers::convert2OutputVector(
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
auto dataNode = paramOuts[0];
auto gather = std::dynamic_pointer_cast<ngraph::opset8::GatherND>(
ngraph::builder::makeGatherND(dataNode, indicesShape, ngIPrc, batchDims));
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gather)};
function = std::make_shared<ngraph::Function>(results, params, "gatherND");
}
} // namespace LayerTestsDefinitions