From dd8a073aa487ec56be9fdae97877c0f0be28985f Mon Sep 17 00:00:00 2001 From: Maxim Andronov Date: Tue, 18 Jan 2022 09:14:29 +0300 Subject: [PATCH] [CPU] FullyConnected leftovers (#9621) --- .../src/nodes/mkldnn_fullyconnected_node.cpp | 65 +++++++++---------- .../src/nodes/mkldnn_fullyconnected_node.h | 3 + 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.cpp b/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.cpp index c2fb6acf419..5a1d52a2d49 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.cpp +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.cpp @@ -129,6 +129,33 @@ std::vector MKLDNNFullyConnectedNode::getAvailableFormatsFor return {memory::format_tag::any}; } +VectorDims MKLDNNFullyConnectedNode::makeDummyInputDims() const { + const auto& inShape = getInputShapeAtPort(DATA_ID); + const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims(); + + auto inMinDims = inShape.getMinDims(); + auto inMaxDims = inShape.getMaxDims(); + + if (inMinDims.size() == 3) { + inMinDims.back() = weightDims.back(); + inMaxDims.back() = weightDims.back(); + } else { + for (size_t i = 1; i < inMinDims.size(); i++) { + inMinDims[i] = weightDims[i]; + inMaxDims[i] = weightDims[i]; + } + } + return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims(); +} + +VectorDims MKLDNNFullyConnectedNode::makeDummyOutputDims(const VectorDims& inDims) const { + std::vector inShapes = {Shape(inDims), getInputShapeAtPort(WEIGHTS_ID)}; + if (inputShapes.size() > 2) { + inShapes.emplace_back(getInputShapeAtPort(BIAS_ID)); + } + return shapeInferGeneric(inShapes).front(); +} + void MKLDNNFullyConnectedNode::getSupportedDescriptors() { if (getParentEdges().size() != 2 && getParentEdges().size() != 3) IE_THROW() << errorPrefix << " has incorrect number of input edges"; @@ -163,32 +190,8 @@ void MKLDNNFullyConnectedNode::getSupportedDescriptors() { outputDataType = memory::data_type::bf16; } - const auto& inShape = getInputShapeAtPort(DATA_ID); - inDims = inShape.getDims(); - const auto& weightDims = getInputShapeAtPort(WEIGHTS_ID).getStaticDims(); - outDims = getOutputShapeAtPort(0).getDims(); - - if (isDynamicNode()) { - auto inMinDims = inShape.getMinDims(); - auto inMaxDims = inShape.getMaxDims(); - - if (inMinDims.size() == 3) { - inMinDims.back() = weightDims.back(); - inMaxDims.back() = weightDims.back(); - } else { - for (size_t i = 1; i < inMinDims.size(); i++) { - inMinDims[i] = weightDims[i]; - inMaxDims[i] = weightDims[i]; - } - } - inDims = MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims(); - - std::vector inShapes = {Shape(inDims), Shape(weightDims)}; - if (inputShapes.size() > 2) { - inShapes.emplace_back(getInputShapeAtPort(BIAS_ID)); - } - outDims = shapeInferGeneric(inShapes).front(); - } + inDims = isDynamicNode() ? makeDummyInputDims() : getInputShapeAtPort(DATA_ID).getStaticDims(); + outDims = isDynamicNode() ? makeDummyOutputDims(inDims) : getOutputShapeAtPort(0).getStaticDims(); for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) { auto in_candidate = mkldnn::memory::desc(MKLDNNExtensionUtils::convertToDnnlDims(inDims), inputDataType, format); @@ -219,14 +222,8 @@ void MKLDNNFullyConnectedNode::prepareParams() { if (selected_pd == nullptr) IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << "."; - auto initPrimitiveAttr = [&]() { - mkldnn::primitive_attr attr; - setPostOps(attr, dstMemPtr->getStaticDims()); - - return std::make_shared(std::move(attr)); - }; - - AttrPtr attr = initPrimitiveAttr(); + AttrPtr attr = std::make_shared(); + setPostOps(*attr, dstMemPtr->getStaticDims()); DnnlMemoryDescCPtr weightDesc = wghMemPtr->GetDescWithType(); DnnlMemoryDescCPtr biasDesc = nullptr; diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.h b/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.h index a42442f7543..26e45a3b79f 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.h +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_fullyconnected_node.h @@ -55,6 +55,9 @@ private: void createDescriptorInternal(const mkldnn::memory::desc &inputDesc, const mkldnn::memory::desc &outputDesc); + VectorDims makeDummyInputDims() const; + VectorDims makeDummyOutputDims(const VectorDims& inDims) const; + VectorDims inDims; VectorDims outDims;