[CPU] Added improvements for StridedSlice (#6658)
This commit is contained in:
parent
1247b228db
commit
983bab8271
@ -272,6 +272,8 @@ void MKLDNNStridedSliceNode::createPrimitive() {
|
||||
auto srcOrder = srcBlockingDesc.getOrder();
|
||||
params.srcDims = srcBlockingDesc.getBlockDims();
|
||||
params.dstDims = dstBlockingDesc.getBlockDims();
|
||||
params.srcMemPtr = srcMemPtr;
|
||||
params.dstMemPtr = dstMemPtr;
|
||||
params.dataSize = getSelectedPrimitiveDescriptor()->getConfig().inConfs[DATA_ID].desc->getPrecision().size();
|
||||
|
||||
if (params.parametersAreConstant) {
|
||||
@ -282,9 +284,7 @@ void MKLDNNStridedSliceNode::createPrimitive() {
|
||||
SizeVector newSrcDims, newDstDims;
|
||||
dimsNormalization(newSrcDims, newDstDims);
|
||||
dimsGluing(realNDims, newSrcDims, newDstDims);
|
||||
|
||||
if (params.dstDims.size() == 1 || params.nDimsForWork != 1)
|
||||
indicesCalculation();
|
||||
indicesCalculation();
|
||||
}
|
||||
}
|
||||
|
||||
@ -510,14 +510,35 @@ void MKLDNNStridedSliceNode::dimsGluing(const size_t realNDims, const SizeVector
|
||||
if (params.dstDims.size() > 2)
|
||||
params.lastDstDim /= newDstDims[secondDim.first];
|
||||
}
|
||||
|
||||
// some parameter calculations for common execution
|
||||
params.isOptimized = params.nDimsForWork == 1 && params.dstDims.size() > 1;
|
||||
if (params.isOptimized) {
|
||||
if (params.dstDims.size() == 2)
|
||||
params.dstDims[1] = 1;
|
||||
|
||||
params.workAmount = params.dstDims[0] * params.dstDims[1];
|
||||
params.srcShift = (begin[0] * params.srcStrides[0] + begin[1] * params.srcStrides[1]) * params.dataSize;
|
||||
} else {
|
||||
params.srcShift = stride.back() == 1 && stride.size() > 1 ?
|
||||
begin[params.nDimsForWork] * params.srcStrides[params.nDimsForWork] * params.dataSize : 0;
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNStridedSliceNode::indicesCalculation() {
|
||||
// indices calculation before execution for the best performance
|
||||
params.nThreads = parallel_get_max_threads();
|
||||
params.srcIndices.resize(params.workAmount, 0);
|
||||
params.dstIndices.resize(params.workAmount, 0);
|
||||
|
||||
// should choose more optimal thread count
|
||||
const size_t nthr = parallel_get_max_threads();
|
||||
params.nThreads = nthr > params.workAmount ? params.workAmount : nthr;
|
||||
|
||||
if (params.isOptimized) {
|
||||
indicesCalculationForOptimized();
|
||||
return;
|
||||
}
|
||||
|
||||
auto getSrcIdx = [this](const SizeVector& indexes){
|
||||
size_t srcIdx = 0;
|
||||
for (int i = 0; i < params.nDimsForWork; ++i)
|
||||
@ -542,10 +563,10 @@ void MKLDNNStridedSliceNode::indicesCalculation() {
|
||||
if (coords[k] < params.dstDims[k]) {
|
||||
srcIdx += stride[k] * params.srcStrides[k] * params.dataSize;
|
||||
break;
|
||||
} else {
|
||||
coords[k] = 0;
|
||||
out = true;
|
||||
}
|
||||
|
||||
coords[k] = 0;
|
||||
out = true;
|
||||
}
|
||||
|
||||
if (out)
|
||||
@ -554,6 +575,25 @@ void MKLDNNStridedSliceNode::indicesCalculation() {
|
||||
});
|
||||
}
|
||||
|
||||
void MKLDNNStridedSliceNode::indicesCalculationForOptimized() {
|
||||
const size_t dstIdx0 = params.dstStrides[0] * params.dataSize;
|
||||
const size_t dstIdx1 = params.dstStrides[1] * params.dataSize;
|
||||
const size_t srcIdx0 = stride[0] * params.srcStrides[0] * params.dataSize;
|
||||
const size_t srcIdx1 = stride[1] * params.srcStrides[1] * params.dataSize;
|
||||
|
||||
for (size_t i0 = 0; i0 < params.dstDims[0]; i0++) {
|
||||
const size_t idx = i0 * params.dstDims[1];
|
||||
|
||||
params.dstIndices[idx] = i0 * dstIdx0;
|
||||
params.srcIndices[idx] = i0 * srcIdx0;
|
||||
|
||||
for (size_t i1 = 1; i1 < params.dstDims[1]; i1++) {
|
||||
params.dstIndices[idx + i1] = params.dstIndices[idx] + i1 * dstIdx1;
|
||||
params.srcIndices[idx + i1] = params.srcIndices[idx] + i1 * srcIdx1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
|
||||
if (!params.parametersAreConstant) {
|
||||
auto srcDims = getParentEdgeAt(DATA_ID)->getShape().getStaticDims();
|
||||
@ -586,42 +626,15 @@ void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
|
||||
SizeVector newSrcDims, newDstDims;
|
||||
dimsNormalization(newSrcDims, newDstDims);
|
||||
dimsGluing(dstDims.size(), newSrcDims, newDstDims);
|
||||
|
||||
if (params.dstDims.size() == 1 || params.nDimsForWork != 1)
|
||||
indicesCalculation();
|
||||
indicesCalculation();
|
||||
}
|
||||
|
||||
if (params.dstDims.size() > 1 && params.nDimsForWork == 1)
|
||||
stridedSliceV();
|
||||
else
|
||||
stridedSlice();
|
||||
stridedSlice();
|
||||
}
|
||||
|
||||
void MKLDNNStridedSliceNode::stridedSliceV() {
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(this->getParentEdgeAt(DATA_ID)->getMemoryPtr()->GetPtr()) +
|
||||
(begin[0] * params.srcStrides[0] + begin[1] * params.srcStrides[1]) * params.dataSize;
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t*>(this->getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
|
||||
const size_t dstIdx = params.dstStrides[0] * params.dataSize;
|
||||
const size_t srcIdx = stride[0] * params.srcStrides[0] * params.dataSize;
|
||||
const size_t dstShift = params.dstStrides[1] * params.dataSize;
|
||||
const size_t srcShift = stride[1] * params.srcStrides[1] * params.dataSize;
|
||||
|
||||
if (params.dstDims.size() > 2) {
|
||||
parallel_for2d(params.dstDims[0], params.dstDims[1], [&](const size_t i, const size_t j) {
|
||||
cpu_memcpy(&dstData[i * dstIdx + j * dstShift], &srcData[i * srcIdx + j * srcShift], params.lastDstDim);
|
||||
});
|
||||
} else {
|
||||
parallel_for(params.dstDims[0], [&](const size_t i) {
|
||||
cpu_memcpy(&dstData[i * dstIdx], &srcData[i * srcIdx], params.lastDstDim);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNStridedSliceNode::stridedSlice() {
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(this->getParentEdgeAt(DATA_ID)->getMemoryPtr()->GetPtr()) +
|
||||
(stride.back() == 1 && stride.size() > 1 ? begin[params.nDimsForWork] * params.srcStrides[params.nDimsForWork] * params.dataSize : 0);
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t*>(this->getChildEdgeAt(0)->getMemoryPtr()->GetPtr());
|
||||
inline void MKLDNNStridedSliceNode::stridedSlice() {
|
||||
const uint8_t* srcData = reinterpret_cast<const uint8_t*>(params.srcMemPtr->GetPtr()) + params.srcShift;
|
||||
uint8_t* dstData = reinterpret_cast<uint8_t*>(params.dstMemPtr->GetPtr());
|
||||
|
||||
parallel_nt(params.nThreads, [&](const int ithr, const int nthr) {
|
||||
size_t start = 0, end = 0;
|
||||
|
@ -27,14 +27,14 @@ public:
|
||||
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
void stridedSliceV();
|
||||
void stridedSlice();
|
||||
inline void stridedSlice();
|
||||
|
||||
void addHiddenDims(const size_t nSrcDims);
|
||||
void orderParametersByLayouts();
|
||||
void dimsNormalization(InferenceEngine::SizeVector& newSrcDims, InferenceEngine::SizeVector& newDstDims);
|
||||
void dimsGluing(const size_t realNDims, const InferenceEngine::SizeVector& newSrcDims, const InferenceEngine::SizeVector& newDstDims);
|
||||
void indicesCalculation();
|
||||
void indicesCalculationForOptimized();
|
||||
|
||||
const size_t DATA_ID = 0;
|
||||
const size_t BEGIN_ID = 1;
|
||||
@ -56,6 +56,8 @@ private:
|
||||
InferenceEngine::SizeVector strideDims;
|
||||
|
||||
struct {
|
||||
MKLDNNMemoryPtr srcMemPtr = nullptr;
|
||||
MKLDNNMemoryPtr dstMemPtr = nullptr;
|
||||
InferenceEngine::SizeVector srcDims;
|
||||
InferenceEngine::SizeVector dstDims;
|
||||
InferenceEngine::SizeVector srcStrides;
|
||||
@ -69,6 +71,8 @@ private:
|
||||
size_t workAmount = 0;
|
||||
size_t lastDstDim = 0;
|
||||
size_t dataSize = 0;
|
||||
size_t srcShift = 0;
|
||||
bool isOptimized = false;
|
||||
bool equalDims = false;
|
||||
bool parametersAreConstant = true;
|
||||
} params;
|
||||
|
Loading…
Reference in New Issue
Block a user