[CPU] Fixed StridedSlice (#5219)
This commit is contained in:
parent
5a111bfb27
commit
1c4428e945
@ -57,6 +57,7 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
|
||||
const SizeVector srcDims = inData->getTensorDesc().getDims();
|
||||
const SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims();
|
||||
const size_t nSrcDims = srcDims.size();
|
||||
const size_t nDims = std::max(nSrcDims, dstDims.size());
|
||||
|
||||
if (getParentEdges().size() != 3 && getParentEdges().size() != 4)
|
||||
THROW_ERROR << "has incorrect number of input edges";
|
||||
@ -90,7 +91,7 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
|
||||
auto createMask = [&](const char* maskName, std::vector<int>& mask, const int bit = 0) {
|
||||
mask = stridedSliceLayer->GetParamAsInts(maskName);
|
||||
if (strcmp(maskName, "ellipsis_mask") != 0 || mask.size() == 0) {
|
||||
for (size_t i = mask.size(); i < dstDims.size(); ++i) mask.push_back(bit);
|
||||
for (size_t i = mask.size(); i < nDims; ++i) mask.push_back(bit);
|
||||
}
|
||||
};
|
||||
|
||||
@ -122,8 +123,8 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
|
||||
const int *ptr = blob->cbuffer().as<const int *>() + blob->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
parameter.assign(ptr, ptr + size);
|
||||
|
||||
if (ellipsisMaskCounter == 0 && size < dstDims.size()) {
|
||||
for (size_t i = size; i < dstDims.size(); i++) parameter.push_back(value);
|
||||
if (ellipsisMaskCounter == 0 && size < nDims) {
|
||||
for (size_t i = size; i < nDims; i++) parameter.push_back(value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -529,14 +530,15 @@ void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
|
||||
if (!params.parametersAreConstant) {
|
||||
auto srcDims = getParentEdgeAt(DATA_ID)->getDims();
|
||||
auto dstDims = getChildEdgeAt(0)->getDims();
|
||||
const size_t nDims = std::max(srcDims.size(), dstDims.size());
|
||||
const size_t ellipsisMaskCounter = std::accumulate(ellipsisMask.begin(), ellipsisMask.end(), 0);
|
||||
|
||||
auto fillingInParameters = [&](std::vector<int> ¶meter, const size_t type, const size_t size, const int value) {
|
||||
const int *ptr = reinterpret_cast<const int*>(this->getParentEdgeAt(type)->getMemoryPtr()->GetPtr());
|
||||
parameter.assign(ptr, ptr + size);
|
||||
|
||||
if (ellipsisMaskCounter == 0 && size < dstDims.ndims()) {
|
||||
for (size_t i = size; i < dstDims.ndims(); i++) parameter.push_back(value);
|
||||
if (ellipsisMaskCounter == 0 && size < nDims) {
|
||||
for (size_t i = size; i < nDims; i++) parameter.push_back(value);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -96,6 +96,14 @@ std::vector<StridedSliceSpecificParams> ss_only_test_cases = {
|
||||
{ 0, 1, 1, 1}, { 1, 1, 1, 1}, {}, {}, {} },
|
||||
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||
{ 1, 0, 1, 1, 1}, { 1, 0, 1, 1, 1 }, {}, { 0, 1, 0, 0, 0 }, {} },
|
||||
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 0, 3, 0, 0 }, { 2, 3, 4, 3, 6 }, { 1, 1, 1, 1, 1 },
|
||||
{ 1, 1, 0, 1, 1}, { 1, 1, 0, 0, 1 }, {}, { 0, 0, 1, 0, 0 }, {} },
|
||||
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 0, 0, 0, 3 }, { 1, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||
{ 0, 1, 1, 1, 0}, { 0, 1, 1, 1, 0 }, {}, { 1, 0, 0, 0, 1 }, {} },
|
||||
StridedSliceSpecificParams{ { 2, 3, 4, 5 }, { 0, 0, 0, 0, 0 }, { 0, 2, 3, 4, 5 }, { 1, 1, 1, 1, 1 },
|
||||
{ 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1 }, { 1, 0, 0, 0, 0 }, {}, {} },
|
||||
StridedSliceSpecificParams{ { 2, 3, 4, 5 }, { 0, 0, 0, 0, 0 }, { 0, 2, 3, 4, 5 }, { 1, 1, 1, 1, 1 },
|
||||
{ 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1 }, { 0, 0, 1, 0, 0 }, {}, {} },
|
||||
StridedSliceSpecificParams{ { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
||||
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
||||
StridedSliceSpecificParams{ { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
||||
|
Loading…
Reference in New Issue
Block a user