[CPU] Fixed StridedSlice (#5219)

This commit is contained in:
Alexandra Sidorova 2021-04-14 10:37:10 +03:00 committed by GitHub
parent 5a111bfb27
commit 1c4428e945
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 6 deletions

View File

@ -57,6 +57,7 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
const SizeVector srcDims = inData->getTensorDesc().getDims(); const SizeVector srcDims = inData->getTensorDesc().getDims();
const SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims(); const SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims();
const size_t nSrcDims = srcDims.size(); const size_t nSrcDims = srcDims.size();
const size_t nDims = std::max(nSrcDims, dstDims.size());
if (getParentEdges().size() != 3 && getParentEdges().size() != 4) if (getParentEdges().size() != 3 && getParentEdges().size() != 4)
THROW_ERROR << "has incorrect number of input edges"; THROW_ERROR << "has incorrect number of input edges";
@ -90,7 +91,7 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
auto createMask = [&](const char* maskName, std::vector<int>& mask, const int bit = 0) { auto createMask = [&](const char* maskName, std::vector<int>& mask, const int bit = 0) {
mask = stridedSliceLayer->GetParamAsInts(maskName); mask = stridedSliceLayer->GetParamAsInts(maskName);
if (strcmp(maskName, "ellipsis_mask") != 0 || mask.size() == 0) { if (strcmp(maskName, "ellipsis_mask") != 0 || mask.size() == 0) {
for (size_t i = mask.size(); i < dstDims.size(); ++i) mask.push_back(bit); for (size_t i = mask.size(); i < nDims; ++i) mask.push_back(bit);
} }
}; };
@ -122,8 +123,8 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
const int *ptr = blob->cbuffer().as<const int *>() + blob->getTensorDesc().getBlockingDesc().getOffsetPadding(); const int *ptr = blob->cbuffer().as<const int *>() + blob->getTensorDesc().getBlockingDesc().getOffsetPadding();
parameter.assign(ptr, ptr + size); parameter.assign(ptr, ptr + size);
if (ellipsisMaskCounter == 0 && size < dstDims.size()) { if (ellipsisMaskCounter == 0 && size < nDims) {
for (size_t i = size; i < dstDims.size(); i++) parameter.push_back(value); for (size_t i = size; i < nDims; i++) parameter.push_back(value);
} }
}; };
@ -529,14 +530,15 @@ void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
if (!params.parametersAreConstant) { if (!params.parametersAreConstant) {
auto srcDims = getParentEdgeAt(DATA_ID)->getDims(); auto srcDims = getParentEdgeAt(DATA_ID)->getDims();
auto dstDims = getChildEdgeAt(0)->getDims(); auto dstDims = getChildEdgeAt(0)->getDims();
const size_t nDims = std::max(srcDims.size(), dstDims.size());
const size_t ellipsisMaskCounter = std::accumulate(ellipsisMask.begin(), ellipsisMask.end(), 0); const size_t ellipsisMaskCounter = std::accumulate(ellipsisMask.begin(), ellipsisMask.end(), 0);
auto fillingInParameters = [&](std::vector<int> &parameter, const size_t type, const size_t size, const int value) { auto fillingInParameters = [&](std::vector<int> &parameter, const size_t type, const size_t size, const int value) {
const int *ptr = reinterpret_cast<const int*>(this->getParentEdgeAt(type)->getMemoryPtr()->GetPtr()); const int *ptr = reinterpret_cast<const int*>(this->getParentEdgeAt(type)->getMemoryPtr()->GetPtr());
parameter.assign(ptr, ptr + size); parameter.assign(ptr, ptr + size);
if (ellipsisMaskCounter == 0 && size < dstDims.ndims()) { if (ellipsisMaskCounter == 0 && size < nDims) {
for (size_t i = size; i < dstDims.ndims(); i++) parameter.push_back(value); for (size_t i = size; i < nDims; i++) parameter.push_back(value);
} }
}; };

View File

@ -95,7 +95,15 @@ std::vector<StridedSliceSpecificParams> ss_only_test_cases = {
StridedSliceSpecificParams{ { 2, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 }, StridedSliceSpecificParams{ { 2, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
{ 0, 1, 1, 1}, { 1, 1, 1, 1}, {}, {}, {} }, { 0, 1, 1, 1}, { 1, 1, 1, 1}, {}, {}, {} },
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 }, StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
{ 1, 0, 1, 1, 1}, { 1, 0, 1, 1, 1}, {}, { 0, 1, 0, 0, 0}, {} }, { 1, 0, 1, 1, 1}, { 1, 0, 1, 1, 1 }, {}, { 0, 1, 0, 0, 0 }, {} },
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 0, 3, 0, 0 }, { 2, 3, 4, 3, 6 }, { 1, 1, 1, 1, 1 },
{ 1, 1, 0, 1, 1}, { 1, 1, 0, 0, 1 }, {}, { 0, 0, 1, 0, 0 }, {} },
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 0, 0, 0, 3 }, { 1, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
{ 0, 1, 1, 1, 0}, { 0, 1, 1, 1, 0 }, {}, { 1, 0, 0, 0, 1 }, {} },
StridedSliceSpecificParams{ { 2, 3, 4, 5 }, { 0, 0, 0, 0, 0 }, { 0, 2, 3, 4, 5 }, { 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1 }, { 1, 0, 0, 0, 0 }, {}, {} },
StridedSliceSpecificParams{ { 2, 3, 4, 5 }, { 0, 0, 0, 0, 0 }, { 0, 2, 3, 4, 5 }, { 1, 1, 1, 1, 1 },
{ 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1 }, { 0, 0, 1, 0, 0 }, {}, {} },
StridedSliceSpecificParams{ { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 }, StridedSliceSpecificParams{ { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } }, { 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
StridedSliceSpecificParams{ { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 }, StridedSliceSpecificParams{ { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },