diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.cpp index 49a126e88c6..ae98c0e74ab 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.cpp @@ -41,8 +41,8 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr& op, if (inputShapes.size() != 2 && outputShapes.size() != 1) THROW_ERROR << "has invalid number of input/output edges."; - const size_t inputDataRank = getInputShapeAtPort(GATHERND_DATA).getRank(); - const size_t indicesDimsRank = getInputShapeAtPort(GATHERND_INDEXES).getRank(); + const size_t dataInputRank = getInputShapeAtPort(GATHERND_DATA).getRank(); + const size_t indicesInputRank = getInputShapeAtPort(GATHERND_INDEXES).getRank(); if (auto gatherNdOp = ngraph::as_type_ptr(op)) { attrs.batchDims = gatherNdOp->get_batch_dims(); @@ -51,7 +51,7 @@ MKLDNNGatherNDNode::MKLDNNGatherNDNode(const std::shared_ptr& op, } else { THROW_ERROR << "has support only opset5."; } - if (attrs.batchDims >= std::min(inputDataRank, indicesDimsRank)) + if (attrs.batchDims >= std::min(dataInputRank, indicesInputRank)) THROW_ERROR << "has invalid batch_dims attribute: " << attrs.batchDims; } @@ -60,12 +60,19 @@ void MKLDNNGatherNDNode::initSupportedPrimitiveDescriptors() { return; Precision inDataPrecision = getOriginalInputPrecisionAtPort(GATHERND_DATA); + if (!MKLDNNPlugin::one_of(inDataPrecision.size(), + sizeof(PrecisionTrait::value_type), + sizeof(PrecisionTrait::value_type), + sizeof(PrecisionTrait::value_type))) { + THROW_ERROR << "has unsupported 'data' input precision: " << inDataPrecision; + } + attrs.dataSize = inDataPrecision.size(); + Precision indicesPrecision = getOriginalInputPrecisionAtPort(GATHERND_INDEXES); if (!MKLDNNPlugin::one_of(indicesPrecision, Precision::I32, Precision::I64, Precision::I16, Precision::U16, Precision::I8, Precision::U8)) { THROW_ERROR << "has unsupported 'indices' input precision: " << indicesPrecision; } - attrs.dataSize = inDataPrecision.size(); addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision}, {LayoutType::ncsp, Precision::I32}}, @@ -96,50 +103,128 @@ void MKLDNNGatherNDNode::prepareParams() { attrs.srcDims = srcMemPtr->getStaticDims(); attrs.srcStrides = srcMemPtr->GetDescWithType()->getStrides(); - attrs.dstSize = dstMemPtr->GetSize(); + attrs.dstElementCount = dstMemPtr->GetShape().getElementsCount(); attrs.sliceRank = idxMemPtr->getStaticDims().back(); execPtr = std::make_shared(attrs); } -MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : attrs(attrs) { +MKLDNNGatherNDNode::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : dataSize(attrs.dataSize), sliceRank(attrs.sliceRank) { batchSize = std::accumulate(attrs.srcDims.begin(), attrs.srcDims.begin() + attrs.batchDims, 1lu, std::multiplies()); - dataLength = std::accumulate(attrs.srcDims.begin() + attrs.sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu, - std::multiplies()) * attrs.dataSize; - cycles = attrs.dstSize / (dataLength * batchSize); + dataLength = std::accumulate(attrs.srcDims.begin() + sliceRank + attrs.batchDims, attrs.srcDims.end(), 1lu, + std::multiplies()); + cycles = attrs.dstElementCount / (dataLength * batchSize); + workAmount = batchSize * cycles; srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, attrs.srcDims.end(), 1lu, - std::multiplies()) * attrs.dataSize; - idxBatchStride = cycles * attrs.sliceRank; + std::multiplies()); + idxBatchStride = cycles * sliceRank; dstBatchStride = cycles * dataLength; srcShifts.resize(attrs.sliceRank, 0); for (size_t i = 0; i < attrs.sliceRank ; i++) - srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * attrs.dataSize; -} + srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * (dataLength > 1 ? dataSize : 1); -void MKLDNNGatherNDNode::GatherNDExecutor::exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData) { - parallel_for2d(batchSize, cycles, [&](const size_t b, const size_t j) { - const size_t srcStride = b * srcBatchStride; - const size_t idxStride = b * idxBatchStride + j * attrs.sliceRank; - const size_t dstStride = b * dstBatchStride + j * dataLength; - - size_t dataIdx = 0lu; - for (size_t i = 0; i < attrs.sliceRank ; ++i) - dataIdx += srcShifts[i] * indices[idxStride + i]; - - cpu_memcpy(&dstData[dstStride], &srcData[srcStride + dataIdx], dataLength); - }); + // optimized implementation 'blocks' via memcpy + if (dataLength > 1) { + dataLength *= dataSize; + srcBatchStride *= dataSize; + dstBatchStride *= dataSize; + } } void MKLDNNGatherNDNode::execute(mkldnn::stream strm) { if (!execPtr) THROW_ERROR << "has not compiled executor."; - const uint8_t* srcData = reinterpret_cast(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr()->GetPtr()); - const int32_t* indices = reinterpret_cast(getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr()->GetPtr()); - uint8_t* dstData = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->GetPtr()); + execPtr->exec(getParentEdgeAt(GATHERND_DATA)->getMemoryPtr(), + getParentEdgeAt(GATHERND_INDEXES)->getMemoryPtr(), + getChildEdgeAt(0)->getMemoryPtr()); +} - execPtr->exec(srcData, indices, dstData); +void MKLDNNGatherNDNode::GatherNDExecutor::exec(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + if (dataLength > 1) { + gatherBlocks(srcMemPtr, idxMemPtr, dstMemPtr); + return; + } + + GatherNDContext ctx { this, srcMemPtr, idxMemPtr, dstMemPtr }; + OV_SWITCH(MKLDNNPlugin, GatherNDEmitter, ctx, dataSize, + OV_CASE(sizeof(PrecisionTrait::value_type), PrecisionTrait::value_type), + OV_CASE(sizeof(PrecisionTrait::value_type), PrecisionTrait::value_type), + OV_CASE(sizeof(PrecisionTrait::value_type), PrecisionTrait::value_type)); +} + +void MKLDNNGatherNDNode::GatherNDExecutor::gatherBlocks(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + const uint8_t* srcData = reinterpret_cast(srcMemPtr->GetPtr()); + const int32_t* indices = reinterpret_cast(idxMemPtr->GetPtr()); + uint8_t* dstData = reinterpret_cast(dstMemPtr->GetPtr()); + + parallel_nt(0, [&](const int ithr, const int nthr) { + size_t start(0lu), end(0lu); + splitter(workAmount, nthr, ithr, start, end); + if (start >= end) + return; + size_t bStart = start / cycles; + size_t cStart = start % cycles; + size_t workCounter = start; + + const uint8_t* shiftedSrcData = srcData + bStart * srcBatchStride; + const int32_t* shiftedIndices = indices + bStart * idxBatchStride + cStart * sliceRank; + uint8_t* shiftedDstData = dstData + bStart * dstBatchStride + cStart * dataLength; + + for (size_t b = bStart; b < batchSize; b++) { + for (size_t j = cStart; j < cycles; j++) { + size_t dataIdx = 0lu; + for (size_t i = 0; i < sliceRank; i++) + dataIdx += srcShifts[i] * shiftedIndices[i]; + cpu_memcpy(shiftedDstData, &(shiftedSrcData[dataIdx]), dataLength); + shiftedDstData += dataLength; + shiftedIndices += sliceRank; + if (++workCounter == end) { + return; + } + } + cStart = 0; + shiftedSrcData += srcBatchStride; + } + }); +} + +template +void MKLDNNGatherNDNode::GatherNDExecutor::gatherElementwise(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + const dataType* srcData = reinterpret_cast(srcMemPtr->GetPtr()); + const int32_t* indices = reinterpret_cast(idxMemPtr->GetPtr()); + dataType* dstData = reinterpret_cast(dstMemPtr->GetPtr()); + + parallel_nt(0, [&](const int ithr, const int nthr) { + size_t start(0lu), end(0lu); + splitter(workAmount, nthr, ithr, start, end); + if (start >= end) + return; + size_t bStart = start / cycles; + size_t cStart = start % cycles; + size_t workCounter = start; + + const dataType* shiftedSrcData = srcData + bStart * srcBatchStride; + const int32_t* shiftedIndices = indices + bStart * idxBatchStride + cStart * sliceRank; + dataType* shiftedDstData = dstData + bStart * dstBatchStride + cStart * dataLength; + + for (size_t b = bStart; b < batchSize; b++) { + for (size_t j = cStart; j < cycles; j++) { + size_t dataIdx = 0lu; + for (size_t i = 0lu; i < sliceRank; i++) + dataIdx += srcShifts[i] * shiftedIndices[i]; + shiftedDstData[0] = shiftedSrcData[dataIdx]; + shiftedDstData++; + shiftedIndices += sliceRank; + if (++workCounter == end) { + return; + } + } + cStart = 0lu; + shiftedSrcData += srcBatchStride; + } + }); } void MKLDNNGatherNDNode::executeDynamicImpl(dnnl::stream strm) { diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.h index fefce441590..53661c4d342 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_gather_nd_node.h @@ -32,7 +32,7 @@ private: struct GatherNDAttributes { size_t batchDims = 0lu; size_t dataSize = 1lu; - size_t dstSize = 0lu; + size_t dstElementCount = 0lu; size_t sliceRank = 0lu; VectorDims srcDims; @@ -42,19 +42,38 @@ private: struct GatherNDExecutor { GatherNDExecutor(const GatherNDAttributes& attrs); ~GatherNDExecutor() = default; - void exec(const uint8_t* srcData, const int32_t* indices, uint8_t* dstData); + void exec(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr); private: + template + void gatherElementwise(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr); + void gatherBlocks(const MKLDNNMemoryPtr& srcMemPtr, const MKLDNNMemoryPtr& idxMemPtr, MKLDNNMemoryPtr& dstMemPtr); + size_t batchSize = 1lu; size_t cycles = 1lu; size_t dataLength = 1lu; + size_t sliceRank = 0lu; + size_t workAmount = 0lu; + size_t dataSize = 1lu; size_t srcBatchStride = 1lu; size_t idxBatchStride = 1lu; size_t dstBatchStride = 1lu; VectorDims srcShifts; - GatherNDAttributes attrs; + struct GatherNDContext { + GatherNDExecutor* executor; + const MKLDNNMemoryPtr srcMemPtr; + const MKLDNNMemoryPtr idxMemPtr; + MKLDNNMemoryPtr dstMemPtr; + }; + + template + struct GatherNDEmitter { + void operator()(GatherNDContext& ctx) { + ctx.executor->gatherElementwise(ctx.srcMemPtr, ctx.idxMemPtr, ctx.dstMemPtr); + } + }; }; static constexpr size_t GATHERND_DATA = 0lu; diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_nd.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_nd.cpp index b85c7f4e646..e9fcf3fcd20 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_nd.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather_nd.cpp @@ -41,7 +41,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set1, GatherNDLayerTest, ::testing::Values({})), GatherNDLayerTest::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherNDLayerTest, +INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set1, GatherND8LayerTest, ::testing::Combine( gatherNDArgsSubset1, ::testing::ValuesIn(dPrecisions), @@ -67,7 +67,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherND5_Set2, GatherNDLayerTest, ::testing::Values({})), GatherNDLayerTest::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherNDLayerTest, +INSTANTIATE_TEST_SUITE_P(smoke_GatherND8_Set2, GatherND8LayerTest, ::testing::Combine( gatherNDArgsSubset2, ::testing::ValuesIn(dPrecisions),