[CPU] Fixed leftovers for GatherND (#8644)
This commit is contained in:
parent
cfe33fdf08
commit
8c55c761c4
@ -41,8 +41,8 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr<ngraph::Node>& op,
|
||||
if (inputShapes.size() != 2 && outputShapes.size() != 1)
|
||||
THROW_ERROR << "has invalid number of input/output edges.";
|
||||
|
||||
const size_t inputDataRank = getInputShapeAtPort(GATHERND_DATA).getRank();
|
||||
const size_t indicesDimsRank = getInputShapeAtPort(GATHERND_INDEXES).getRank();
|
||||
const size_t dataInputRank = getInputShapeAtPort(GATHERND_DATA).getRank();
|
||||
const size_t indicesInputRank = getInputShapeAtPort(GATHERND_INDEXES).getRank();
|
||||
|
||||
if (auto gatherNdOp = ngraph::as_type_ptr<const ngraph::op::v8::GatherND>(op)) {
|
||||
attrs.batchDims = gatherNdOp->get_batch_dims();
|
||||
@ -51,7 +51,7 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr<ngraph::Node>& op,
|
||||
} else {
|
||||
THROW_ERROR << "has support only opset5.";
|
||||
}
|
||||
if (attrs.batchDims >= std::min(inputDataRank, indicesDimsRank))
|
||||
if (attrs.batchDims >= std::min(dataInputRank, indicesInputRank))
|
||||
THROW_ERROR << "has invalid batch_dims attribute: " << attrs.batchDims;
|
||||
}
|
||||
|
||||
@ -60,12 +60,19 @@ void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() {
|
||||
return;
|
||||
|
||||
Precision inDataPrecision = getOriginalInputPrecisionAtPort(GATHERND_DATA);
|
||||
if (!MKLDNNPlugin::one_of(inDataPrecision.size(),
|
||||
sizeof(PrecisionTrait<Precision::I32>::value_type),
|
||||
sizeof(PrecisionTrait<Precision::I16>::value_type),
|
||||
sizeof(PrecisionTrait<Precision::I8>::value_type))) {
|
||||
THROW_ERROR << "has unsupported 'data' input precision: " << inDataPrecision;
|
||||
}
|
||||
attrs.dataSize = inDataPrecision.size();
|
||||
|
||||
Precision indicesPrecision = getOriginalInputPrecisionAtPort(GATHERND_INDEXES);
|
||||
if (!MKLDNNPlugin::one_of(indicesPrecision,
|
||||
Precision::I32, Precision::I64, Precision::I16, Precision::U16, Precision::I8, Precision::U8)) {
|
||||
THROW_ERROR << "has unsupported 'indices' input precision: " << indicesPrecision;
|
||||
}
|
||||
attrs.dataSize = inDataPrecision.size();
|
||||
|
||||
addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision},
|
||||
{LayoutType::ncsp, Precision::I32}},
|
||||
@ -96,50 +103,128 @@ void MKLDNNGatherNDNode::prepareParams() {
|
||||
|
||||
attrs.srcDims = srcMemPtr->getStaticDims();
|
||||
attrs.srcStrides = srcMemPtr->GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
attrs.dstSize = dstMemPtr->GetSize();
|
||||
attrs.dstElementCount = dstMemPtr->GetShape().getElementsCount();
|
||||
attrs.sliceRank = idxMemPtr->getStaticDims().back();
|
||||
execPtr = std::make_shared<GatherNDExecutor>(attrs);
|
||||
}
|
||||
|
||||
MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : attrs(attrs) {
|
||||
MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : dataSize(attrs.dataSize), sliceRank(attrs.sliceRank) {
|
||||
batchSize = std::accumulate(attrs.srcDims.begin(), attrs.srcDims.begin() + attrs.batchDims, 1lu, std::multiplies<size_t>());
|
||||
dataLength = std::accumulate(attrs.srcDims.begin() + attrs.sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu,
|
||||
std::multiplies<size_t>()) * attrs.dataSize;
|
||||
cycles = attrs.dstSize / (dataLength * batchSize);
|
||||
dataLength = std::accumulate(attrs.srcDims.begin() + sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu,
|
||||
std::multiplies<size_t>());
|
||||
cycles = attrs.dstElementCount / (dataLength * batchSize);
|
||||
workAmount = batchSize * cycles;
|
||||
|
||||
srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, attrs.srcDims.end(), 1lu,
|
||||
std::multiplies<size_t>()) * attrs.dataSize;
|
||||
idxBatchStride = cycles * attrs.sliceRank;
|
||||
std::multiplies<size_t>());
|
||||
idxBatchStride = cycles * sliceRank;
|
||||
dstBatchStride = cycles * dataLength;
|
||||
|
||||
srcShifts.resize(attrs.sliceRank, 0);
|
||||
for (size_t i = 0; i < attrs.sliceRank ; i++)
|
||||
srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * attrs.dataSize;
|
||||
}
|
||||
srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * (dataLength > 1 ? dataSize : 1);
|
||||
|
||||
void MKLDNNGatherNDNode::GatherNDExecutor::exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData) {
|
||||
parallel_for2d(batchSize, cycles, [&](const size_t b, const size_t j) {
|
||||
const size_t srcStride = b * srcBatchStride;
|
||||
const size_t idxStride = b * idxBatchStride + j * attrs.sliceRank;
|
||||
const size_t dstStride = b * dstBatchStride + j * dataLength;
|
||||
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0; i < attrs.sliceRank ; ++i)
|
||||
dataIdx += srcShifts[i] * indices[idxStride + i];
|
||||
|
||||
cpu_memcpy(&dstData[dstStride], &srcData[srcStride + dataIdx], dataLength);
|
||||
});
|
||||
// optimized implementation 'blocks' via memcpy
|
||||
if (dataLength > 1) {
|
||||
dataLength *= dataSize;
|
||||
srcBatchStride *= dataSize;
|
||||
dstBatchStride *= dataSize;
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::execute(mkldnn::stream strm) {
|
||||
if (!execPtr)
|
||||
THROW_ERROR << "has not compiled executor.";
|
||||
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr()->GetPtr());
|
||||
const int32_t* indices = reinterpret_cast<const int32_t*>(getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr()->GetPtr());
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t*>(getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
execPtr->exec(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr(),
|
||||
getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr(),
|
||||
getChildEdgeAt(0)->getMemoryPtr());
|
||||
}
|
||||
|
||||
execPtr->exec(srcData, indices, dstData);
|
||||
void MKLDNNGatherNDNode::GatherNDExecutor::exec(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
|
||||
if (dataLength > 1) {
|
||||
gatherBlocks(srcMemPtr, idxMemPtr, dstMemPtr);
|
||||
return;
|
||||
}
|
||||
|
||||
GatherNDContext ctx { this, srcMemPtr, idxMemPtr, dstMemPtr };
|
||||
OV_SWITCH(MKLDNNPlugin, GatherNDEmitter, ctx, dataSize,
|
||||
OV_CASE(sizeof(PrecisionTrait<Precision::I32>::value_type), PrecisionTrait<Precision::I32>::value_type),
|
||||
OV_CASE(sizeof(PrecisionTrait<Precision::I16>::value_type), PrecisionTrait<Precision::I16>::value_type),
|
||||
OV_CASE(sizeof(PrecisionTrait<Precision::I8>::value_type), PrecisionTrait<Precision::I8>::value_type));
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::GatherNDExecutor::gatherBlocks(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(srcMemPtr->GetPtr());
|
||||
const int32_t* indices = reinterpret_cast<const int32_t*>(idxMemPtr->GetPtr());
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t*>(dstMemPtr->GetPtr());
|
||||
|
||||
parallel_nt(0, [&](const int ithr, const int nthr) {
|
||||
size_t start(0lu), end(0lu);
|
||||
splitter(workAmount, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
size_t bStart = start / cycles;
|
||||
size_t cStart = start % cycles;
|
||||
size_t workCounter = start;
|
||||
|
||||
const uint8_t* shiftedSrcData = srcData + bStart * srcBatchStride;
|
||||
const int32_t* shiftedIndices = indices + bStart * idxBatchStride + cStart * sliceRank;
|
||||
uint8_t* shiftedDstData = dstData + bStart * dstBatchStride + cStart * dataLength;
|
||||
|
||||
for (size_t b = bStart; b < batchSize; b++) {
|
||||
for (size_t j = cStart; j < cycles; j++) {
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0; i < sliceRank; i++)
|
||||
dataIdx += srcShifts[i] * shiftedIndices[i];
|
||||
cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataLength);
|
||||
shiftedDstData += dataLength;
|
||||
shiftedIndices += sliceRank;
|
||||
if (++workCounter == end) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cStart = 0;
|
||||
shiftedSrcData += srcBatchStride;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename dataType>
|
||||
void MKLDNNGatherNDNode::GatherNDExecutor::gatherElementwise(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
|
||||
const dataType* srcData = reinterpret_cast<const dataType*>(srcMemPtr->GetPtr());
|
||||
const int32_t* indices = reinterpret_cast<const int32_t*>(idxMemPtr->GetPtr());
|
||||
dataType* dstData = reinterpret_cast<dataType*>(dstMemPtr->GetPtr());
|
||||
|
||||
parallel_nt(0, [&](const int ithr, const int nthr) {
|
||||
size_t start(0lu), end(0lu);
|
||||
splitter(workAmount, nthr, ithr, start, end);
|
||||
if (start >= end)
|
||||
return;
|
||||
size_t bStart = start / cycles;
|
||||
size_t cStart = start % cycles;
|
||||
size_t workCounter = start;
|
||||
|
||||
const dataType* shiftedSrcData = srcData + bStart * srcBatchStride;
|
||||
const int32_t* shiftedIndices = indices + bStart * idxBatchStride + cStart * sliceRank;
|
||||
dataType* shiftedDstData = dstData + bStart * dstBatchStride + cStart * dataLength;
|
||||
|
||||
for (size_t b = bStart; b < batchSize; b++) {
|
||||
for (size_t j = cStart; j < cycles; j++) {
|
||||
size_t dataIdx = 0lu;
|
||||
for (size_t i = 0lu; i < sliceRank; i++)
|
||||
dataIdx += srcShifts[i] * shiftedIndices[i];
|
||||
shiftedDstData[0] = shiftedSrcData[dataIdx];
|
||||
shiftedDstData++;
|
||||
shiftedIndices += sliceRank;
|
||||
if (++workCounter == end) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
cStart = 0lu;
|
||||
shiftedSrcData += srcBatchStride;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void MKLDNNGatherNDNode::executeDynamicImpl(dnnl::stream strm) {
|
||||
|
@ -32,7 +32,7 @@ private:
|
||||
struct GatherNDAttributes {
|
||||
size_t batchDims = 0lu;
|
||||
size_t dataSize = 1lu;
|
||||
size_t dstSize = 0lu;
|
||||
size_t dstElementCount = 0lu;
|
||||
size_t sliceRank = 0lu;
|
||||
|
||||
VectorDims srcDims;
|
||||
@ -42,19 +42,38 @@ private:
|
||||
struct GatherNDExecutor {
|
||||
GatherNDExecutor(const GatherNDAttributes& attrs);
|
||||
~GatherNDExecutor() = default;
|
||||
void exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData);
|
||||
void exec(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr);
|
||||
|
||||
private:
|
||||
template <typename dataType>
|
||||
void gatherElementwise(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr);
|
||||
void gatherBlocks(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr);
|
||||
|
||||
size_t batchSize = 1lu;
|
||||
size_t cycles = 1lu;
|
||||
size_t dataLength = 1lu;
|
||||
size_t sliceRank = 0lu;
|
||||
size_t workAmount = 0lu;
|
||||
size_t dataSize = 1lu;
|
||||
|
||||
size_t srcBatchStride = 1lu;
|
||||
size_t idxBatchStride = 1lu;
|
||||
size_t dstBatchStride = 1lu;
|
||||
VectorDims srcShifts;
|
||||
|
||||
GatherNDAttributes attrs;
|
||||
struct GatherNDContext {
|
||||
GatherNDExecutor* executor;
|
||||
const MKLDNNMemoryPtr srcMemPtr;
|
||||
const MKLDNNMemoryPtr idxMemPtr;
|
||||
MKLDNNMemoryPtr dstMemPtr;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct GatherNDEmitter {
|
||||
void operator()(GatherNDContext& ctx) {
|
||||
ctx.executor->gatherElementwise<T>(ctx.srcMemPtr, ctx.idxMemPtr, ctx.dstMemPtr);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr size_t GATHERND_DATA = 0lu;
|
||||
|
@ -41,7 +41,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set1, GatherNDLayerTest,
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherNDLayerTest,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherND8LayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset1,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
@ -67,7 +67,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set2, GatherNDLayerTest,
|
||||
::testing::Values<Config>({})),
|
||||
GatherNDLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherNDLayerTest,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherND8LayerTest,
|
||||
::testing::Combine(
|
||||
gatherNDArgsSubset2,
|
||||
::testing::ValuesIn(dPrecisions),
|
||||
|
Loading…
Reference in New Issue
Block a user