[CPU] Fixed leftovers for GatherND (#8644)

This commit is contained in:
Alexandra Sidorova 2021-11-23 15:11:58 +03:00 committed by GitHub
parent cfe33fdf08
commit 8c55c761c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 138 additions and 34 deletions

View File

@ -41,8 +41,8 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr<ngraph::Node>& op,
if (inputShapes.size() != 2 && outputShapes.size() != 1)
THROW_ERROR << "has invalid number of input/output edges.";
const size_t inputDataRank = getInputShapeAtPort(GATHERND_DATA).getRank();
const size_t indicesDimsRank = getInputShapeAtPort(GATHERND_INDEXES).getRank();
const size_t dataInputRank = getInputShapeAtPort(GATHERND_DATA).getRank();
const size_t indicesInputRank = getInputShapeAtPort(GATHERND_INDEXES).getRank();
if (auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v8::GatherND>(op)) {
attrs.batchDims = gatherNdOp->get_batch_dims();
@ -51,7 +51,7 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr<ngraph::Node>& op,
} else {
THROW_ERROR << "has support only opset5.";
}
if (attrs.batchDims >= std::min(inputDataRank, indicesDimsRank))
if (attrs.batchDims >= std::min(dataInputRank, indicesInputRank))
THROW_ERROR << "has invalid batch_dims attribute: " << attrs.batchDims;
}
@ -60,12 +60,19 @@ void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() {
return;
Precision inDataPrecision = getOriginalInputPrecisionAtPort(GATHERND_DATA);
if (!MKLDNNPlugin::one_of(inDataPrecision.size(),
sizeof(PrecisionTrait<Precision::I32>::value_type),
sizeof(PrecisionTrait<Precision::I16>::value_type),
sizeof(PrecisionTrait<Precision::I8>::value_type))) {
THROW_ERROR << "has unsupported 'data' input precision: " << inDataPrecision;
}
attrs.dataSize = inDataPrecision.size();
Precision indicesPrecision = getOriginalInputPrecisionAtPort(GATHERND_INDEXES);
if (!MKLDNNPlugin::one_of(indicesPrecision,
Precision::I32, Precision::I64, Precision::I16, Precision::U16, Precision::I8, Precision::U8)) {
THROW_ERROR << "has unsupported 'indices' input precision: " << indicesPrecision;
}
attrs.dataSize = inDataPrecision.size();
addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision},
{LayoutType::ncsp, Precision::I32}},
@ -96,50 +103,128 @@ void MKLDNNGatherNDNode::prepareParams() {
attrs.srcDims = srcMemPtr->getStaticDims();
attrs.srcStrides = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getStrides();
attrs.dstSize = dstMemPtr->GetSize();
attrs.dstElementCount = dstMemPtr->GetShape().getElementsCount();
attrs.sliceRank = idxMemPtr->getStaticDims().back();
execPtr = std::make_shared<GatherNDExecutor>(attrs);
}
MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : attrs(attrs) {
MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : dataSize(attrs.dataSize), sliceRank(attrs.sliceRank) {
batchSize = std::accumulate(attrs.srcDims.begin(), attrs.srcDims.begin() + attrs.batchDims, 1lu, std::multiplies<size_t>());
dataLength = std::accumulate(attrs.srcDims.begin() + attrs.sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu,
std::multiplies<size_t>()) * attrs.dataSize;
cycles = attrs.dstSize / (dataLength * batchSize);
dataLength = std::accumulate(attrs.srcDims.begin() + sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu,
std::multiplies<size_t>());
cycles = attrs.dstElementCount / (dataLength * batchSize);
workAmount = batchSize * cycles;
srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, attrs.srcDims.end(), 1lu,
std::multiplies<size_t>()) * attrs.dataSize;
idxBatchStride = cycles * attrs.sliceRank;
std::multiplies<size_t>());
idxBatchStride = cycles * sliceRank;
dstBatchStride = cycles * dataLength;
srcShifts.resize(attrs.sliceRank, 0);
for (size_t i = 0; i < attrs.sliceRank ; i++)
srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * attrs.dataSize;
}
srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * (dataLength > 1 ? dataSize : 1);
void MKLDNNGatherNDNode::GatherNDExecutor::exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData) {
parallel_for2d(batchSize, cycles, [&](const size_t b, const size_t j) {
const size_t srcStride = b * srcBatchStride;
const size_t idxStride = b * idxBatchStride + j * attrs.sliceRank;
const size_t dstStride = b * dstBatchStride + j * dataLength;
size_t dataIdx = 0lu;
for (size_t i = 0; i < attrs.sliceRank ; ++i)
dataIdx += srcShifts[i] * indices[idxStride + i];
cpu_memcpy(&dstData[dstStride], &srcData[srcStride + dataIdx], dataLength);
});
// optimized implementation 'blocks' via memcpy
if (dataLength > 1) {
dataLength *= dataSize;
srcBatchStride *= dataSize;
dstBatchStride *= dataSize;
}
}
void MKLDNNGatherNDNode::execute(mkldnn::stream strm) {
if (!execPtr)
THROW_ERROR << "has not compiled executor.";
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr()->GetPtr());
const int32_t* indices = reinterpret_cast<const int32_t*>(getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr()->GetPtr());
uint8_t* dstData = reinterpret_cast<uint8_t*>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
execPtr->exec(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr(),
getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr(),
getChildEdgeAt(0)->getMemoryPtr());
}
execPtr->exec(srcData, indices, dstData);
void MKLDNNGatherNDNode::GatherNDExecutor::exec(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
if (dataLength > 1) {
gatherBlocks(srcMemPtr, idxMemPtr, dstMemPtr);
return;
}
GatherNDContext ctx { this, srcMemPtr, idxMemPtr, dstMemPtr };
OV_SWITCH(MKLDNNPlugin, GatherNDEmitter, ctx, dataSize,
OV_CASE(sizeof(PrecisionTrait<Precision::I32>::value_type), PrecisionTrait<Precision::I32>::value_type),
OV_CASE(sizeof(PrecisionTrait<Precision::I16>::value_type), PrecisionTrait<Precision::I16>::value_type),
OV_CASE(sizeof(PrecisionTrait<Precision::I8>::value_type), PrecisionTrait<Precision::I8>::value_type));
}
void MKLDNNGatherNDNode::GatherNDExecutor::gatherBlocks(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(srcMemPtr->GetPtr());
const int32_t* indices = reinterpret_cast<const int32_t*>(idxMemPtr->GetPtr());
uint8_t* dstData = reinterpret_cast<uint8_t*>(dstMemPtr->GetPtr());
parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start(0lu), end(0lu);
splitter(workAmount, nthr, ithr, start, end);
if (start >= end)
return;
size_t bStart = start / cycles;
size_t cStart = start % cycles;
size_t workCounter = start;
const uint8_t* shiftedSrcData = srcData + bStart * srcBatchStride;
const int32_t* shiftedIndices = indices + bStart * idxBatchStride + cStart * sliceRank;
uint8_t* shiftedDstData = dstData + bStart * dstBatchStride + cStart * dataLength;
for (size_t b = bStart; b < batchSize; b++) {
for (size_t j = cStart; j < cycles; j++) {
size_t dataIdx = 0lu;
for (size_t i = 0; i < sliceRank; i++)
dataIdx += srcShifts[i] * shiftedIndices[i];
cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataLength);
shiftedDstData += dataLength;
shiftedIndices += sliceRank;
if (++workCounter == end) {
return;
}
}
cStart = 0;
shiftedSrcData += srcBatchStride;
}
});
}
template <typename dataType>
void MKLDNNGatherNDNode::GatherNDExecutor::gatherElementwise(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
const dataType* srcData = reinterpret_cast<const dataType*>(srcMemPtr->GetPtr());
const int32_t* indices = reinterpret_cast<const int32_t*>(idxMemPtr->GetPtr());
dataType* dstData = reinterpret_cast<dataType*>(dstMemPtr->GetPtr());
parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start(0lu), end(0lu);
splitter(workAmount, nthr, ithr, start, end);
if (start >= end)
return;
size_t bStart = start / cycles;
size_t cStart = start % cycles;
size_t workCounter = start;
const dataType* shiftedSrcData = srcData + bStart * srcBatchStride;
const int32_t* shiftedIndices = indices + bStart * idxBatchStride + cStart * sliceRank;
dataType* shiftedDstData = dstData + bStart * dstBatchStride + cStart * dataLength;
for (size_t b = bStart; b < batchSize; b++) {
for (size_t j = cStart; j < cycles; j++) {
size_t dataIdx = 0lu;
for (size_t i = 0lu; i < sliceRank; i++)
dataIdx += srcShifts[i] * shiftedIndices[i];
shiftedDstData[0] = shiftedSrcData[dataIdx];
shiftedDstData++;
shiftedIndices += sliceRank;
if (++workCounter == end) {
return;
}
}
cStart = 0lu;
shiftedSrcData += srcBatchStride;
}
});
}
void MKLDNNGatherNDNode::executeDynamicImpl(dnnl::stream strm) {

View File

@ -32,7 +32,7 @@ private:
struct GatherNDAttributes {
size_t batchDims = 0lu;
size_t dataSize = 1lu;
size_t dstSize = 0lu;
size_t dstElementCount = 0lu;
size_t sliceRank = 0lu;
VectorDims srcDims;
@ -42,19 +42,38 @@ private:
struct GatherNDExecutor {
GatherNDExecutor(const GatherNDAttributes& attrs);
~GatherNDExecutor() = default;
void exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData);
void exec(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr);
private:
template <typename dataType>
void gatherElementwise(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr);
void gatherBlocks(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr);
size_t batchSize = 1lu;
size_t cycles = 1lu;
size_t dataLength = 1lu;
size_t sliceRank = 0lu;
size_t workAmount = 0lu;
size_t dataSize = 1lu;
size_t srcBatchStride = 1lu;
size_t idxBatchStride = 1lu;
size_t dstBatchStride = 1lu;
VectorDims srcShifts;
GatherNDAttributes attrs;
struct GatherNDContext {
GatherNDExecutor* executor;
const MKLDNNMemoryPtr srcMemPtr;
const MKLDNNMemoryPtr idxMemPtr;
MKLDNNMemoryPtr dstMemPtr;
};
template<typename T>
struct GatherNDEmitter {
void operator()(GatherNDContext& ctx) {
ctx.executor->gatherElementwise<T>(ctx.srcMemPtr, ctx.idxMemPtr, ctx.dstMemPtr);
}
};
};
static constexpr size_t GATHERND_DATA = 0lu;

View File

@ -41,7 +41,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set1, GatherNDLayerTest,
::testing::Values<Config>({})),
GatherNDLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherNDLayerTest,
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherND8LayerTest,
::testing::Combine(
gatherNDArgsSubset1,
::testing::ValuesIn(dPrecisions),
@ -67,7 +67,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set2, GatherNDLayerTest,
::testing::Values<Config>({})),
GatherNDLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherNDLayerTest,
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherND8LayerTest,
::testing::Combine(
gatherNDArgsSubset2,
::testing::ValuesIn(dPrecisions),