From 6aeb054e4807ab49e71f8235a15fd289a31d4153 Mon Sep 17 00:00:00 2001 From: Luwei Zhou Date: Fri, 14 Apr 2023 01:02:48 +0800 Subject: [PATCH] [CPU] Use ONEDNN3.x weight/dest scale API to optimize perf (#16805) * [LPT][CPU] Added callback for AddTransformation * [WIP] Convolution scales fusion * Force to use weight sclae to test performance. * Update on interface. * Use weight scale to adapt to ONEDNN 3.x API changes. * Update the code. * Update ONEDNN fix for gemm_x8s8s32x_conv kernel * Fix the bug in ONEDNN and deconvFusingScale. * Fuse FC Bias when having DQscale. * WR to perf regression on * Update onednn version. * Fix bug and clean code. * FC fusing dq scale bug fix. * Add more comments and debug information. * Fix CI issues. * Merge ONEDNN changes. * Fix CI issues and bugs. * Apply review comments. * Update comments. * Apply reveiw comments. * Avoid using LPT BiasAttribute RTInfo. * Applied review comments. --------- Co-authored-by: Vladislav Golubev --- .../intel_cpu/src/dnnl_postops_composer.cpp | 156 ++++++++++-------- .../intel_cpu/src/dnnl_postops_composer.h | 18 +- src/plugins/intel_cpu/src/graph.cpp | 1 - src/plugins/intel_cpu/src/graph_optimizer.cpp | 120 ++++++++++++++ src/plugins/intel_cpu/src/graph_optimizer.h | 1 + src/plugins/intel_cpu/src/node.cpp | 15 ++ src/plugins/intel_cpu/src/node.h | 7 +- src/plugins/intel_cpu/src/nodes/conv.cpp | 5 +- src/plugins/intel_cpu/src/nodes/conv.h | 1 + src/plugins/intel_cpu/src/nodes/deconv.cpp | 15 +- .../intel_cpu/src/nodes/fullyconnected.cpp | 5 +- .../intel_cpu/src/nodes/fullyconnected.h | 5 + src/plugins/intel_cpu/src/nodes/matmul.cpp | 10 +- .../cpu_opset/common/pass/fc_bias_fusion.cpp | 80 ++++++++- .../cpu_opset/common/pass/fc_bias_fusion.hpp | 19 ++- .../transformation_pipeline.cpp | 15 +- .../subgraph_tests/src/fq_layer_dq_bias.cpp | 7 +- src/plugins/intel_cpu/thirdparty/onednn | 2 +- 18 files changed, 387 insertions(+), 95 deletions(-) diff --git a/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp b/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp index 6321ea1cac0..5a502f99855 100644 --- a/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp +++ b/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp @@ -12,38 +12,70 @@ namespace ov { namespace intel_cpu { DnnlPostOpsComposer::DnnlPostOpsComposer(const dnnl::engine& engine, - dnnl::primitive_attr& attr, - dnnl::post_ops& ops, - std::unordered_map& args, - const VectorDims& outputDims, - int indexOfOutputChannelDim, - bool isINT8) + dnnl::primitive_attr& attr, + dnnl::post_ops& ops, + std::unordered_map& args, + const VectorDims& outputDims, + int indexOfOutputChannelDim, + bool isInt8, + const int weiScaleMaskPerChannel, + const std::vector& DQScales, + bool hasBias) : engine(engine), attr(attr), ops(ops), args(args), outputDims(outputDims), idxOC(indexOfOutputChannelDim), - isINT8(isINT8) { + isINT8(isInt8), + weightScaleMaskPerChannel(weiScaleMaskPerChannel) { IE_ASSERT(idxOC >= 0 && idxOC < outputDims.size()); OC = outputDims[idxOC]; dimsPerOC = dimsPerTensor = VectorDims(outputDims.size(), 1); dimsPerOC[idxOC] = OC; - oscale_mask = 0; - oscale_values = {1.0f}; + + if (isINT8) { + wei_scale_values = DQScales.empty() ? std::vector{1.0} : DQScales; + wei_scale_mask = wei_scale_values.size() > 1 ? weiScaleMaskPerChannel : 0; + dst_scale_val = 1.0; + + //set the DQscale into attr weight scale before appending any post-ops. + updateWeiScales(); + //If having the bias, attr weight scale can't be updated for further ops-ops optimization. + //ONEDNN 3.x quantization for scheme: QuantizedInput * QuantizedWeight * DQScale + Bias. + weightScaleAvailable = !hasBias; + } else if (!DQScales.empty()) { + // DQ scale is fused but swiching back to non-INT8 for execution in some cases. + DEBUG_LOG("Set DQ scales for None-INT8, scale size ", DQScales.size()); + appendScale(DQScales, false, true); + } } -void DnnlPostOpsComposer::updateOutputScales() { - if (oscale_mask == 0 && oscale_values[0] == 1.0f) +void DnnlPostOpsComposer::updateWeiScales() { + if (wei_scale_mask == 0 && wei_scale_values[0] == 1.0f) return; - DEBUG_LOG("Set scales mask ", "DNNL_ARG: ", DNNL_ARG_DST, " mask: ", oscale_mask); - attr.set_scales_mask(DNNL_ARG_DST, oscale_mask); + DEBUG_LOG("Set weight scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", wei_scale_mask); + attr.set_scales_mask(DNNL_ARG_WEIGHTS, wei_scale_mask); - DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({oscale_values.size()})); + DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({wei_scale_values.size()})); auto mem = std::make_shared(engine); mem->Create(memoryDesc); - memcpy(mem->GetPtr(), oscale_values.data(), oscale_values.size() * sizeof(float)); + memcpy(mem->GetPtr(), wei_scale_values.data(), wei_scale_values.size() * sizeof(float)); + args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = mem; +} + +void DnnlPostOpsComposer::updateDestScales() { + if (dst_scale_val == 1.0f) + return; + + DEBUG_LOG("Set dest scale mask ", "DNNL_ARG: ", DNNL_ARG_DST, " mask: ", 0); + attr.set_scales_mask(DNNL_ARG_DST, 0); + + DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({1})); + auto mem = std::make_shared(engine); + mem->Create(memoryDesc); + memcpy(mem->GetPtr(), &dst_scale_val, sizeof(float)); args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST] = mem; } @@ -77,27 +109,33 @@ void DnnlPostOpsComposer::appendRoundHTE() { bool DnnlPostOpsComposer::appendScale(const std::vector& scale, bool isLastPostOp, bool allowBinary) { IE_ASSERT(scale.size() == OC || scale.size() == 1); - // there are so many possible optimizations can be done, for example: - // - // we can switch the existing postOps's order to take - // advantage of output scale if it's available: - // relu(x)*scale = relu(x*scale) - // or we can fuse it into previous one as long as they are - // compatible in shape - // x*A*s = x*(A*s) - // or even with add: - // (x*A + B)*s = x*(A*s) + (B*s) - // or we can combine these two tricks: - // relu(x*A)*s = relu(x*(A*s)) - // - // we cannot implement all of them, so we just add the one - // that we observed in real models. - // fuse into existing output scale (only when isINT8) - bool can_fuse_into_oscale = false; - if (isINT8 && isLastPostOp && scale.size() == 1) { // oneDNN v3.* limitation does not allow per-channel dst scales - if (ops.len() == 0) - can_fuse_into_oscale = true; + bool fuseIntoWeiScale = false; + // Use dest scale when last post-ops is per-tensor quantization. + if ((isINT8 && isLastPostOp && scale.size() == 1)) { + dst_scale_val = 1.0 / scale[0]; + updateDestScales(); + return true; + } + if (weightScaleAvailable) { + //oneDNN v3.* weight scale can also be used in the further optimization patterns. + // there are so many possible optimizations can be done, for example: + // + // we can switch the existing postOps's order to take + // advantage of output scale if it's available: + // relu(x)*scale = relu(x*scale) + // or we can fuse it into previous one as long as they are + // compatible in shape + // x*A*s = x*(A*s) + // or even with add: + // (x*A + B)*s = x*(A*s) + (B*s) + // or we can combine these two tricks: + // relu(x*A)*s = relu(x*(A*s)) + // + // we cannot implement all of them, so we just add the one + // that we observed in real models. + if ((ops.len() == 0)) + fuseIntoWeiScale = true; // relu(x)*s = relu(x*s) // prelu(x)*s = prelu(x*s) @@ -105,7 +143,7 @@ bool DnnlPostOpsComposer::appendScale(const std::vector& scale, bool isLa auto& cur_op = ops.get()->entry_[0]; if ((cur_op.kind == dnnl::impl::primitive_kind::eltwise && cur_op.eltwise.alg == dnnl_eltwise_relu) || (cur_op.kind == dnnl::impl::primitive_kind::binary && cur_op.binary.alg == dnnl_binary_prelu)) { - can_fuse_into_oscale = true; + fuseIntoWeiScale = true; } } @@ -114,54 +152,32 @@ bool DnnlPostOpsComposer::appendScale(const std::vector& scale, bool isLa auto& cur_op = ops.get()->entry_.back(); if (cur_op.kind == dnnl::impl::primitive_kind::sum) { cur_op.sum.scale *= scale[0]; - can_fuse_into_oscale = true; + fuseIntoWeiScale = true; } } } - - if (can_fuse_into_oscale) { + if (fuseIntoWeiScale) { if (scale.size() > 1) { - if (oscale_mask == 0) - oscale_values.resize(scale.size(), oscale_values[0]); + if (wei_scale_mask == 0) + wei_scale_values.resize(scale.size(), wei_scale_values[0]); else - IE_ASSERT(oscale_values.size() == OC); + IE_ASSERT(wei_scale_values.size() == OC); for (int j = 0; j < OC; j++) - oscale_values[j] *= 1 / scale[j]; + wei_scale_values[j] *= scale[j]; } else { - for (int j = 0; j < oscale_values.size(); j++) - oscale_values[j] *= 1 / scale[0]; + for (int j = 0; j < wei_scale_values.size(); j++) + wei_scale_values[j] *= scale[0]; } - if (oscale_values.size() == 1) - oscale_mask = 0; + if (wei_scale_values.size() == 1) + wei_scale_mask = 0; else - oscale_mask = 1 << idxOC; - updateOutputScales(); + wei_scale_mask = weightScaleMaskPerChannel; + updateWeiScales(); return true; } - // (eltwise(x, scale, alpha, beta) + dst[:])*s = (eltwise(x, scale*s, alpha, beta) + s*dst[:]) - if (scale.size() == 1 && ops.len() > 1) { - auto N = ops.len(); - auto& cur_op = ops.get()->entry_[N-1]; - auto& prev_op = ops.get()->entry_[N-2]; - if (cur_op.kind == dnnl::impl::primitive_kind::sum && prev_op.is_eltwise()) { - cur_op.sum.scale *= scale[0]; - prev_op.eltwise.scale *= scale[0]; - return true; - } - } - - // eltwise(x, scale, alpha, beta)*s = eltwise(x, (scale*s), alpha, beta) - if (scale.size() == 1 && ops.len() > 0) { - auto& cur_op = ops.get()->entry_.back(); - if (cur_op.kind == dnnl::impl::primitive_kind::eltwise) { - cur_op.eltwise.scale *= scale[0]; - return true; - } - } - // final fallback if (scale.size() == 1) { appendEltwise(dnnl::algorithm::eltwise_linear, scale[0], 0); diff --git a/src/plugins/intel_cpu/src/dnnl_postops_composer.h b/src/plugins/intel_cpu/src/dnnl_postops_composer.h index f9008751fcc..4ddd2b7ca79 100644 --- a/src/plugins/intel_cpu/src/dnnl_postops_composer.h +++ b/src/plugins/intel_cpu/src/dnnl_postops_composer.h @@ -29,7 +29,10 @@ public: std::unordered_map& args, const VectorDims& outputDims, int indexOfOutputChannelDim, - bool isINT8); + bool isINT8, + int weiScaleMaskPerChannel, + const std::vector& DQScales, + bool hasBias); void appendBinary(const dnnl::algorithm alg, const std::vector& data); void appendEltwise(const dnnl::algorithm alg, float alpha, float beta); @@ -50,14 +53,19 @@ private: std::unordered_map& args; const VectorDims outputDims; int idxOC; + const bool isINT8; // only INT8 primitive support scales + const int weightScaleMaskPerChannel; + bool weightScaleAvailable = false; + VectorDims dimsPerTensor; VectorDims dimsPerOC; Dim OC; - const bool isINT8; // only INT8 primitive support output scale - int oscale_mask; - std::vector oscale_values; + int wei_scale_mask = -1; + std::vector wei_scale_values; + float dst_scale_val; - void updateOutputScales(); + void updateWeiScales(); + void updateDestScales(); }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index 3c1b32b8ca0..4383667c0a5 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -436,7 +436,6 @@ void Graph::InitDescriptors() { if (inputNode) inputNode->withMeanImage(); } - OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.getSupportedDescriptors); DEBUG_LOG("Get supported primitive descriptors for node: ", node->getName()); node->getSupportedDescriptors(); diff --git a/src/plugins/intel_cpu/src/graph_optimizer.cpp b/src/plugins/intel_cpu/src/graph_optimizer.cpp index 18d91debc97..6aaa65ace8e 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.cpp +++ b/src/plugins/intel_cpu/src/graph_optimizer.cpp @@ -12,6 +12,7 @@ #include "nodes/reorder.h" #include "nodes/conv.h" #include "nodes/deconv.h" +#include "nodes/fullyconnected.h" #include "nodes/bin_conv.h" #include "nodes/fake_quantize.h" #include "nodes/mvn.h" @@ -27,6 +28,7 @@ #include #include "utils/general_utils.h" #include "utils/cpu_utils.hpp" +#include "utils/debug_capabilities.h" #include #include @@ -61,6 +63,9 @@ namespace intel_cpu { GraphOptimizer::GraphOptimizer() {} void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) { + FuseConvMatmulFCDeconvAndDQScales(graph); + graph.RemoveDroppedNodes(); + OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "ApplyCommonGraphOptimizations", "FuseConvolutionAndBias"); FuseConvolutionMatMulDeconvAndBias(graph); graph.RemoveDroppedNodes(); @@ -177,6 +182,121 @@ void GraphOptimizer::ApplyImplSpecificGraphOptimizations(Graph &graph) { graph.RemoveDroppedEdges(); } +void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { + auto& graphNodes = graph.GetNodes(); + + auto isDQScaleGraphPattern = [](NodePtr node) { + if (node->getType() != Type::Eltwise || node->getAlgorithm() != Algorithm::EltwiseMultiply) { + return false; + } + auto parentNode = node->getParentEdgesAtPort(0)[0]->getParent(); + auto scaleNode = node->getParentEdgesAtPort(1)[0]->getParent(); + if (!(parentNode->getType() == Type::Convolution + || parentNode->getType() == Type::MatMul + || parentNode->getType() == Type::Deconvolution + || parentNode->getType() == Type::FullyConnected)) + return false; + if (!scaleNode->isConstant()) + return false; + //Only Fusing scales for INT8 precision. + if (parentNode->getOriginalInputPrecisionAtPort(0) != Precision::U8 && parentNode->getOriginalInputPrecisionAtPort(0) != Precision::I8) + return false; + if (parentNode->getOriginalInputPrecisionAtPort(1) != Precision::I8) + return false; + + //Deconv has some heuristic limitation to use INT8 besides input precision. + auto deconv = std::dynamic_pointer_cast(parentNode); + if (deconv && !deconv->canBeExecutedInInt8()) + return false; + // FC bias has been fused into FC in transformation phase. + // todo: Move the FC fusing bias into graph optimizer. + const auto parentNodeInputEdges = parentNode->getParentEdges().size(); + + if (parentNodeInputEdges != 2) { + auto fcNode = std::dynamic_pointer_cast(parentNode); + if (!(parentNodeInputEdges == 3 && fcNode && fcNode->withBiasFused())) + return false; + } + return true; + }; + + auto scaleDimsCheck = [](NodePtr node, NodePtr scales) { + const auto nodeOutDims = node->getOutputShapeAtPort(0).getDims(); + const auto channelAxis = node->getFusingAxis(); + auto OC = nodeOutDims[channelAxis]; + + if (Shape::UNDEFINED_DIM == OC) + return false; + if (!node->getFusedWith().empty() || !scales->getFusedWith().empty()) + return false; + + const auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(), + nodeOutDims.size()); + if (nodeOutDims.size() != scalesDims.size() || scalesDims.size() < 2) + return false; + + if (!dimsEqualStrong(scalesDims[channelAxis], nodeOutDims[channelAxis]) && scalesDims[channelAxis] != 1) + return false; + + for (size_t i = 0; i < scalesDims.size(); i++) { + if (scalesDims[i] != 1 && static_cast(i) != channelAxis) + return false; + } + return true; + }; + + auto initializeDeQuantizedScales = [](NodePtr node, NodePtr scales) { + auto scalesConstant = dynamic_cast(scales.get()); + if (scalesConstant == nullptr) + IE_THROW() << "Cannot cast to Input node"; + + auto scalesBlob = scalesConstant->getMemoryPtr(); + if (scalesBlob == nullptr) + IE_THROW() << "Cannot cast to TBlob internal scales blob"; + + auto scalesData = static_cast(scalesBlob->GetPtr()); + if (scalesData == nullptr) + IE_THROW() << "scalesBlob has not allocated buffer"; + auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(), + node->getOutputShapeAtPort(0).getDims().size()); + auto scaleSize = std::accumulate(scalesDims.begin(), scalesDims.end(), 1, std::multiplies()); + node->initializeDQScales(scalesData, scaleSize); + return true; + }; + + for (size_t i = 0; i < graphNodes.size(); i++) { + auto mul = graphNodes[i]; + if (!isDQScaleGraphPattern(mul)) continue; + + CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvMatmulFCDeconvAndDQScales); + + auto node = mul->getParentEdgesAtPort(0)[0]->getParent(); + auto scales = mul->getParentEdgesAtPort(1)[0]->getParent(); + if (!scaleDimsCheck(node, scales)) { + auto fcNode = std::dynamic_pointer_cast(node); + if (fcNode && fcNode->withBiasFused()) { + // For int8 FC, BIAS has been fused into FC during ngraph transformation. DQ fusing check fails here. + // Sliently exit here would cause accuracy issue, because this multiply would be append after BIAS. + // It is a bug. Assert to give more debugging information. + // todo: Remove this by moving the fullyconnect_bias fusing into graph optimizer from ngraph transformation. + DEBUG_LOG("BUG in scaleDimsCheck##", scales->getName(), " into FullyConnect ##", node->getName(), + "Fusing axis: ", node->getFusingAxis()); + DEBUG_LOG(*node); + DEBUG_LOG(*scales); + IE_THROW() << "BUG: IN8 FC bias fused, DQ scale can not fused in " << node->getName() << std::endl; + } + continue; + } + + if (initializeDeQuantizedScales(node, scales)) { + node->addOriginalLayer(mul->getOriginalLayers()); + auto p_edge = mul->getParentEdgesAtPort(1)[0]; + graph.RemoveEdge(p_edge); + graph.DropNode(mul); + } + } +} + void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { auto& graphNodes = graph.GetNodes(); diff --git a/src/plugins/intel_cpu/src/graph_optimizer.h b/src/plugins/intel_cpu/src/graph_optimizer.h index 9fe6e8df157..aa5da34f0fd 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.h +++ b/src/plugins/intel_cpu/src/graph_optimizer.h @@ -20,6 +20,7 @@ public: void ApplyImplSpecificGraphOptimizations(Graph& graph); private: + void FuseConvMatmulFCDeconvAndDQScales(Graph &graph); void FuseConvolutionMatMulDeconvAndBias(Graph &graph); void FuseDeconvolutionAndSimpleOperation(Graph &graph); void FuseMultiplyAndAdd(Graph &graph); diff --git a/src/plugins/intel_cpu/src/node.cpp b/src/plugins/intel_cpu/src/node.cpp index 2814d3d40d3..78c4463e232 100644 --- a/src/plugins/intel_cpu/src/node.cpp +++ b/src/plugins/intel_cpu/src/node.cpp @@ -1666,5 +1666,20 @@ void Node::addSupportedPrimDesc(const std::vector& inPortConfi supportedPrimitiveDescriptors.push_back({config, implType}); } +void Node::initializeDQScales(const float* scaleData, const size_t scaleSize) { + bool scalePerTensor; + if (!DQScales.empty() || !scaleSize) + IE_THROW() << "DQ scales is preset or scale size is 0, ##" << getName(); + DQScales.reserve(scaleSize); + scalePerTensor = true; + for (size_t i = 0; i < scaleSize; i++) { + DQScales.push_back(scaleData[i]); + if (scaleData[i] != scaleData[0]) + scalePerTensor = false; + } + if (scalePerTensor) + DQScales.resize(1); +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/node.h b/src/plugins/intel_cpu/src/node.h index bb10731ad2c..62e8aa81853 100644 --- a/src/plugins/intel_cpu/src/node.h +++ b/src/plugins/intel_cpu/src/node.h @@ -540,6 +540,10 @@ public: */ std::pair, std::vector> getScalesAndShifts(const Node *parentNode) const; + void initializeDQScales(const float* scaleData, const size_t scaleSize); + const std::vector& getDQScales() const { + return DQScales; + } /** * @brief Appends new item into ops list with the information on how the node should be executed as post operation. * Seed node should call this routine and pass its post operations list as parameter. @@ -715,7 +719,8 @@ private: enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2 }; ConstantType checkConstant(LOOK look, std::vector& checkNodes); - + // Hold output scales + std::vector DQScales; // we cannot rely on per-NUMA weightCache for caching weights because: // 1.it may not exist(in single stream configuration) // 2.it only holds weak references, the life-cycle of cached item diff --git a/src/plugins/intel_cpu/src/nodes/conv.cpp b/src/plugins/intel_cpu/src/nodes/conv.cpp index e4b1cd5bfe4..ffbdb4ed5e2 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.cpp +++ b/src/plugins/intel_cpu/src/nodes/conv.cpp @@ -616,8 +616,9 @@ void Convolution::setPostOps(dnnl::primitive_attr& attr, dnnl::post_ops ops; auto& args = convPostOpsArgs[useLegacyPostOps]; bool isINT8 = canBeExecutedInInt8(); - - DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8); + // Weight dims in NON-Group CONV: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM, weiScaleMaskPerChannel = 1 << 0 + // Weight dims in Group CONV:[Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM, weiScaleMaskPerChannel = ( 1 << 0 | 1<< 1) = 0x03 + DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8, isGrouped ? 3 : 1 << 0, getDQScales(), withBiases); DEBUG_LOG(getName(), " useLegacyPostOps=", useLegacyPostOps, " initWeights=", initWeights); diff --git a/src/plugins/intel_cpu/src/nodes/conv.h b/src/plugins/intel_cpu/src/nodes/conv.h index 48b4b6b7058..d858197cae6 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.h +++ b/src/plugins/intel_cpu/src/nodes/conv.h @@ -83,6 +83,7 @@ private: PerTensor, PerChannel }; + class FusedSubgraph; using FusedSubgraphPtr = std::shared_ptr; using executorPtr = std::shared_ptr; diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index afcd521acda..a081f2ec949 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -470,8 +470,19 @@ void Deconvolution::initPaddingR(const Shape &inShape, const Shape &outShape) { void Deconvolution::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims) { dnnl::post_ops ops; - - DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8); + // OC, IC is the convolution forward output channel, input channel. + // According to ONEDNN API doc, mask whould be set on the corresponding index on weight. + // For [OC, IC, KH, KW] perchannel scale weight mask should set on IC dim( 1 << 1) for none group deconv; + // For [Group, OC, IC, KH, KW] IC and group dims ( 1 << 0 | 1<< 2) for group deconv. + // Perchannel weight should set on IC dimention not OC dimention. + // But we have to set on IC dimesion as following to make weight scale work. It should be ONEDNN bug?? + // Current perchannel mask setting. + // Weight dims in NON-Group deconv: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM + // weiScaleMaskPerChannel = 1 << 0 + // Weight dims in Group deconv: [Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM, + // weiScaleMaskPerChannel = ( 1 << 0 | 1 << 1) = 0x03 + // @todo: Clarify with ONEDNN about deconvolution channel mask setting. + DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8, withGroups ? 3 : 1 << 0, getDQScales(), withBiases); for (int i = 0; i < fusedWith.size(); ++i) { auto& node = fusedWith[i]; diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp index bc441ac79fa..4e0814cfee9 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp @@ -237,6 +237,7 @@ void FullyConnected::getSupportedDescriptors() { } auto weightsDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID)); + isINT8 = one_of(inputDataType, memory::data_type::u8, memory::data_type::s8) && weightsDataType == memory::data_type::s8; // revert back outputDataType on special cases if (inputDataType == memory::data_type::f32) { // oneDNN only support f32 output when input is f32, even if FQ is fused @@ -534,10 +535,8 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di IE_THROW() << "Unexpected rank(" << dims_ext.size() << ") for output tensor of node: " << getName(); } - bool isINT8 = getOriginalInputPrecisionAtPort(WEIGHTS_ID) == Precision::U8 || - getOriginalInputPrecisionAtPort(WEIGHTS_ID) == Precision::I8; - DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8); + DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8, 1 << 0, getDQScales(), withBiases); for (int i = 0; i < fusedWith.size(); ++i) { auto& node = fusedWith[i]; diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.h b/src/plugins/intel_cpu/src/nodes/fullyconnected.h index 8add77440fd..6898e2f6bc8 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.h +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.h @@ -58,6 +58,10 @@ public: void setDynamicBatchLim(int lim) override; + bool withBiasFused() const { + return withBiases; + } + private: void createDescriptorInternal(const dnnl::memory::desc &inputDesc, const dnnl::memory::desc &outputDesc); @@ -100,6 +104,7 @@ private: float minSparseRate = 1.f; float weiSparseRate = 0.f; bool useSparseWeightsDecompression(); + bool isINT8 = false; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/matmul.cpp b/src/plugins/intel_cpu/src/nodes/matmul.cpp index 1e2669579ae..e5802dc1223 100644 --- a/src/plugins/intel_cpu/src/nodes/matmul.cpp +++ b/src/plugins/intel_cpu/src/nodes/matmul.cpp @@ -204,6 +204,14 @@ MatMul::MatMul(const std::shared_ptr& op, const GraphContext::CPtr } bool MatMul::canFuse(const NodePtr& node) const { + // Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output. + // Then the Matmul will change its output precision to fp32. If fusing FQ into matmul, there would be reorder inserted + // after matmul. In some bert model, this reorder causes great perf degradation. + // Todo: Remove this if onednn primitive support U8 output with floating input. + if (node->getType() == Type::FakeQuantize && one_of(node->getOriginalOutputPrecisionAtPort(0), Precision::I8, Precision::U8) && + !canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1)) && + getOriginalInputPrecisionAtPort(0) == InferenceEngine::Precision::FP32 ) + return false; return canFuseSimpleOperation(node); } @@ -217,7 +225,7 @@ void MatMul::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims, bool bool isINT8 = canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1)); - DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8); + DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8, 1 << (dims.size() - 1), getDQScales(), withBiases); for (int i = 0; i < fusedWith.size(); ++i) { auto& node = fusedWith[i]; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.cpp index 993922b9ccf..fe84f9e9d30 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.cpp @@ -13,7 +13,7 @@ #include "itt.hpp" -ov::intel_cpu::FullyConnectedBiasFusion::FullyConnectedBiasFusion() { +ov::intel_cpu::NonQuantizedFullyConnectedBiasFusion::NonQuantizedFullyConnectedBiasFusion() { MATCHER_SCOPE(FullyConnectedBiasFusion); auto input = ngraph::pattern::any_input(); auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); @@ -74,3 +74,81 @@ ov::intel_cpu::FullyConnectedBiasFusion::FullyConnectedBiasFusion() { auto m = std::make_shared(m_add, matcher_name); this->register_matcher(m, callback); } + +//CPU plugin would config LPT not to propogate dequantization scale over bias to follow ONEDNN 3.x scheme. +//It is a little tricky now to first fuse bias not DQ for pattern "FC + DQ + BIAS". +//todo: Will move the FullyConnnect fusing into CPU and fuse the DQ and BIAS in topology order. +ov::intel_cpu::QuantizedFullyConnectedBiasFusion::QuantizedFullyConnectedBiasFusion() { + MATCHER_SCOPE(FullyConnectedBiasFusion); + auto input = ngraph::pattern::any_input(); + auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto m_fc = ngraph::pattern::wrap_type({ input, weights }, [](ngraph::Output output) { + return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_rank()(output); + }); + + auto m_scale = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto m_mul = ngraph::pattern::wrap_type({m_fc, m_scale}); + + auto m_bias = ngraph::pattern::any_input(ngraph::pattern::has_static_shape()); + auto m_add = ngraph::pattern::wrap_type({m_mul, m_bias}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto mul = pattern_to_output[m_mul].get_node_shared_ptr(); + auto scale = pattern_to_output[m_scale].get_node_shared_ptr(); + auto add = pattern_to_output[m_add].get_node_shared_ptr(); + auto bias = pattern_to_output[m_bias].get_node_shared_ptr(); + auto fc = std::dynamic_pointer_cast(pattern_to_output[m_fc].get_node_shared_ptr()); + if (!fc || transformation_callback(fc)) { + return false; + } + + if (!std::dynamic_pointer_cast(bias)) { + return false; + } + + ngraph::Shape bias_shape(bias->get_shape()); + ngraph::PartialShape output_shape(fc->get_output_partial_shape(0)); + size_t bias_size = ngraph::shape_size(bias_shape); + auto rank = output_shape.rank().get_length(); + if (rank == 0 || output_shape[rank - 1].is_dynamic()) { + return false; + } + + const bool per_channel = std::count_if(bias_shape.begin(), bias_shape.end(), [](size_t x) { return x > 1; }) == 1; + if (ov::shape_size(bias_shape) != 1 && !per_channel) + return false; + + if (bias_shape.empty() || bias_shape.back() != output_shape[rank - 1].get_length() || bias_shape.back() != bias_size) { + return false; + } + + ngraph::NodeVector new_ops; + + std::shared_ptr final_bias = bias; + if (bias_shape.size() >= 2) { + auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 1 }, { -1 }); + final_bias = ov::op::util::make_try_fold(final_bias, reshape_const, true); + new_ops.push_back(final_bias); + } + + auto new_fc = std::make_shared(fc->input_value(0), + fc->input_value(1), + final_bias, + fc->get_output_rank(), + fc->get_output_type()); + new_ops.push_back(new_fc); + + std::shared_ptr final_scale = scale; + auto new_mul = std::make_shared(new_fc, final_scale, mul->get_autob()); + new_ops.push_back(new_mul); + + new_mul->set_friendly_name(add->get_friendly_name()); + ngraph::copy_runtime_info({fc, mul, add}, new_ops); + ngraph::replace_node(add, new_mul); + return true; + }; + + auto m = std::make_shared(m_add, matcher_name); + this->register_matcher(m, callback); +} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.hpp index 5abac038c2d..812597802a4 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/fc_bias_fusion.hpp @@ -9,10 +9,25 @@ namespace ov { namespace intel_cpu { -class FullyConnectedBiasFusion : public ngraph::pass::MatcherPass { +class NonQuantizedFullyConnectedBiasFusion : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("NonQuantizedFullyConnectedBiasFusion", "0"); + NonQuantizedFullyConnectedBiasFusion(); +}; + +class QuantizedFullyConnectedBiasFusion : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("FullyConnectedDQBiasFusion", "0"); + QuantizedFullyConnectedBiasFusion(); +}; + +class FullyConnectedBiasFusion : public ngraph::pass::GraphRewrite { public: OPENVINO_RTTI("FullyConnectedBiasFusion", "0"); - FullyConnectedBiasFusion(); + FullyConnectedBiasFusion() { + add_matcher(); + add_matcher(); + } }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 942c41b3048..6f7bfd7d7a5 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -81,12 +81,14 @@ #include "utils/ngraph_transformation.hpp" // LPT transformations -#include "transformations/low_precision/mark_dequantization_subgraph.hpp" -#include "low_precision/convolution_backprop_data.hpp" +#include "low_precision/add.hpp" #include "low_precision/convert_subtract_constant.hpp" -#include "low_precision/network_helper.hpp" -#include "low_precision/multiply_to_group_convolution.hpp" +#include "low_precision/convolution_backprop_data.hpp" #include "low_precision/group_convolution.hpp" +#include "low_precision/multiply_to_group_convolution.hpp" +#include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/bias_attribute.hpp" +#include "transformations/low_precision/mark_dequantization_subgraph.hpp" // CPU specific transformations #include "transformations/cpu_opset/convert_to_cpu_specific_opset.hpp" @@ -521,6 +523,11 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vectorset_callback( + [](const_node_ptr& node) -> bool { + return ov::marked_as_bias(node); + }); + CPU_DISABLE_PASS_COMMON(lptManager, ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation); lptManager.run_passes(model); diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/fq_layer_dq_bias.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/fq_layer_dq_bias.cpp index 68e5fbdbf81..e33a0342b0a 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/fq_layer_dq_bias.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/fq_layer_dq_bias.cpp @@ -43,7 +43,6 @@ protected: std::tie(input_shape, layer_type) = GetParam(); targetDevice = CommonTestUtils::DEVICE_CPU; - fusedOps = std::vector{"Add"}; std::tie(inFmts, outFmts, priority, selectedType) = CPUSpecificParams{{}, {}, {}, CPUTestsBase::any_type}; std::unordered_map ngraph_type_to_plugin_type{ {"Convolution", "Convolution"}, @@ -53,6 +52,11 @@ protected: {"MatMulWithConstant", "FullyConnected"}, }; node_type = ngraph_type_to_plugin_type[layer_type]; + if (node_type == "FullyConnected") + // @todo: Recover the Multiply fusing check after moving FC bias fusing into CPUgraph optimizer. + fusedOps = std::vector{"Add"}; + else + fusedOps = std::vector{"Multiply", "Add"}; const auto shapes = layer_type == "MatMul" ? std::vector{input_shape, input_shape} : std::vector{input_shape}; @@ -100,7 +104,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_FQLayerDQBias_4D_dynamic, FQLayerDQBias, ::testing::Combine(::testing::ValuesIn(input_shapes_4D_dynamic), ::testing::ValuesIn(layer_types_4D_dynamic)), FQLayerDQBias::getTestCaseName); - const std::vector input_shapes_2D = { {{-1, 768}, {{1, 768}}} }; diff --git a/src/plugins/intel_cpu/thirdparty/onednn b/src/plugins/intel_cpu/thirdparty/onednn index f9127156d14..67c84b1d763 160000 --- a/src/plugins/intel_cpu/thirdparty/onednn +++ b/src/plugins/intel_cpu/thirdparty/onednn @@ -1 +1 @@ -Subproject commit f9127156d148393502d1d2254d9a48f564dc9adb +Subproject commit 67c84b1d76390ba1bf977a1c2d4bda53cf479c65