[CPU] Fixed StridedSlice (#5219)
This commit is contained in:
parent
5a111bfb27
commit
1c4428e945
@ -57,6 +57,7 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
|
|||||||
const SizeVector srcDims = inData->getTensorDesc().getDims();
|
const SizeVector srcDims = inData->getTensorDesc().getDims();
|
||||||
const SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims();
|
const SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims();
|
||||||
const size_t nSrcDims = srcDims.size();
|
const size_t nSrcDims = srcDims.size();
|
||||||
|
const size_t nDims = std::max(nSrcDims, dstDims.size());
|
||||||
|
|
||||||
if (getParentEdges().size() != 3 && getParentEdges().size() != 4)
|
if (getParentEdges().size() != 3 && getParentEdges().size() != 4)
|
||||||
THROW_ERROR << "has incorrect number of input edges";
|
THROW_ERROR << "has incorrect number of input edges";
|
||||||
@ -90,7 +91,7 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
|
|||||||
auto createMask = [&](const char* maskName, std::vector<int>& mask, const int bit = 0) {
|
auto createMask = [&](const char* maskName, std::vector<int>& mask, const int bit = 0) {
|
||||||
mask = stridedSliceLayer->GetParamAsInts(maskName);
|
mask = stridedSliceLayer->GetParamAsInts(maskName);
|
||||||
if (strcmp(maskName, "ellipsis_mask") != 0 || mask.size() == 0) {
|
if (strcmp(maskName, "ellipsis_mask") != 0 || mask.size() == 0) {
|
||||||
for (size_t i = mask.size(); i < dstDims.size(); ++i) mask.push_back(bit);
|
for (size_t i = mask.size(); i < nDims; ++i) mask.push_back(bit);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -122,8 +123,8 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
|
|||||||
const int *ptr = blob->cbuffer().as<const int *>() + blob->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
const int *ptr = blob->cbuffer().as<const int *>() + blob->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
parameter.assign(ptr, ptr + size);
|
parameter.assign(ptr, ptr + size);
|
||||||
|
|
||||||
if (ellipsisMaskCounter == 0 && size < dstDims.size()) {
|
if (ellipsisMaskCounter == 0 && size < nDims) {
|
||||||
for (size_t i = size; i < dstDims.size(); i++) parameter.push_back(value);
|
for (size_t i = size; i < nDims; i++) parameter.push_back(value);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -529,14 +530,15 @@ void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
|
|||||||
if (!params.parametersAreConstant) {
|
if (!params.parametersAreConstant) {
|
||||||
auto srcDims = getParentEdgeAt(DATA_ID)->getDims();
|
auto srcDims = getParentEdgeAt(DATA_ID)->getDims();
|
||||||
auto dstDims = getChildEdgeAt(0)->getDims();
|
auto dstDims = getChildEdgeAt(0)->getDims();
|
||||||
|
const size_t nDims = std::max(srcDims.size(), dstDims.size());
|
||||||
const size_t ellipsisMaskCounter = std::accumulate(ellipsisMask.begin(), ellipsisMask.end(), 0);
|
const size_t ellipsisMaskCounter = std::accumulate(ellipsisMask.begin(), ellipsisMask.end(), 0);
|
||||||
|
|
||||||
auto fillingInParameters = [&](std::vector<int> ¶meter, const size_t type, const size_t size, const int value) {
|
auto fillingInParameters = [&](std::vector<int> ¶meter, const size_t type, const size_t size, const int value) {
|
||||||
const int *ptr = reinterpret_cast<const int*>(this->getParentEdgeAt(type)->getMemoryPtr()->GetPtr());
|
const int *ptr = reinterpret_cast<const int*>(this->getParentEdgeAt(type)->getMemoryPtr()->GetPtr());
|
||||||
parameter.assign(ptr, ptr + size);
|
parameter.assign(ptr, ptr + size);
|
||||||
|
|
||||||
if (ellipsisMaskCounter == 0 && size < dstDims.ndims()) {
|
if (ellipsisMaskCounter == 0 && size < nDims) {
|
||||||
for (size_t i = size; i < dstDims.ndims(); i++) parameter.push_back(value);
|
for (size_t i = size; i < nDims; i++) parameter.push_back(value);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -95,7 +95,15 @@ std::vector<StridedSliceSpecificParams> ss_only_test_cases = {
|
|||||||
StridedSliceSpecificParams{ { 2, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
|
StridedSliceSpecificParams{ { 2, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
|
||||||
{ 0, 1, 1, 1}, { 1, 1, 1, 1}, {}, {}, {} },
|
{ 0, 1, 1, 1}, { 1, 1, 1, 1}, {}, {}, {} },
|
||||||
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||||
{ 1, 0, 1, 1, 1}, { 1, 0, 1, 1, 1}, {}, { 0, 1, 0, 0, 0}, {} },
|
{ 1, 0, 1, 1, 1}, { 1, 0, 1, 1, 1 }, {}, { 0, 1, 0, 0, 0 }, {} },
|
||||||
|
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 0, 3, 0, 0 }, { 2, 3, 4, 3, 6 }, { 1, 1, 1, 1, 1 },
|
||||||
|
{ 1, 1, 0, 1, 1}, { 1, 1, 0, 0, 1 }, {}, { 0, 0, 1, 0, 0 }, {} },
|
||||||
|
StridedSliceSpecificParams{ { 2, 3, 4, 5, 6 }, { 0, 0, 0, 0, 3 }, { 1, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||||
|
{ 0, 1, 1, 1, 0}, { 0, 1, 1, 1, 0 }, {}, { 1, 0, 0, 0, 1 }, {} },
|
||||||
|
StridedSliceSpecificParams{ { 2, 3, 4, 5 }, { 0, 0, 0, 0, 0 }, { 0, 2, 3, 4, 5 }, { 1, 1, 1, 1, 1 },
|
||||||
|
{ 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1 }, { 1, 0, 0, 0, 0 }, {}, {} },
|
||||||
|
StridedSliceSpecificParams{ { 2, 3, 4, 5 }, { 0, 0, 0, 0, 0 }, { 0, 2, 3, 4, 5 }, { 1, 1, 1, 1, 1 },
|
||||||
|
{ 1, 1, 1, 1, 1 }, { 1, 1, 1, 1, 1 }, { 0, 0, 1, 0, 0 }, {}, {} },
|
||||||
StridedSliceSpecificParams{ { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
StridedSliceSpecificParams{ { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
||||||
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
||||||
StridedSliceSpecificParams{ { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
StridedSliceSpecificParams{ { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
||||||
|
Loading…
Reference in New Issue
Block a user