diff --git a/src/plugins/intel_cpu/src/nodes/strided_slice.cpp b/src/plugins/intel_cpu/src/nodes/strided_slice.cpp index 5b3c9c7ec02..3402b069b38 100644 --- a/src/plugins/intel_cpu/src/nodes/strided_slice.cpp +++ b/src/plugins/intel_cpu/src/nodes/strided_slice.cpp @@ -11,8 +11,6 @@ #include -#define THROW_ERROR IE_THROW() << NameFromType(getType()) << " node with name '" << getName() << "' " - using namespace dnnl; using namespace InferenceEngine; using namespace InferenceEngine::details; @@ -21,29 +19,18 @@ namespace ov { namespace intel_cpu { namespace node { -static inline size_t parallel_init(size_t start, size_t nDims, const VectorDims& dims, VectorDims& indexes) { - for (int j = nDims - 1; j >= 0; j--) { - indexes[j] = start % dims[j]; - start = start / dims[j]; - } - return start; -} - bool StridedSlice::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { if (!ov::is_type(op) && !ov::is_type(op)) { errorMessage = "Only StridedSlice from opset1 and Slice from opset8 operations are supported."; return false; - } - if (!ov::is_type(op->get_input_node_ptr(BEGIN_ID)) || - !ov::is_type(op->get_input_node_shared_ptr(END_ID)) || - (op->get_input_size() > STRIDE_ID && !ov::is_type(op->get_input_node_ptr(STRIDE_ID))) || - (op->get_input_size() > AXES_ID && !ov::is_type(op->get_input_node_ptr(AXES_ID)))) { - // TODO: Support begin, end, stride, axis inputs for dynamic shapes. - errorMessage = "Only Constant 'begin', 'end', 'stride' and 'axis' inputs are supported."; - return false; + if (op->get_input_size() > AXES_ID && !ov::is_type(op->get_input_node_ptr(AXES_ID))) { + // TODO: all required modifications are completed on the node level. More functional tests have to be implemented to resolve the limitation. + errorMessage = "Only constant 'axis' input is supported."; + return false; + } } } catch (...) { return false; @@ -52,60 +39,49 @@ bool StridedSlice::isSupportedOperation(const std::shared_ptr& o } StridedSlice::StridedSlice(const std::shared_ptr& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache) : - Node(op, eng, cache, NgraphShapeInferFactory(op, EMPTY_PORT_MASK)) { + Node(op, eng, cache, NgraphShapeInferFactory(op, PortMask(1, 2, 3, 4))) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { IE_THROW(NotImplemented) << errorMessage; } + errorPrefix = NameFromType(getType()) + " node with name '" + getName() + "' "; - isStridedSliceOp = ov::is_type(op); + attrs.isStridedSliceOp = ov::is_type(op); - if ((isStridedSliceOp && (inputShapes.size() < 3 || inputShapes.size() > 4)) || - (!isStridedSliceOp && (inputShapes.size() < 4 || inputShapes.size() > 5))) { - THROW_ERROR << "has incorrect number of input edges"; + if ((attrs.isStridedSliceOp && (inputShapes.size() < 3 || inputShapes.size() > 4)) || + (!attrs.isStridedSliceOp && (inputShapes.size() < 4 || inputShapes.size() > 5))) { + IE_THROW() << errorPrefix << "has incorrect number of input edges"; } if (outputShapes.size() != 1) { - THROW_ERROR << "has incorrect number of output edges"; + IE_THROW() << errorPrefix << "has incorrect number of output edges"; } - for (size_t i = 0lu; i < op->get_input_size(); i++) { - isConstantInput[i] = ov::is_type(op->inputs()[i].get_node()); - } - - attrs.beginDims = getInputShapeAtPort(BEGIN_ID).getStaticDims(); - attrs.endDims = getInputShapeAtPort(END_ID).getStaticDims(); - if (attrs.beginDims.size() != 1) - THROW_ERROR << "should have begin vector with 1 dimension"; - if (attrs.endDims.size() != 1) - THROW_ERROR << "should have end vector with 1 dimension"; - if (attrs.beginDims[0] != attrs.endDims[0]) - THROW_ERROR << "should have begin vector with size equal to end vector size"; if (inputShapes.size() > STRIDE_ID) { isStrideSpecified = true; - attrs.strideDims = getInputShapeAtPort(STRIDE_ID).getStaticDims(); - if (attrs.strideDims.size() > 1) - THROW_ERROR << "should have stride vector with 1 dimension"; - if (attrs.beginDims[0] != attrs.strideDims[0]) - THROW_ERROR << "should have stride vector with size equal to begin vector size"; } if (inputShapes.size() > AXES_ID) { isAxesSpecified = true; - attrs.axesDims = inputShapes[AXES_ID].getStaticDims(); - if (attrs.axesDims.size() != 1) - THROW_ERROR << "should have axes vector with 1 dimension."; - if (attrs.beginDims[0] != attrs.axesDims[0]) - THROW_ERROR << "should have axes vector with size equal to begin vector size."; } - if (isStridedSliceOp) { + for (size_t i = 0lu; i < op->get_input_size(); i++) { + isConstantInput[i] = ov::is_type(op->get_input_node_shared_ptr(i)); + + if (!isConstantInput[i] && one_of(i, 1, 2, 3)) { + shapeHasDataDependency = true; + } + } + hasConstAttrInputs = !shapeHasDataDependency; + if (isAxesSpecified) + hasConstAttrInputs &= isConstantInput[AXES_ID]; + + const size_t inputRank = getInputShapeAtPort(DATA_ID).getRank(); + const size_t outputRank = getOutputShapeAtPort(0).getRank(); + const size_t nDims = std::max(inputRank, outputRank); + + if (attrs.isStridedSliceOp) { auto ss = ov::as_type_ptr(op); - const size_t inputRank = getInputShapeAtPort(DATA_ID).getRank(); - const size_t outputRank = getOutputShapeAtPort(0).getRank(); - - const size_t nDims = std::max(inputRank, outputRank); - auto createMask = [&](const std::vector &origMask, const int bit = 0, bool needReverse = false) { std::vector mask(origMask.size()); for (size_t i = 0; i < mask.size(); i++) { @@ -145,22 +121,14 @@ StridedSlice::StridedSlice(const std::shared_ptr& op, const dnnl::engi attrs.shrinkAxisMask = std::vector(length, 0); attrs.ellipsisMask = std::vector(length, 0); } -} -void StridedSlice::getSupportedDescriptors() { - const size_t inputRank = getInputShapeAtPort(DATA_ID).getRank(); - const size_t outputRank = getOutputShapeAtPort(0).getRank(); - const size_t nDims = std::max(inputRank, outputRank); - - int ellipsisMaskCounter = 0; - int ellipsisPos1 = -1; - if (isStridedSliceOp) { + if (attrs.isStridedSliceOp) { for (size_t i = 0; i < attrs.ellipsisMask.size(); i++) { - ellipsisMaskCounter += attrs.ellipsisMask[i]; - ellipsisPos1 = attrs.ellipsisMask[i] == 1 && ellipsisPos1 == -1 ? i : ellipsisPos1; + attrs.ellipsisMaskCounter += attrs.ellipsisMask[i]; + attrs.ellipsisPos1 = attrs.ellipsisMask[i] == 1 && attrs.ellipsisPos1 == -1 ? i : attrs.ellipsisPos1; } - if (ellipsisMaskCounter > 1) - THROW_ERROR << "has incorrect 'Ellipsis_mask'. Only one non-zero bit is allowed"; + if (attrs.ellipsisMaskCounter > 1) + IE_THROW() << errorPrefix << "has incorrect 'Ellipsis_mask'. Only one non-zero bit is allowed"; int newAxis = std::accumulate(attrs.newAxisMask.begin(), attrs.newAxisMask.end(), 0); int shrinkAxis = std::accumulate(attrs.shrinkAxisMask.begin(), attrs.shrinkAxisMask.end(), 0); @@ -169,30 +137,32 @@ void StridedSlice::getSupportedDescriptors() { attrs.equalDims = true; } - auto fillingInParameters = [&](std::vector ¶meter, const size_t type, const size_t size, const int value) { - const auto constNode = std::dynamic_pointer_cast(getParentEdgesAtPort(type)[0]->getParent()); - if (!constNode) { - THROW_ERROR << "can't cast node on " << type << " port to Input"; - } - auto blob = constNode->getMemoryPtr(); - if (blob->GetDataType() != dnnl::memory::data_type::s32) - THROW_ERROR << "supports only parameters input with precision I32"; - const int *ptr = static_cast(blob->GetPtr()); - parameter.assign(ptr, ptr + size); + auto fillingInParameters = [&](std::vector ¶meter, const size_t type, const int value) { + if (!isConstantInput[type]) + return; - if (type != AXES_ID && ellipsisMaskCounter == 0 && size < nDims) { + const auto constNode = ov::as_type_ptr(op->get_input_node_shared_ptr(type)); + parameter = constNode->cast_vector(); + + auto size = constNode->get_shape()[0]; + if (type != AXES_ID && attrs.ellipsisMaskCounter == 0 && size < nDims) { for (size_t i = size; i < nDims; i++) parameter.push_back(value); } }; - if (attrs.beginDims.size()) - fillingInParameters(attrs.begin, BEGIN_ID, attrs.beginDims[0], 0); - if (attrs.endDims.size()) - fillingInParameters(attrs.end, END_ID, attrs.endDims[0], 0); - if (attrs.strideDims.size()) - fillingInParameters(attrs.stride, STRIDE_ID, attrs.strideDims[0], 1); - if (attrs.axesDims.size()) { - fillingInParameters(attrs.axes, AXES_ID, attrs.axesDims[0], 0); + fillingInParameters(attrs.begin, BEGIN_ID, 0); + fillingInParameters(attrs.end, END_ID, 0); + if (inputShapes.size() > STRIDE_ID) + fillingInParameters(attrs.stride, STRIDE_ID, 1); + if (inputShapes.size() > AXES_ID) + fillingInParameters(attrs.axes, AXES_ID, 0); +} + +void StridedSlice::getSupportedDescriptors() { +} + +static void addHiddenDims(StridedSlice::StridedSliceAttributes& attrs, const size_t inputRank, const size_t outputRank, bool withAxis) { + if (withAxis) { std::vector beginTmp(outputRank, 0); std::vector endTmp(outputRank, -1); std::vector strideTmp(outputRank, 1); @@ -211,36 +181,32 @@ void StridedSlice::getSupportedDescriptors() { attrs.stride = strideTmp; } - if (inputRank > 3 && attrs.equalDims && ellipsisMaskCounter == 1) - addHiddenDims(inputRank, ellipsisPos1); -} + if (inputRank > 3 && attrs.equalDims && attrs.ellipsisMaskCounter == 1) { + // all masks and input parameters are for planar layouts. So if we use blocked or per channel layout and + // there is ellipsis should to add default values in hidden dimensions to know real order of mask or parameter values + size_t afterDims = attrs.begin.size() - attrs.ellipsisPos1 - 1; + size_t ellipsisPos2 = inputRank - afterDims - 1; + auto addHiddenDims = [&](std::vector& data, const int bit = 0) { + std::vector temp; + for (size_t i = 0; i < attrs.ellipsisPos1; i++) + temp.push_back(data[i]); + for (size_t i = attrs.ellipsisPos1; i < ellipsisPos2 + 1; i++) + temp.push_back(bit); + for (size_t i = 1; i < inputRank - ellipsisPos2; i++) + temp.push_back(data[i + attrs.ellipsisPos1]); + data = temp; + }; -void StridedSlice::addHiddenDims(const size_t nSrcDims, int ellipsisPos1) { - // all masks and input parameters are for planar layouts. So if we use blocked or per channel layout and - // there is ellipsis should to add default values in hidden dimensions to know real order of mask or parameter values - size_t afterDims = attrs.begin.size() - ellipsisPos1 - 1; - size_t ellipsisPos2 = nSrcDims - afterDims - 1; - - auto addHiddenDims = [&](std::vector& data, const int bit = 0) { - std::vector temp; - for (size_t i = 0; i < ellipsisPos1; i++) - temp.push_back(data[i]); - for (size_t i = ellipsisPos1; i < ellipsisPos2 + 1; i++) - temp.push_back(bit); - for (size_t i = 1; i < nSrcDims - ellipsisPos2; i++) - temp.push_back(data[i + ellipsisPos1]); - data = temp; - }; - - addHiddenDims(attrs.begin); - addHiddenDims(attrs.end); - addHiddenDims(attrs.stride, 1); - addHiddenDims(attrs.beginMask); - addHiddenDims(attrs.endMask); - addHiddenDims(attrs.ellipsisMask); - addHiddenDims(attrs.newAxisMask); - addHiddenDims(attrs.shrinkAxisMask); + addHiddenDims(attrs.begin); + addHiddenDims(attrs.end); + addHiddenDims(attrs.stride, 1); + addHiddenDims(attrs.beginMask); + addHiddenDims(attrs.endMask); + addHiddenDims(attrs.ellipsisMask); + addHiddenDims(attrs.newAxisMask); + addHiddenDims(attrs.shrinkAxisMask); + } } void StridedSlice::initSupportedPrimitiveDescriptors() { @@ -274,20 +240,27 @@ void StridedSlice::initSupportedPrimitiveDescriptors() { std::vector supportedTypes; if (nDims > 2 && attrs.equalDims) { - auto canUseBlocked = [&](const size_t blockSize) { + auto canUseBlocked = [&](StridedSliceAttributes& tmpAttrs, const size_t blockSize) { + if (!isConstantInput[BEGIN_ID]) + return false; const auto& srcDims = getInputShapeAtPort(DATA_ID).getDims(); if (srcDims[1] == Shape::UNDEFINED_DIM) return false; - auto channelBeginNormalized = attrs.begin[1] > 0 ? attrs.begin[1] : attrs.begin[1] + static_cast(srcDims[1]); - return srcDims[1] % blockSize == 0 && abs(attrs.stride[1]) == 1 && - (channelBeginNormalized > srcDims[1] || channelBeginNormalized % blockSize == 0 || channelBeginNormalized < 0 || attrs.beginMask[1] == 0); + auto channelBeginNormalized = tmpAttrs.begin[1] > 0 ? tmpAttrs.begin[1] : tmpAttrs.begin[1] + static_cast(srcDims[1]); + return srcDims[1] % blockSize == 0 && abs(tmpAttrs.stride[1]) == 1 && + (channelBeginNormalized > srcDims[1] || channelBeginNormalized % blockSize == 0 || channelBeginNormalized < 0 || tmpAttrs.beginMask[1] == 0); }; supportedTypes.push_back(LayoutType::nspc); - if (canUseBlocked(8lu)) - supportedTypes.push_back(LayoutType::nCsp8c); - if (canUseBlocked(16lu)) - supportedTypes.push_back(LayoutType::nCsp16c); + + if (hasConstAttrInputs) { + auto tmpAttrs = attrs; + addHiddenDims(tmpAttrs, getInputShapeAtPort(DATA_ID).getRank(), getOutputShapeAtPort(0).getRank(), isAxesSpecified); + if (canUseBlocked(tmpAttrs, 8lu)) + supportedTypes.push_back(LayoutType::nCsp8c); + if (canUseBlocked(tmpAttrs, 16lu)) + supportedTypes.push_back(LayoutType::nCsp16c); + } } supportedTypes.push_back(LayoutType::ncsp); auto creators = BlockedDescCreator::getCommonCreators(); @@ -312,51 +285,89 @@ bool StridedSlice::isExecutable() const { } void StridedSlice::createPrimitive() { - if (!isExecutable()) { - return; - } - auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); - auto& srcMemPtr = getParentEdgeAt(DATA_ID)->getMemoryPtr(); - if (!dstMemPtr || !dstMemPtr->isAllocated()) - THROW_ERROR << "has not allocated destination memory."; - if (!srcMemPtr || !srcMemPtr->isAllocated()) - THROW_ERROR << "has not allocated input memory."; - if (getSelectedPrimitiveDescriptor() == nullptr) - THROW_ERROR << "has unidentified preferable primitive descriptor."; - - if (!srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) - orderParametersByLayouts(srcMemPtr); - - if (inputShapesDefined()) { - prepareParams(); + if (inputShapesDefined() && isExecutable() && !shapeHasDataDependency) { + if (needPrepareParams()) { + prepareParams(); + } updateLastInputDims(); } } -void StridedSlice::orderParametersByLayouts(const MemoryPtr& srcMemPtr) { +bool StridedSlice::needPrepareParams() const { + return true; +} + +void StridedSlice::prepareParams() { + updateLastInputDims(); + + if (srcMemory.empty()) { + for (int i = 0; i < getOriginalInputsNumber(); i++) { + srcMemory.push_back(getParentEdgeAt(i)->getMemoryPtr()); + } + } + if (dstMemory.empty()) { + for (int i = 0; i < getOriginalOutputsNumber(); i++) { + dstMemory.push_back(getChildEdgeAt(i)->getMemoryPtr()); + } + } + + execPtr = std::make_shared(attrs, srcMemory, dstMemory, errorPrefix); +} + +bool StridedSlice::needShapeInfer() const { + return Node::inputShapesModified() || shapeHasDataDependency; +} + +void StridedSlice::execute(dnnl::stream strm) { + if (!execPtr) + IE_THROW() << errorPrefix << "doesn't have compiled executor!"; + + execPtr->exec(srcMemory, dstMemory); +} + +void StridedSlice::executeDynamicImpl(dnnl::stream strm) { + execute(strm); +} + +bool StridedSlice::created() const { + return getType() == Type::StridedSlice; +} + +StridedSlice::StridedSliceCommonExecutor::StridedSliceCommonExecutor(const StridedSliceAttributes& attrs, + const std::vector& srcMemory, + const std::vector& dstMemory, + const std::string& errorPrefix) + : StridedSliceExecutor(attrs, srcMemory, dstMemory, errorPrefix) { + paramsInitialization(attrs, srcMemory, dstMemory); + dimsNormalization(); + dimsGluing(); + indicesCalculation(); +} + +void StridedSlice::StridedSliceCommonExecutor::orderParametersByLayouts(const BlockedMemoryDescCPtr& blockedMemoryDesc) { size_t blk = 1; bool isBlockedLayout = false; - if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nCsp16c)) { + if (blockedMemoryDesc->hasLayoutType(LayoutType::nCsp16c)) { isBlockedLayout = true; blk = 16; - } else if (srcMemPtr->getDesc().hasLayoutType(LayoutType::nCsp8c)) { + } else if (blockedMemoryDesc->hasLayoutType(LayoutType::nCsp8c)) { isBlockedLayout = true; blk = 8; } - const bool isPerChannelLayout = srcMemPtr->getDesc().hasLayoutType(LayoutType::nspc); - auto srcOrder = srcMemPtr->GetDescWithType()->getOrder(); + const bool isPerChannelLayout = blockedMemoryDesc->hasLayoutType(LayoutType::nspc); + auto srcOrder = blockedMemoryDesc->getOrder(); if (isBlockedLayout) { - attrs.begin[1] = attrs.begin[1] / blk; - attrs.end[1] = ceil(attrs.end[1] / static_cast(blk)); - attrs.begin.push_back(0); - attrs.end.push_back(0); - attrs.stride.push_back(1); - attrs.beginMask.push_back(0); - attrs.endMask.push_back(0); - attrs.ellipsisMask.push_back(0); - attrs.newAxisMask.push_back(0); - attrs.shrinkAxisMask.push_back(0); + params.attrs.begin[1] = params.attrs.begin[1] / blk; + params.attrs.end[1] = ceil(params.attrs.end[1] / static_cast(blk)); + params.attrs.begin.push_back(0); + params.attrs.end.push_back(0); + params.attrs.stride.push_back(1); + params.attrs.beginMask.push_back(0); + params.attrs.endMask.push_back(0); + params.attrs.ellipsisMask.push_back(0); + params.attrs.newAxisMask.push_back(0); + params.attrs.shrinkAxisMask.push_back(0); } else if (isPerChannelLayout) { auto sortByOrder = [&](std::vector& data) { std::vector temp(srcOrder.size()); @@ -365,40 +376,86 @@ void StridedSlice::orderParametersByLayouts(const MemoryPtr& srcMemPtr) { data = temp; }; - sortByOrder(attrs.begin); - sortByOrder(attrs.end); - sortByOrder(attrs.stride); - sortByOrder(attrs.beginMask); - sortByOrder(attrs.endMask); - if (isStridedSliceOp) { - sortByOrder(attrs.ellipsisMask); - sortByOrder(attrs.newAxisMask); - sortByOrder(attrs.shrinkAxisMask); + sortByOrder(params.attrs.begin); + sortByOrder(params.attrs.end); + sortByOrder(params.attrs.stride); + sortByOrder(params.attrs.beginMask); + sortByOrder(params.attrs.endMask); + if (params.attrs.isStridedSliceOp) { + sortByOrder(params.attrs.ellipsisMask); + sortByOrder(params.attrs.newAxisMask); + sortByOrder(params.attrs.shrinkAxisMask); } } } -void StridedSlice::prepareParams() { - execPtr = std::make_shared(attrs, - getParentEdgeAt(0)->getMemoryPtr()->GetDescWithType()->getBlockDims(), - getChildEdgeAt(0)->getMemoryPtr()->GetDescWithType()->getBlockDims()); -} +void StridedSlice::StridedSliceCommonExecutor::paramsInitialization(const StridedSliceAttributes& attrs, + const std::vector& srcMemory, + const std::vector& dstMemory) { + const auto srcBlockedMemoryDesc = srcMemory[0]->GetDescWithType(); + const auto dstBlockedMemoryDesc = dstMemory[0]->GetDescWithType(); -StridedSlice::StridedSliceExecutor::StridedSliceExecutor(const StridedSliceAttributes& attrs, - const VectorDims& srcBlockedDims, - const VectorDims& dstBlockedDims) { - StridedSliceParams params; - params.srcBlockedDims = srcBlockedDims; - params.dstBlockedDims = dstBlockedDims; params.attrs = attrs; + params.srcBlockedDims = srcBlockedMemoryDesc->getBlockDims(); + params.srcOrder = srcBlockedMemoryDesc->getOrder(); + params.dstBlockedDims = dstBlockedMemoryDesc->getBlockDims(); - size_t realNDims = params.dstBlockedDims.size(); - dimsNormalization(params); - dimsGluing(params, realNDims); - indicesCalculation(params); + const size_t inputRank = srcMemory[0]->GetShape().getRank(); + const size_t outputRank = dstMemory[0]->GetShape().getRank(); + const size_t nDims = std::max(inputRank, outputRank); + + auto fillingInParameters = [&](std::vector ¶meter, const size_t type, const size_t size, const int value) { + const int *ptr = reinterpret_cast(srcMemory[type]->GetPtr()); + parameter.assign(ptr, ptr + size); + + if (type != AXES_ID && params.attrs.ellipsisMaskCounter == 0 && size < nDims) { + for (size_t i = size; i < nDims; i++) parameter.push_back(value); + } + }; + + params.attrs.beginDims = srcMemory[BEGIN_ID]->GetShape().getStaticDims(); + params.attrs.endDims = srcMemory[END_ID]->GetShape().getStaticDims(); + if (params.attrs.beginDims.size() != 1) + IE_THROW() << errorPrefix << "should have begin vector with 1 dimension"; + if (params.attrs.endDims.size() != 1) + IE_THROW() << errorPrefix << "should have end vector with 1 dimension"; + if (params.attrs.beginDims[0] != params.attrs.endDims[0]) + IE_THROW() << errorPrefix << "should have begin vector with size equal to end vector size"; + + if (params.attrs.begin.empty()) + fillingInParameters(params.attrs.begin, BEGIN_ID, params.attrs.beginDims[0], 0); + if (params.attrs.end.empty()) + fillingInParameters(params.attrs.end, END_ID, params.attrs.endDims[0], 0); + + if (srcMemory.size() > STRIDE_ID) { + params.attrs.strideDims = srcMemory[STRIDE_ID]->GetShape().getStaticDims(); + if (params.attrs.strideDims.size() > 1) + IE_THROW() << errorPrefix << "should have stride vector with 1 dimension"; + if (params.attrs.beginDims[0] != params.attrs.strideDims[0]) + IE_THROW() << errorPrefix << "should have stride vector with size equal to begin vector size"; + + if (params.attrs.stride.empty()) + fillingInParameters(params.attrs.stride, STRIDE_ID, params.attrs.strideDims[0], 1); + } + + if (srcMemory.size() > AXES_ID) { + params.attrs.axesDims = srcMemory[AXES_ID]->GetShape().getStaticDims(); + if (params.attrs.axesDims.size() != 1) + IE_THROW() << errorPrefix << "should have axes vector with 1 dimension."; + if (params.attrs.beginDims[0] != params.attrs.axesDims[0]) + IE_THROW() << errorPrefix << "should have axes vector with size equal to begin vector size."; + + if (params.attrs.axes.empty()) + fillingInParameters(params.attrs.axes, AXES_ID, params.attrs.axesDims[0], 0); + } + + addHiddenDims(params.attrs, inputRank, outputRank, srcMemory.size() > AXES_ID); + + if (!srcBlockedMemoryDesc->hasLayoutType(LayoutType::ncsp)) + orderParametersByLayouts(srcBlockedMemoryDesc); } -void StridedSlice::StridedSliceExecutor::dimsNormalization(StridedSliceParams& params) { +void StridedSlice::StridedSliceCommonExecutor::dimsNormalization() { // creating new src and dst dimensions and parameters of the same size using masks // // example 1: before srcDims = [5, 6, 8, 3, 2], begin = [1, 0], end = [4, 0], stride = [1, 1] @@ -507,11 +564,13 @@ void StridedSlice::StridedSliceExecutor::dimsNormalization(StridedSliceParams& p } } -void StridedSlice::StridedSliceExecutor::dimsGluing(StridedSliceParams& params, const size_t realNDims) { +void StridedSlice::StridedSliceCommonExecutor::dimsGluing() { // gluing of dimensions if there aren't begin, end and stride != 1 on this axis // example: before gluing srcDims = [5, 6, 8, 3, 2], begin = [1, 0, 0, 0, 0], stride = [1, 1, 2, 1, 1], dstDims = [4, 6, 4, 3, 2] // after gluing srcDims = [30, 8, 6], begin = [6, 0, 0], stride = [1, 2, 1], dstDims = [24, 4, 6] + size_t realNDims = params.dstBlockedDims.size(); + std::pair secondDim = { 0, params.attrs.begin.size() }; VectorDims indexes(1, 0); for (int idx = 0; idx < params.attrs.begin.size(); idx++) { @@ -601,7 +660,15 @@ void StridedSlice::StridedSliceExecutor::dimsGluing(StridedSliceParams& params, } } -void StridedSlice::StridedSliceExecutor::indicesCalculation(const StridedSliceParams& params) { +static inline size_t parallel_init(size_t start, size_t nDims, const VectorDims& dims, VectorDims& indexes) { + for (int j = nDims - 1; j >= 0; j--) { + indexes[j] = start % dims[j]; + start = start / dims[j]; + } + return start; +} + +void StridedSlice::StridedSliceCommonExecutor::indicesCalculation() { // indices calculation before execution for the best performance srcIndices.resize(workAmount, 0); dstIndices.resize(workAmount, 0); @@ -611,7 +678,7 @@ void StridedSlice::StridedSliceExecutor::indicesCalculation(const StridedSlicePa nThreads = nthr > workAmount ? workAmount : nthr; if (params.isOptimized) { - indicesCalculationForOptimized(params); + indicesCalculationForOptimized(); return; } @@ -651,7 +718,7 @@ void StridedSlice::StridedSliceExecutor::indicesCalculation(const StridedSlicePa }); } -void StridedSlice::StridedSliceExecutor::indicesCalculationForOptimized(const StridedSliceParams& params) { +void StridedSlice::StridedSliceCommonExecutor::indicesCalculationForOptimized() { const size_t dstIdx0 = params.dstStrides[0] * params.attrs.dataSize; const size_t dstIdx1 = params.dstStrides[1] * params.attrs.dataSize; const size_t srcIdx0 = params.attrs.stride[0] * params.srcStrides[0] * params.attrs.dataSize; @@ -670,7 +737,9 @@ void StridedSlice::StridedSliceExecutor::indicesCalculationForOptimized(const St } } -void StridedSlice::StridedSliceExecutor::exec(const uint8_t* srcData, uint8_t* dstData) { +void StridedSlice::StridedSliceCommonExecutor::exec(const std::vector& srcMemory, const std::vector& dstMemory) { + const uint8_t* srcData = reinterpret_cast(srcMemory[0]->GetPtr()); + uint8_t* dstData = reinterpret_cast(dstMemory[0]->GetPtr()); const uint8_t* srcShiftedData = srcData + srcShift; parallel_nt(nThreads, [&](const int ithr, const int nthr) { size_t start = 0, end = 0; @@ -681,22 +750,6 @@ void StridedSlice::StridedSliceExecutor::exec(const uint8_t* srcData, uint8_t* d }); } -void StridedSlice::execute(dnnl::stream strm) { - if (!execPtr) - THROW_ERROR << "doesn't have compiled executor!"; - const uint8_t* srcData = reinterpret_cast(getParentEdgeAt(0)->getMemory().GetPtr()); - uint8_t* dstData = reinterpret_cast(getChildEdgeAt(0)->getMemory().GetPtr()); - execPtr->exec(srcData, dstData); -} - -void StridedSlice::executeDynamicImpl(dnnl::stream strm) { - execute(strm); -} - -bool StridedSlice::created() const { - return getType() == Type::StridedSlice; -} - } // namespace node } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/strided_slice.h b/src/plugins/intel_cpu/src/nodes/strided_slice.h index b0e9d73ecf4..6ae6210d1cc 100644 --- a/src/plugins/intel_cpu/src/nodes/strided_slice.h +++ b/src/plugins/intel_cpu/src/nodes/strided_slice.h @@ -27,14 +27,7 @@ public: } bool isExecutable() const override; - -protected: - void prepareParams() override; - void executeDynamicImpl(dnnl::stream strm) override; - -private: - void addHiddenDims(const size_t nSrcDims, int ellipsisPos1); - void orderParametersByLayouts(const MemoryPtr& srcMemPtr); + bool needShapeInfer() const override; struct StridedSliceAttributes { std::vector begin; @@ -55,17 +48,46 @@ private: bool equalDims = false; size_t dataSize = 1lu; + int ellipsisMaskCounter = 0; + bool isStridedSliceOp = true; + int ellipsisPos1 = -1; + bool hasConstInputs = false; } attrs; - struct StridedSliceExecutor { - StridedSliceExecutor(const StridedSliceAttributes& attrs, const VectorDims& srcBlockedDims, const VectorDims& dstBlockedDims); - void exec(const uint8_t* srcData, uint8_t* dstData); - ~StridedSliceExecutor() = default; +protected: + bool needPrepareParams() const override; + void prepareParams() override; + void executeDynamicImpl(dnnl::stream strm) override; + +private: + class StridedSliceExecutor { + public: + StridedSliceExecutor(const StridedSliceAttributes& attrs, + const std::vector& srcMemory, + const std::vector& dstMemory, + const std::string& errorPrefix) : errorPrefix(errorPrefix) {} + virtual void exec(const std::vector& srcMemory, + const std::vector& dstMemory) = 0; + virtual ~StridedSliceExecutor() = default; + + protected: + const std::string errorPrefix; + }; + + class StridedSliceCommonExecutor : public StridedSliceExecutor { + public: + StridedSliceCommonExecutor(const StridedSliceAttributes& attrs, + const std::vector& srcMemory, + const std::vector& dstMemory, + const std::string& errorPrefix); + void exec(const std::vector& srcMemory, + const std::vector& dstMemory) override; private: struct StridedSliceParams { StridedSliceAttributes attrs; VectorDims srcBlockedDims; + VectorDims srcOrder; VectorDims dstBlockedDims; VectorDims srcStrides; VectorDims dstStrides; @@ -73,11 +95,16 @@ private: bool isOptimized = false; }; - void dimsNormalization(StridedSliceParams& params); - void dimsGluing(StridedSliceParams& params, const size_t realNDims); - void indicesCalculation(const StridedSliceParams& params); - void indicesCalculationForOptimized(const StridedSliceParams& params); + void paramsInitialization(const StridedSliceAttributes& attrs, + const std::vector& srcMemory, + const std::vector& dstMemory); + void dimsNormalization(); + void dimsGluing(); + void indicesCalculation(); + void indicesCalculationForOptimized(); + void orderParametersByLayouts(const BlockedMemoryDescCPtr& blockedMemoryDesc); + StridedSliceParams params; VectorDims srcIndices; VectorDims dstIndices; size_t nThreads = 0lu; @@ -88,7 +115,6 @@ private: using executorPtr = std::shared_ptr; executorPtr execPtr = nullptr; - bool isStridedSliceOp = true; bool isStrideSpecified = false; bool isAxesSpecified = false; @@ -99,6 +125,13 @@ private: static constexpr size_t AXES_ID = 4; bool isConstantInput[AXES_ID + 1] = {false}; + bool shapeHasDataDependency = false; + bool hasConstAttrInputs = true; + + std::vector srcMemory; + std::vector dstMemory; + + std::string errorPrefix; }; } // namespace node diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/strided_slice.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/strided_slice.cpp index 749b917ded7..c1f3ef47da8 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/strided_slice.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/strided_slice.cpp @@ -6,7 +6,7 @@ #include "ngraph_functions/builders.hpp" #include "test_utils/cpu_test_utils.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" - +#include using namespace InferenceEngine; using namespace CPUTestUtils; @@ -30,6 +30,7 @@ struct StridedSliceParams { typedef std::tuple< InputShape, // Input shapes StridedSliceParams, + ngraph::helpers::InputLayerType, // Secondary input types ElementType, // Element type CPUSpecificParams> StridedSliceLayerCPUTestParamSet; @@ -39,9 +40,10 @@ public: static std::string getTestCaseName(testing::TestParamInfo obj) { InputShape shapes; StridedSliceParams params; - ElementType elementType; + ngraph::helpers::InputLayerType secondaryInputType; + ElementType dataType; CPUSpecificParams cpuParams; - std::tie(shapes, params, elementType, cpuParams) = obj.param; + std::tie(shapes, params, secondaryInputType, dataType, cpuParams) = obj.param; std::ostringstream results; results << "IS=" << CommonTestUtils::partialShape2str({shapes.first}) << "_"; @@ -49,7 +51,8 @@ public: for (const auto& item : shapes.second) { results << CommonTestUtils::vec2str(item) << "_"; } - results << "netPRC=" << elementType << "_"; + results << "secondaryInputType=" << secondaryInputType << "_"; + results << "netPRC=" << dataType << "_"; results << "begin=" << CommonTestUtils::vec2str(params.begin) << "_"; results << "end=" << CommonTestUtils::vec2str(params.end) << "_"; results << "stride=" << CommonTestUtils::vec2str(params.strides) << "_"; @@ -62,23 +65,67 @@ public: return results.str(); } + protected: + void generate_inputs(const std::vector& targetInputStaticShapes) override { + std::vector inputValues = {ssParams.begin.data(), ssParams.end.data(), ssParams.strides.data()}; + + inputs.clear(); + const auto& funcInputs = function->inputs(); + for (int i = 0; i < funcInputs.size(); ++i) { + const auto& funcInput = funcInputs[i]; + ov::Tensor tensor; + if (i == 0) { + tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 10, 1, 1); + } else { + tensor = ov::Tensor{ov::element::i64, targetInputStaticShapes[i], inputValues[i-1]}; + } + inputs.insert({funcInput.get_node_shared_ptr(), tensor}); + } + } + void SetUp() override { InputShape shapes; - StridedSliceParams ssParams; + ngraph::helpers::InputLayerType secondaryInputType; CPUSpecificParams cpuParams; - std::tie(shapes, ssParams, inType, cpuParams) = this->GetParam(); + ov::element::Type dataType; + std::tie(shapes, ssParams, secondaryInputType, dataType, cpuParams) = this->GetParam(); std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; - selectedType = makeSelectedTypeStr("ref", inType); + selectedType = makeSelectedTypeStr("ref", dataType); targetDevice = CommonTestUtils::DEVICE_CPU; - init_input_shapes({shapes}); + std::vector input_shapes = {shapes}; - auto params = ngraph::builder::makeDynamicParams(inType, inputDynamicShapes); - auto ss = ngraph::builder::makeStridedSlice(params[0], ssParams.begin, ssParams.end, ssParams.strides, inType, ssParams.beginMask, - ssParams.endMask, ssParams.newAxisMask, ssParams.shrinkAxisMask, ssParams.ellipsisAxisMask); + init_input_shapes({input_shapes}); + for (auto& targetShapes : targetStaticShapes) { + targetShapes.push_back({ssParams.begin.size()}); + targetShapes.push_back({ssParams.end.size()}); + targetShapes.push_back({ssParams.strides.size()}); + } + + auto params = ngraph::builder::makeDynamicParams(dataType, inputDynamicShapes); + std::shared_ptr ss; + if (secondaryInputType == ngraph::helpers::InputLayerType::PARAMETER) { + ov::Shape inShape = {ssParams.begin.size()}; + + auto beginNode = std::make_shared(ov::element::i64, inShape); + auto endNode = std::make_shared(ov::element::i64, inShape); + auto strideNode = std::make_shared(ov::element::i64, inShape); + + params.push_back(std::dynamic_pointer_cast(beginNode)); + params.push_back(std::dynamic_pointer_cast(endNode)); + params.push_back(std::dynamic_pointer_cast(strideNode)); + + ss = ngraph::builder::makeStridedSlice(params[0], beginNode, endNode, strideNode, inType, ssParams.beginMask, + ssParams.endMask, ssParams.newAxisMask, ssParams.shrinkAxisMask, ssParams.ellipsisAxisMask); + } else { + ss = ngraph::builder::makeStridedSlice(params[0], ssParams.begin, ssParams.end, ssParams.strides, inType, ssParams.beginMask, + ssParams.endMask, ssParams.newAxisMask, ssParams.shrinkAxisMask, ssParams.ellipsisAxisMask); + } function = makeNgraphFunction(inType, params, ss, "StridedSlice"); } + + StridedSliceParams ssParams; }; TEST_P(StridedSliceLayerCPUTest, CompareWithRefs) { @@ -106,6 +153,11 @@ const std::vector inputPrecisions = { ElementType::i8 }; +const std::vector inputLayerTypes = { + ngraph::helpers::InputLayerType::CONSTANT, + ngraph::helpers::InputLayerType::PARAMETER +}; + const std::vector inputShapesDynamic2D = { {{-1, -1}, {{32, 20}, {16, 16}, {24, 16}}}, @@ -118,16 +170,18 @@ const std::vector inputShapesDynamic2D = { }; const std::vector paramsPlain2D = { - StridedSliceParams{ { 0, 10 }, { 16, 16 }, { 1, 1 }, { 0, 0 }, { 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 2, 5 }, { 16, 8 }, { 1, 1 }, { 0, 0 }, { 0, 0 }, { }, { }, { } }, - StridedSliceParams{ { 2, 5 }, { 16, 16 }, { 1, 2 }, { 0, 1 }, { 1, 0 }, { }, { }, { } }, - StridedSliceParams{ { 0, 0 }, { 16, 16 }, { 2, 1 }, { 0, 0 }, { 1, 0 }, { }, { }, { } }, + StridedSliceParams{ { -10, -11 }, { -2, -3 }, { 1, 1 }, { 0, 0 }, { 0, 0 }, { }, { }, { } }, + StridedSliceParams{ { 2, 44 }, { 55, -2 }, { 2, 3 }, { 0, 1 }, { 0, 0 }, { }, { }, { } }, + StridedSliceParams{ { 2, -7 }, { 1, -2 }, { 2, 3 }, { 1, 0 }, { 1, 0 }, { }, { }, { } }, + StridedSliceParams{ { 2 }, { 22 }, { 2 }, { 0 }, { 0 }, { }, { }, { } }, }; INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Plain_Static_2D, StridedSliceLayerCPUTest, ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation({{32, 20}})), ::testing::ValuesIn(paramsPlain2D), + ::testing::ValuesIn(inputLayerTypes), ::testing::ValuesIn(inputPrecisions), ::testing::Values(emptyCPUSpec)), StridedSliceLayerCPUTest::getTestCaseName); @@ -136,22 +190,20 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Plain_Dynamic_2D, StridedSliceLay ::testing::Combine( ::testing::ValuesIn(inputShapesDynamic2D), ::testing::ValuesIn(paramsPlain2D), + ::testing::ValuesIn(inputLayerTypes), ::testing::ValuesIn(inputPrecisions), ::testing::Values(emptyCPUSpec)), StridedSliceLayerCPUTest::getTestCaseName); const std::vector testCasesCommon4D = { StridedSliceParams{ { 0, 2, 5, 4 }, { 1, 4, 28, 27 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, - StridedSliceParams{ { 0, 1, 0, 0 }, { 1, 3, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 10, 0 }, { 1, 3, 20, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 1, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 20, 20 }, { 1, 5, 25, 26 }, { 1, 1, 1, 2 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 0, 20 }, { 1, 2, 30, 30 }, { 1, 1, 2, 1 }, { 0, 0, 0, 1 }, { 0, 1, 0, 1 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 2, 10 }, { 1, 3, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 1, 1 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 1, 0, 10 }, { 1, 5, 32, 30 }, { 1, 1, 1, 1 }, { 0, 1, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, - StridedSliceParams{ { 0, 1, 2, 10 }, { 1, 5, 32, 18 }, { 1, 1, 1, 2 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 2, 10 }, { 1, 8, 32, 18 }, { 1, 2, 1, 2 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 10 }, { 0, 32, 18 }, { 1, 1, 1 }, { 1, 1, 0 }, { 1, 1, 0 }, { }, { }, { 1, 0, 0 } }, - StridedSliceParams{ { 0, 0, 10 }, { 1, 0, 20 }, { 1, 1, 1 }, { 1, 1, 0 }, { 0, 1, 1 }, { }, { }, { 0, 1, 0 } }, StridedSliceParams{ { 0, 4, 10 }, { 1, 8, 0 }, { 1, 1, 1 }, { 1, 0, 1 }, { 1, 1, 1 }, { }, { }, { 0, 0, 1 } } }; @@ -179,6 +231,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Static_4D, StridedSliceLay ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic4D)), ::testing::ValuesIn(testCasesCommon4D), + ::testing::ValuesIn(inputLayerTypes), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsCommon4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -187,6 +240,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_4D, StridedSliceLa ::testing::Combine( ::testing::ValuesIn(inputShapesDynamic4D), ::testing::ValuesIn(testCasesCommon4D), + ::testing::ValuesIn(inputLayerTypes), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsCommon4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -201,15 +255,12 @@ const std::vector testCasesBlocked4DSubset1 = { const std::vector testCasesBlocked4DSubset2 = { StridedSliceParams{ { 0, 0, 5, 4 }, { 1, 16, 28, 27 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 16, 0, 0 }, { 1, 32, 10, 10 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, - StridedSliceParams{ { 0, 0, 10, 0 }, { 1, 16, 20, 10 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 1 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 20, 20 }, { 1, 32, 25, 25 }, { 1, 1, 1, 1 }, { 0, 1, 0, 0 }, { 0, 1, 0, 0 }, { }, { }, { } }, - StridedSliceParams{ { 0, 16, 0, 20 }, { 1, 32, 32, 30 }, { 1, 1, 1, 2 }, { 1, 0, 1, 0 }, { 1, 0, 1, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 16, 2, 10 }, { 1, 32, 32, 20 }, { 1, 1, 2, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 16, 0, 0 }, { 2, 64, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 32, 0, 0 }, { 2, 50, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 0, 0, 0 }, { 2, 12, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, -16, 0, 10 }, { 2, 100, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, - StridedSliceParams{ { 0, -16, 0, 0 }, { 2, -4, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, -32, 0, 0 }, { 2, -12, 32, 20 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } }, StridedSliceParams{ { 0, 10 }, { 0, 20 }, { 1, 1 }, { 1, 0 }, { 1, 0 }, { }, { }, { 1, 0 } }, StridedSliceParams{ { 0, 16, 0 }, { 2, 32, 0 }, { 1, 1, 1 }, { 1, 0, 1 }, { 1, 1, 1 }, { }, { }, { 0, 0, 1 } }, @@ -244,10 +295,15 @@ const std::vector CPUParamsBlocked4D = { cpuParams_nChw8c, }; +const std::vector inputLayerTypesBlocked = { + ngraph::helpers::InputLayerType::CONSTANT, +}; + INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Static_4D_Subset1, StridedSliceLayerCPUTest, ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesBlockedStatic4DSubset1)), ::testing::ValuesIn(testCasesBlocked4DSubset1), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -256,6 +312,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_4D_Subset1, Stride ::testing::Combine( ::testing::ValuesIn(inputShapesBlockedDynamic4DSubset1), ::testing::ValuesIn(testCasesBlocked4DSubset1), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -264,6 +321,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Static_4D_Subset2, Strided ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesBlockedStatic4DSubset2)), ::testing::ValuesIn(testCasesBlocked4DSubset2), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -272,6 +330,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_4D_Subset2, Stride ::testing::Combine( ::testing::ValuesIn(inputShapesBlockedDynamic4DSubset2), ::testing::ValuesIn(testCasesBlocked4DSubset2), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -311,6 +370,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Static_5D, StridedSliceLay ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic5D)), ::testing::ValuesIn(testCasesCommon5D), + ::testing::ValuesIn(inputLayerTypes), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsCommon5D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -319,6 +379,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_5D, StridedSliceLa ::testing::Combine( ::testing::ValuesIn(inputShapesDynamic5D), ::testing::ValuesIn(testCasesCommon5D), + ::testing::ValuesIn(inputLayerTypes), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsCommon5D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -380,6 +441,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Static_5D_Subset1, Strided ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesBlockedStatic5DSubset1)), ::testing::ValuesIn(testCasesBlocked5DSubset1), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked5D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -388,6 +450,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_5D_Subset1, Stride ::testing::Combine( ::testing::ValuesIn(inputShapesBlockedDynamic5DSubset1), ::testing::ValuesIn(testCasesBlocked5DSubset1), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked5D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -396,6 +459,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Static_5D_Subset2, Strided ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesBlockedStatic4DSubset2)), ::testing::ValuesIn(testCasesBlocked4DSubset2), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked4D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -404,6 +468,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_5D_Subset2, Stride ::testing::Combine( ::testing::ValuesIn(inputShapesBlockedDynamic5DSubset2), ::testing::ValuesIn(testCasesBlocked5DSubset2), + ::testing::ValuesIn(inputLayerTypesBlocked), ::testing::ValuesIn(inputPrecisions), ::testing::ValuesIn(CPUParamsBlocked5D)), StridedSliceLayerCPUTest::getTestCaseName); @@ -434,6 +499,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_StridedSliceLayerDescriptorCPUTest, StridedSliceL ::testing::Combine( ::testing::ValuesIn(inputShapesDescriptors), ::testing::ValuesIn(testCasesDescriptors), + ::testing::Values(ngraph::helpers::InputLayerType::CONSTANT), ::testing::Values(ElementType::f32), ::testing::Values(cpuParams_nChw8c)), StridedSliceLayerDescriptorCPUTest::getTestCaseName); diff --git a/src/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp b/src/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp index 8e651f57246..c4f1661ab6c 100644 --- a/src/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp +++ b/src/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp @@ -289,6 +289,17 @@ std::shared_ptr makeStridedSlice(const ngraph::Output &in, const std::vector &shrink_mask = std::vector{}, const std::vector &ellipsis_mask = std::vector{}); +std::shared_ptr makeStridedSlice(const ov::Output &in, + const ov::Output &beginNode, + const ov::Output &endNode, + const ov::Output &strideNode, + const element::Type &type, + const std::vector &begin_mask, + const std::vector &end_mask, + const std::vector &new_axis_mask = std::vector{}, + const std::vector &shrink_mask = std::vector{}, + const std::vector &ellipsis_mask = std::vector{}); + std::shared_ptr makeSlice(const ngraph::Output &in, const std::vector &begin, const std::vector &end, diff --git a/src/tests/ngraph_helpers/ngraph_functions/src/strided_slice.cpp b/src/tests/ngraph_helpers/ngraph_functions/src/strided_slice.cpp index c3e249d2466..a6c3ea3b5de 100644 --- a/src/tests/ngraph_helpers/ngraph_functions/src/strided_slice.cpp +++ b/src/tests/ngraph_helpers/ngraph_functions/src/strided_slice.cpp @@ -7,15 +7,15 @@ namespace ngraph { namespace builder { std::shared_ptr makeStridedSlice(const ov::Output &in, - const std::vector &begin, - const std::vector &end, - const std::vector &stride, - const element::Type &type, - const std::vector &begin_mask, - const std::vector &end_mask, - const std::vector &new_axis_mask, - const std::vector &shrink_mask, - const std::vector &ellipsis_mask) { + const std::vector &begin, + const std::vector &end, + const std::vector &stride, + const element::Type &type, + const std::vector &begin_mask, + const std::vector &end_mask, + const std::vector &new_axis_mask, + const std::vector &shrink_mask, + const std::vector &ellipsis_mask) { ov::Shape constShape = {begin.size()}; auto beginNode = std::make_shared(ov::element::i64, constShape, begin.data()); auto endNode = std::make_shared(ov::element::i64, constShape, end.data()); @@ -25,6 +25,21 @@ std::shared_ptr makeStridedSlice(const ov::Output &in, return ssNode; } +std::shared_ptr makeStridedSlice(const ov::Output &in, + const ov::Output &beginNode, + const ov::Output &endNode, + const ov::Output &strideNode, + const element::Type &type, + const std::vector &begin_mask, + const std::vector &end_mask, + const std::vector &new_axis_mask, + const std::vector &shrink_mask, + const std::vector &ellipsis_mask) { + auto ssNode = std::make_shared(in, beginNode, endNode, strideNode, begin_mask, end_mask, + new_axis_mask, shrink_mask, ellipsis_mask); + return ssNode; +} + std::shared_ptr makeSlice(const ov::Output &in, const std::vector &begin, const std::vector &end,