[CPU] Use ONEDNN3.x weight/dest scale API to optimize perf (#16805)

* [LPT][CPU] Added callback for AddTransformation

* [WIP] Convolution scales fusion

* Force to use weight sclae to test performance.

* Update on interface.

* Use weight scale to adapt to ONEDNN 3.x API changes.

* Update the code.

* Update ONEDNN fix for gemm_x8s8s32x_conv kernel

* Fix the bug in ONEDNN and deconvFusingScale.

* Fuse FC Bias when having DQscale.

* WR to perf regression on

* Update onednn version.

* Fix bug and clean code.

* FC fusing dq scale bug fix.

* Add more comments and debug information.

* Fix CI issues.

* Merge ONEDNN changes.

* Fix CI issues and bugs.

* Apply review comments.

* Update comments.

* Apply reveiw comments.

* Avoid using LPT BiasAttribute RTInfo.

* Applied review comments.

---------

Co-authored-by: Vladislav Golubev <vladislav.golubev@intel.com>
This commit is contained in:
Luwei Zhou 2023-04-14 01:02:48 +08:00 committed by GitHub
parent 25015f9790
commit 6aeb054e48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 387 additions and 95 deletions

View File

@ -12,38 +12,70 @@ namespace ov {
namespace intel_cpu {
DnnlPostOpsComposer::DnnlPostOpsComposer(const dnnl::engine& engine,
dnnl::primitive_attr& attr,
dnnl::post_ops& ops,
std::unordered_map<int, MemoryPtr>& args,
const VectorDims& outputDims,
int indexOfOutputChannelDim,
bool isINT8)
dnnl::primitive_attr& attr,
dnnl::post_ops& ops,
std::unordered_map<int, MemoryPtr>& args,
const VectorDims& outputDims,
int indexOfOutputChannelDim,
bool isInt8,
const int weiScaleMaskPerChannel,
const std::vector<float>& DQScales,
bool hasBias)
: engine(engine),
attr(attr),
ops(ops),
args(args),
outputDims(outputDims),
idxOC(indexOfOutputChannelDim),
isINT8(isINT8) {
isINT8(isInt8),
weightScaleMaskPerChannel(weiScaleMaskPerChannel) {
IE_ASSERT(idxOC >= 0 && idxOC < outputDims.size());
OC = outputDims[idxOC];
dimsPerOC = dimsPerTensor = VectorDims(outputDims.size(), 1);
dimsPerOC[idxOC] = OC;
oscale_mask = 0;
oscale_values = {1.0f};
if (isINT8) {
wei_scale_values = DQScales.empty() ? std::vector<float>{1.0} : DQScales;
wei_scale_mask = wei_scale_values.size() > 1 ? weiScaleMaskPerChannel : 0;
dst_scale_val = 1.0;
//set the DQscale into attr weight scale before appending any post-ops.
updateWeiScales();
//If having the bias, attr weight scale can't be updated for further ops-ops optimization.
//ONEDNN 3.x quantization for scheme: QuantizedInput * QuantizedWeight * DQScale + Bias.
weightScaleAvailable = !hasBias;
} else if (!DQScales.empty()) {
// DQ scale is fused but swiching back to non-INT8 for execution in some cases.
DEBUG_LOG("Set DQ scales for None-INT8, scale size ", DQScales.size());
appendScale(DQScales, false, true);
}
}
void DnnlPostOpsComposer::updateOutputScales() {
if (oscale_mask == 0 && oscale_values[0] == 1.0f)
void DnnlPostOpsComposer::updateWeiScales() {
if (wei_scale_mask == 0 && wei_scale_values[0] == 1.0f)
return;
DEBUG_LOG("Set scales mask ", "DNNL_ARG: ", DNNL_ARG_DST, " mask: ", oscale_mask);
attr.set_scales_mask(DNNL_ARG_DST, oscale_mask);
DEBUG_LOG("Set weight scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", wei_scale_mask);
attr.set_scales_mask(DNNL_ARG_WEIGHTS, wei_scale_mask);
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({oscale_values.size()}));
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({wei_scale_values.size()}));
auto mem = std::make_shared<Memory>(engine);
mem->Create(memoryDesc);
memcpy(mem->GetPtr(), oscale_values.data(), oscale_values.size() * sizeof(float));
memcpy(mem->GetPtr(), wei_scale_values.data(), wei_scale_values.size() * sizeof(float));
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = mem;
}
void DnnlPostOpsComposer::updateDestScales() {
if (dst_scale_val == 1.0f)
return;
DEBUG_LOG("Set dest scale mask ", "DNNL_ARG: ", DNNL_ARG_DST, " mask: ", 0);
attr.set_scales_mask(DNNL_ARG_DST, 0);
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({1}));
auto mem = std::make_shared<Memory>(engine);
mem->Create(memoryDesc);
memcpy(mem->GetPtr(), &dst_scale_val, sizeof(float));
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST] = mem;
}
@ -77,27 +109,33 @@ void DnnlPostOpsComposer::appendRoundHTE() {
bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLastPostOp, bool allowBinary) {
IE_ASSERT(scale.size() == OC || scale.size() == 1);
// there are so many possible optimizations can be done, for example:
//
// we can switch the existing postOps's order to take
// advantage of output scale if it's available:
// relu(x)*scale = relu(x*scale)
// or we can fuse it into previous one as long as they are
// compatible in shape
// x*A*s = x*(A*s)
// or even with add:
// (x*A + B)*s = x*(A*s) + (B*s)
// or we can combine these two tricks:
// relu(x*A)*s = relu(x*(A*s))
//
// we cannot implement all of them, so we just add the one
// that we observed in real models.
// fuse into existing output scale (only when isINT8)
bool can_fuse_into_oscale = false;
if (isINT8 && isLastPostOp && scale.size() == 1) { // oneDNN v3.* limitation does not allow per-channel dst scales
if (ops.len() == 0)
can_fuse_into_oscale = true;
bool fuseIntoWeiScale = false;
// Use dest scale when last post-ops is per-tensor quantization.
if ((isINT8 && isLastPostOp && scale.size() == 1)) {
dst_scale_val = 1.0 / scale[0];
updateDestScales();
return true;
}
if (weightScaleAvailable) {
//oneDNN v3.* weight scale can also be used in the further optimization patterns.
// there are so many possible optimizations can be done, for example:
//
// we can switch the existing postOps's order to take
// advantage of output scale if it's available:
// relu(x)*scale = relu(x*scale)
// or we can fuse it into previous one as long as they are
// compatible in shape
// x*A*s = x*(A*s)
// or even with add:
// (x*A + B)*s = x*(A*s) + (B*s)
// or we can combine these two tricks:
// relu(x*A)*s = relu(x*(A*s))
//
// we cannot implement all of them, so we just add the one
// that we observed in real models.
if ((ops.len() == 0))
fuseIntoWeiScale = true;
// relu(x)*s = relu(x*s)
// prelu(x)*s = prelu(x*s)
@ -105,7 +143,7 @@ bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLa
auto& cur_op = ops.get()->entry_[0];
if ((cur_op.kind == dnnl::impl::primitive_kind::eltwise && cur_op.eltwise.alg == dnnl_eltwise_relu) ||
(cur_op.kind == dnnl::impl::primitive_kind::binary && cur_op.binary.alg == dnnl_binary_prelu)) {
can_fuse_into_oscale = true;
fuseIntoWeiScale = true;
}
}
@ -114,54 +152,32 @@ bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLa
auto& cur_op = ops.get()->entry_.back();
if (cur_op.kind == dnnl::impl::primitive_kind::sum) {
cur_op.sum.scale *= scale[0];
can_fuse_into_oscale = true;
fuseIntoWeiScale = true;
}
}
}
if (can_fuse_into_oscale) {
if (fuseIntoWeiScale) {
if (scale.size() > 1) {
if (oscale_mask == 0)
oscale_values.resize(scale.size(), oscale_values[0]);
if (wei_scale_mask == 0)
wei_scale_values.resize(scale.size(), wei_scale_values[0]);
else
IE_ASSERT(oscale_values.size() == OC);
IE_ASSERT(wei_scale_values.size() == OC);
for (int j = 0; j < OC; j++)
oscale_values[j] *= 1 / scale[j];
wei_scale_values[j] *= scale[j];
} else {
for (int j = 0; j < oscale_values.size(); j++)
oscale_values[j] *= 1 / scale[0];
for (int j = 0; j < wei_scale_values.size(); j++)
wei_scale_values[j] *= scale[0];
}
if (oscale_values.size() == 1)
oscale_mask = 0;
if (wei_scale_values.size() == 1)
wei_scale_mask = 0;
else
oscale_mask = 1 << idxOC;
updateOutputScales();
wei_scale_mask = weightScaleMaskPerChannel;
updateWeiScales();
return true;
}
// (eltwise(x, scale, alpha, beta) + dst[:])*s = (eltwise(x, scale*s, alpha, beta) + s*dst[:])
if (scale.size() == 1 && ops.len() > 1) {
auto N = ops.len();
auto& cur_op = ops.get()->entry_[N-1];
auto& prev_op = ops.get()->entry_[N-2];
if (cur_op.kind == dnnl::impl::primitive_kind::sum && prev_op.is_eltwise()) {
cur_op.sum.scale *= scale[0];
prev_op.eltwise.scale *= scale[0];
return true;
}
}
// eltwise(x, scale, alpha, beta)*s = eltwise(x, (scale*s), alpha, beta)
if (scale.size() == 1 && ops.len() > 0) {
auto& cur_op = ops.get()->entry_.back();
if (cur_op.kind == dnnl::impl::primitive_kind::eltwise) {
cur_op.eltwise.scale *= scale[0];
return true;
}
}
// final fallback
if (scale.size() == 1) {
appendEltwise(dnnl::algorithm::eltwise_linear, scale[0], 0);

View File

@ -29,7 +29,10 @@ public:
std::unordered_map<int, MemoryPtr>& args,
const VectorDims& outputDims,
int indexOfOutputChannelDim,
bool isINT8);
bool isINT8,
int weiScaleMaskPerChannel,
const std::vector<float>& DQScales,
bool hasBias);
void appendBinary(const dnnl::algorithm alg, const std::vector<float>& data);
void appendEltwise(const dnnl::algorithm alg, float alpha, float beta);
@ -50,14 +53,19 @@ private:
std::unordered_map<int, MemoryPtr>& args;
const VectorDims outputDims;
int idxOC;
const bool isINT8; // only INT8 primitive support scales
const int weightScaleMaskPerChannel;
bool weightScaleAvailable = false;
VectorDims dimsPerTensor;
VectorDims dimsPerOC;
Dim OC;
const bool isINT8; // only INT8 primitive support output scale
int oscale_mask;
std::vector<float> oscale_values;
int wei_scale_mask = -1;
std::vector<float> wei_scale_values;
float dst_scale_val;
void updateOutputScales();
void updateWeiScales();
void updateDestScales();
};
} // namespace intel_cpu

View File

@ -436,7 +436,6 @@ void Graph::InitDescriptors() {
if (inputNode)
inputNode->withMeanImage();
}
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.getSupportedDescriptors);
DEBUG_LOG("Get supported primitive descriptors for node: ", node->getName());
node->getSupportedDescriptors();

View File

@ -12,6 +12,7 @@
#include "nodes/reorder.h"
#include "nodes/conv.h"
#include "nodes/deconv.h"
#include "nodes/fullyconnected.h"
#include "nodes/bin_conv.h"
#include "nodes/fake_quantize.h"
#include "nodes/mvn.h"
@ -27,6 +28,7 @@
#include <blob_factory.hpp>
#include "utils/general_utils.h"
#include "utils/cpu_utils.hpp"
#include "utils/debug_capabilities.h"
#include <ngraph/opsets/opset1.hpp>
#include <ie_ngraph_utils.hpp>
@ -61,6 +63,9 @@ namespace intel_cpu {
GraphOptimizer::GraphOptimizer() {}
void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
FuseConvMatmulFCDeconvAndDQScales(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "ApplyCommonGraphOptimizations", "FuseConvolutionAndBias");
FuseConvolutionMatMulDeconvAndBias(graph);
graph.RemoveDroppedNodes();
@ -177,6 +182,121 @@ void GraphOptimizer::ApplyImplSpecificGraphOptimizations(Graph &graph) {
graph.RemoveDroppedEdges();
}
void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
auto& graphNodes = graph.GetNodes();
auto isDQScaleGraphPattern = [](NodePtr node) {
if (node->getType() != Type::Eltwise || node->getAlgorithm() != Algorithm::EltwiseMultiply) {
return false;
}
auto parentNode = node->getParentEdgesAtPort(0)[0]->getParent();
auto scaleNode = node->getParentEdgesAtPort(1)[0]->getParent();
if (!(parentNode->getType() == Type::Convolution
|| parentNode->getType() == Type::MatMul
|| parentNode->getType() == Type::Deconvolution
|| parentNode->getType() == Type::FullyConnected))
return false;
if (!scaleNode->isConstant())
return false;
//Only Fusing scales for INT8 precision.
if (parentNode->getOriginalInputPrecisionAtPort(0) != Precision::U8 && parentNode->getOriginalInputPrecisionAtPort(0) != Precision::I8)
return false;
if (parentNode->getOriginalInputPrecisionAtPort(1) != Precision::I8)
return false;
//Deconv has some heuristic limitation to use INT8 besides input precision.
auto deconv = std::dynamic_pointer_cast<Deconvolution>(parentNode);
if (deconv && !deconv->canBeExecutedInInt8())
return false;
// FC bias has been fused into FC in transformation phase.
// todo: Move the FC fusing bias into graph optimizer.
const auto parentNodeInputEdges = parentNode->getParentEdges().size();
if (parentNodeInputEdges != 2) {
auto fcNode = std::dynamic_pointer_cast<FullyConnected>(parentNode);
if (!(parentNodeInputEdges == 3 && fcNode && fcNode->withBiasFused()))
return false;
}
return true;
};
auto scaleDimsCheck = [](NodePtr node, NodePtr scales) {
const auto nodeOutDims = node->getOutputShapeAtPort(0).getDims();
const auto channelAxis = node->getFusingAxis();
auto OC = nodeOutDims[channelAxis];
if (Shape::UNDEFINED_DIM == OC)
return false;
if (!node->getFusedWith().empty() || !scales->getFusedWith().empty())
return false;
const auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(),
nodeOutDims.size());
if (nodeOutDims.size() != scalesDims.size() || scalesDims.size() < 2)
return false;
if (!dimsEqualStrong(scalesDims[channelAxis], nodeOutDims[channelAxis]) && scalesDims[channelAxis] != 1)
return false;
for (size_t i = 0; i < scalesDims.size(); i++) {
if (scalesDims[i] != 1 && static_cast<int>(i) != channelAxis)
return false;
}
return true;
};
auto initializeDeQuantizedScales = [](NodePtr node, NodePtr scales) {
auto scalesConstant = dynamic_cast<node::Input*>(scales.get());
if (scalesConstant == nullptr)
IE_THROW() << "Cannot cast to Input node";
auto scalesBlob = scalesConstant->getMemoryPtr();
if (scalesBlob == nullptr)
IE_THROW() << "Cannot cast to TBlob internal scales blob";
auto scalesData = static_cast<const float*>(scalesBlob->GetPtr());
if (scalesData == nullptr)
IE_THROW() << "scalesBlob has not allocated buffer";
auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(),
node->getOutputShapeAtPort(0).getDims().size());
auto scaleSize = std::accumulate(scalesDims.begin(), scalesDims.end(), 1, std::multiplies<size_t>());
node->initializeDQScales(scalesData, scaleSize);
return true;
};
for (size_t i = 0; i < graphNodes.size(); i++) {
auto mul = graphNodes[i];
if (!isDQScaleGraphPattern(mul)) continue;
CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvMatmulFCDeconvAndDQScales);
auto node = mul->getParentEdgesAtPort(0)[0]->getParent();
auto scales = mul->getParentEdgesAtPort(1)[0]->getParent();
if (!scaleDimsCheck(node, scales)) {
auto fcNode = std::dynamic_pointer_cast<FullyConnected>(node);
if (fcNode && fcNode->withBiasFused()) {
// For int8 FC, BIAS has been fused into FC during ngraph transformation. DQ fusing check fails here.
// Sliently exit here would cause accuracy issue, because this multiply would be append after BIAS.
// It is a bug. Assert to give more debugging information.
// todo: Remove this by moving the fullyconnect_bias fusing into graph optimizer from ngraph transformation.
DEBUG_LOG("BUG in scaleDimsCheck##", scales->getName(), " into FullyConnect ##", node->getName(),
"Fusing axis: ", node->getFusingAxis());
DEBUG_LOG(*node);
DEBUG_LOG(*scales);
IE_THROW() << "BUG: IN8 FC bias fused, DQ scale can not fused in " << node->getName() << std::endl;
}
continue;
}
if (initializeDeQuantizedScales(node, scales)) {
node->addOriginalLayer(mul->getOriginalLayers());
auto p_edge = mul->getParentEdgesAtPort(1)[0];
graph.RemoveEdge(p_edge);
graph.DropNode(mul);
}
}
}
void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) {
auto& graphNodes = graph.GetNodes();

View File

@ -20,6 +20,7 @@ public:
void ApplyImplSpecificGraphOptimizations(Graph& graph);
private:
void FuseConvMatmulFCDeconvAndDQScales(Graph &graph);
void FuseConvolutionMatMulDeconvAndBias(Graph &graph);
void FuseDeconvolutionAndSimpleOperation(Graph &graph);
void FuseMultiplyAndAdd(Graph &graph);

View File

@ -1666,5 +1666,20 @@ void Node::addSupportedPrimDesc(const std::vector<PortConfigurator>& inPortConfi
supportedPrimitiveDescriptors.push_back({config, implType});
}
void Node::initializeDQScales(const float* scaleData, const size_t scaleSize) {
bool scalePerTensor;
if (!DQScales.empty() || !scaleSize)
IE_THROW() << "DQ scales is preset or scale size is 0, ##" << getName();
DQScales.reserve(scaleSize);
scalePerTensor = true;
for (size_t i = 0; i < scaleSize; i++) {
DQScales.push_back(scaleData[i]);
if (scaleData[i] != scaleData[0])
scalePerTensor = false;
}
if (scalePerTensor)
DQScales.resize(1);
}
} // namespace intel_cpu
} // namespace ov

View File

@ -540,6 +540,10 @@ public:
*/
std::pair<std::vector<float>, std::vector<float>> getScalesAndShifts(const Node *parentNode) const;
void initializeDQScales(const float* scaleData, const size_t scaleSize);
const std::vector<float>& getDQScales() const {
return DQScales;
}
/**
* @brief Appends new item into ops list with the information on how the node should be executed as post operation.
* Seed node should call this routine and pass its post operations list as parameter.
@ -715,7 +719,8 @@ private:
enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2 };
ConstantType checkConstant(LOOK look, std::vector<NodePtr>& checkNodes);
// Hold output scales
std::vector<float> DQScales;
// we cannot rely on per-NUMA weightCache for caching weights because:
// 1.it may not exist(in single stream configuration)
// 2.it only holds weak references, the life-cycle of cached item

View File

@ -616,8 +616,9 @@ void Convolution::setPostOps(dnnl::primitive_attr& attr,
dnnl::post_ops ops;
auto& args = convPostOpsArgs[useLegacyPostOps];
bool isINT8 = canBeExecutedInInt8();
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8);
// Weight dims in NON-Group CONV: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM, weiScaleMaskPerChannel = 1 << 0
// Weight dims in Group CONV:[Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM, weiScaleMaskPerChannel = ( 1 << 0 | 1<< 1) = 0x03
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8, isGrouped ? 3 : 1 << 0, getDQScales(), withBiases);
DEBUG_LOG(getName(), " useLegacyPostOps=", useLegacyPostOps, " initWeights=", initWeights);

View File

@ -83,6 +83,7 @@ private:
PerTensor,
PerChannel
};
class FusedSubgraph;
using FusedSubgraphPtr = std::shared_ptr<FusedSubgraph>;
using executorPtr = std::shared_ptr<DnnlExecutor>;

View File

@ -470,8 +470,19 @@ void Deconvolution::initPaddingR(const Shape &inShape, const Shape &outShape) {
void Deconvolution::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims) {
dnnl::post_ops ops;
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8);
// OC, IC is the convolution forward output channel, input channel.
// According to ONEDNN API doc, mask whould be set on the corresponding index on weight.
// For [OC, IC, KH, KW] perchannel scale weight mask should set on IC dim( 1 << 1) for none group deconv;
// For [Group, OC, IC, KH, KW] IC and group dims ( 1 << 0 | 1<< 2) for group deconv.
// Perchannel weight should set on IC dimention not OC dimention.
// But we have to set on IC dimesion as following to make weight scale work. It should be ONEDNN bug??
// Current perchannel mask setting.
// Weight dims in NON-Group deconv: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM
// weiScaleMaskPerChannel = 1 << 0
// Weight dims in Group deconv: [Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM,
// weiScaleMaskPerChannel = ( 1 << 0 | 1 << 1) = 0x03
// @todo: Clarify with ONEDNN about deconvolution channel mask setting.
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8, withGroups ? 3 : 1 << 0, getDQScales(), withBiases);
for (int i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i];

View File

@ -237,6 +237,7 @@ void FullyConnected::getSupportedDescriptors() {
}
auto weightsDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID));
isINT8 = one_of(inputDataType, memory::data_type::u8, memory::data_type::s8) && weightsDataType == memory::data_type::s8;
// revert back outputDataType on special cases
if (inputDataType == memory::data_type::f32) {
// oneDNN only support f32 output when input is f32, even if FQ is fused
@ -534,10 +535,8 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
IE_THROW() << "Unexpected rank(" << dims_ext.size() << ") for output tensor of node: " << getName();
}
bool isINT8 = getOriginalInputPrecisionAtPort(WEIGHTS_ID) == Precision::U8 ||
getOriginalInputPrecisionAtPort(WEIGHTS_ID) == Precision::I8;
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8);
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8, 1 << 0, getDQScales(), withBiases);
for (int i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i];

View File

@ -58,6 +58,10 @@ public:
void setDynamicBatchLim(int lim) override;
bool withBiasFused() const {
return withBiases;
}
private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
const dnnl::memory::desc &outputDesc);
@ -100,6 +104,7 @@ private:
float minSparseRate = 1.f;
float weiSparseRate = 0.f;
bool useSparseWeightsDecompression();
bool isINT8 = false;
};
} // namespace node

View File

@ -204,6 +204,14 @@ MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
}
bool MatMul::canFuse(const NodePtr& node) const {
// Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output.
// Then the Matmul will change its output precision to fp32. If fusing FQ into matmul, there would be reorder inserted
// after matmul. In some bert model, this reorder causes great perf degradation.
// Todo: Remove this if onednn primitive support U8 output with floating input.
if (node->getType() == Type::FakeQuantize && one_of(node->getOriginalOutputPrecisionAtPort(0), Precision::I8, Precision::U8) &&
!canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1)) &&
getOriginalInputPrecisionAtPort(0) == InferenceEngine::Precision::FP32 )
return false;
return canFuseSimpleOperation(node);
}
@ -217,7 +225,7 @@ void MatMul::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims, bool
bool isINT8 = canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1));
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8);
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8, 1 << (dims.size() - 1), getDQScales(), withBiases);
for (int i = 0; i < fusedWith.size(); ++i) {
auto& node = fusedWith[i];

View File

@ -13,7 +13,7 @@
#include "itt.hpp"
ov::intel_cpu::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
ov::intel_cpu::NonQuantizedFullyConnectedBiasFusion::NonQuantizedFullyConnectedBiasFusion() {
MATCHER_SCOPE(FullyConnectedBiasFusion);
auto input = ngraph::pattern::any_input();
auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
@ -74,3 +74,81 @@ ov::intel_cpu::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
auto m = std::make_shared<ngraph::pattern::Matcher>(m_add, matcher_name);
this->register_matcher(m, callback);
}
//CPU plugin would config LPT not to propogate dequantization scale over bias to follow ONEDNN 3.x scheme.
//It is a little tricky now to first fuse bias not DQ for pattern "FC + DQ + BIAS".
//todo: Will move the FullyConnnect fusing into CPU and fuse the DQ and BIAS in topology order.
ov::intel_cpu::QuantizedFullyConnectedBiasFusion::QuantizedFullyConnectedBiasFusion() {
MATCHER_SCOPE(FullyConnectedBiasFusion);
auto input = ngraph::pattern::any_input();
auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto m_fc = ngraph::pattern::wrap_type<ov::intel_cpu::FullyConnectedNode>({ input, weights }, [](ngraph::Output<ngraph::Node> output) {
return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_rank()(output);
});
auto m_scale = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto m_mul = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>({m_fc, m_scale});
auto m_bias = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto m_add = ngraph::pattern::wrap_type<ngraph::opset1::Add>({m_mul, m_bias});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto mul = pattern_to_output[m_mul].get_node_shared_ptr();
auto scale = pattern_to_output[m_scale].get_node_shared_ptr();
auto add = pattern_to_output[m_add].get_node_shared_ptr();
auto bias = pattern_to_output[m_bias].get_node_shared_ptr();
auto fc = std::dynamic_pointer_cast<ov::intel_cpu::FullyConnectedNode>(pattern_to_output[m_fc].get_node_shared_ptr());
if (!fc || transformation_callback(fc)) {
return false;
}
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(bias)) {
return false;
}
ngraph::Shape bias_shape(bias->get_shape());
ngraph::PartialShape output_shape(fc->get_output_partial_shape(0));
size_t bias_size = ngraph::shape_size(bias_shape);
auto rank = output_shape.rank().get_length();
if (rank == 0 || output_shape[rank - 1].is_dynamic()) {
return false;
}
const bool per_channel = std::count_if(bias_shape.begin(), bias_shape.end(), [](size_t x) { return x > 1; }) == 1;
if (ov::shape_size(bias_shape) != 1 && !per_channel)
return false;
if (bias_shape.empty() || bias_shape.back() != output_shape[rank - 1].get_length() || bias_shape.back() != bias_size) {
return false;
}
ngraph::NodeVector new_ops;
std::shared_ptr<ngraph::Node> final_bias = bias;
if (bias_shape.size() >= 2) {
auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 1 }, { -1 });
final_bias = ov::op::util::make_try_fold<ngraph::opset1::Reshape>(final_bias, reshape_const, true);
new_ops.push_back(final_bias);
}
auto new_fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(fc->input_value(0),
fc->input_value(1),
final_bias,
fc->get_output_rank(),
fc->get_output_type());
new_ops.push_back(new_fc);
std::shared_ptr<ngraph::Node> final_scale = scale;
auto new_mul = std::make_shared<ngraph::opset1::Multiply>(new_fc, final_scale, mul->get_autob());
new_ops.push_back(new_mul);
new_mul->set_friendly_name(add->get_friendly_name());
ngraph::copy_runtime_info({fc, mul, add}, new_ops);
ngraph::replace_node(add, new_mul);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(m_add, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -9,10 +9,25 @@
namespace ov {
namespace intel_cpu {
class FullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
class NonQuantizedFullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("NonQuantizedFullyConnectedBiasFusion", "0");
NonQuantizedFullyConnectedBiasFusion();
};
class QuantizedFullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("FullyConnectedDQBiasFusion", "0");
QuantizedFullyConnectedBiasFusion();
};
class FullyConnectedBiasFusion : public ngraph::pass::GraphRewrite {
public:
OPENVINO_RTTI("FullyConnectedBiasFusion", "0");
FullyConnectedBiasFusion();
FullyConnectedBiasFusion() {
add_matcher<NonQuantizedFullyConnectedBiasFusion>();
add_matcher<QuantizedFullyConnectedBiasFusion>();
}
};
} // namespace intel_cpu

View File

@ -81,12 +81,14 @@
#include "utils/ngraph_transformation.hpp"
// LPT transformations
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "low_precision/convolution_backprop_data.hpp"
#include "low_precision/add.hpp"
#include "low_precision/convert_subtract_constant.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include "low_precision/convolution_backprop_data.hpp"
#include "low_precision/group_convolution.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/bias_attribute.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
// CPU specific transformations
#include "transformations/cpu_opset/convert_to_cpu_specific_opset.hpp"
@ -521,6 +523,11 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
},
ngraph::pass::low_precision::ConvolutionBackpropDataTransformation);
lptManager.get_pass_config()->set_callback<ngraph::pass::low_precision::AddTransformation>(
[](const_node_ptr& node) -> bool {
return ov::marked_as_bias(node);
});
CPU_DISABLE_PASS_COMMON(lptManager, ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation);
lptManager.run_passes(model);

View File

@ -43,7 +43,6 @@ protected:
std::tie(input_shape, layer_type) = GetParam();
targetDevice = CommonTestUtils::DEVICE_CPU;
fusedOps = std::vector<std::string>{"Add"};
std::tie(inFmts, outFmts, priority, selectedType) = CPUSpecificParams{{}, {}, {}, CPUTestsBase::any_type};
std::unordered_map<std::string, std::string> ngraph_type_to_plugin_type{
{"Convolution", "Convolution"},
@ -53,6 +52,11 @@ protected:
{"MatMulWithConstant", "FullyConnected"},
};
node_type = ngraph_type_to_plugin_type[layer_type];
if (node_type == "FullyConnected")
// @todo: Recover the Multiply fusing check after moving FC bias fusing into CPUgraph optimizer.
fusedOps = std::vector<std::string>{"Add"};
else
fusedOps = std::vector<std::string>{"Multiply", "Add"};
const auto shapes = layer_type == "MatMul" ? std::vector<InputShape>{input_shape, input_shape}
: std::vector<InputShape>{input_shape};
@ -100,7 +104,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_FQLayerDQBias_4D_dynamic, FQLayerDQBias,
::testing::Combine(::testing::ValuesIn(input_shapes_4D_dynamic),
::testing::ValuesIn(layer_types_4D_dynamic)),
FQLayerDQBias::getTestCaseName);
const std::vector<InputShape> input_shapes_2D = {
{{-1, 768}, {{1, 768}}}
};

@ -1 +1 @@
Subproject commit f9127156d148393502d1d2254d9a48f564dc9adb
Subproject commit 67c84b1d76390ba1bf977a1c2d4bda53cf479c65