[CPU] Use ONEDNN3.x weight/dest scale API to optimize perf (#16805)
* [LPT][CPU] Added callback for AddTransformation * [WIP] Convolution scales fusion * Force to use weight sclae to test performance. * Update on interface. * Use weight scale to adapt to ONEDNN 3.x API changes. * Update the code. * Update ONEDNN fix for gemm_x8s8s32x_conv kernel * Fix the bug in ONEDNN and deconvFusingScale. * Fuse FC Bias when having DQscale. * WR to perf regression on * Update onednn version. * Fix bug and clean code. * FC fusing dq scale bug fix. * Add more comments and debug information. * Fix CI issues. * Merge ONEDNN changes. * Fix CI issues and bugs. * Apply review comments. * Update comments. * Apply reveiw comments. * Avoid using LPT BiasAttribute RTInfo. * Applied review comments. --------- Co-authored-by: Vladislav Golubev <vladislav.golubev@intel.com>
This commit is contained in:
parent
25015f9790
commit
6aeb054e48
@ -12,38 +12,70 @@ namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
DnnlPostOpsComposer::DnnlPostOpsComposer(const dnnl::engine& engine,
|
||||
dnnl::primitive_attr& attr,
|
||||
dnnl::post_ops& ops,
|
||||
std::unordered_map<int, MemoryPtr>& args,
|
||||
const VectorDims& outputDims,
|
||||
int indexOfOutputChannelDim,
|
||||
bool isINT8)
|
||||
dnnl::primitive_attr& attr,
|
||||
dnnl::post_ops& ops,
|
||||
std::unordered_map<int, MemoryPtr>& args,
|
||||
const VectorDims& outputDims,
|
||||
int indexOfOutputChannelDim,
|
||||
bool isInt8,
|
||||
const int weiScaleMaskPerChannel,
|
||||
const std::vector<float>& DQScales,
|
||||
bool hasBias)
|
||||
: engine(engine),
|
||||
attr(attr),
|
||||
ops(ops),
|
||||
args(args),
|
||||
outputDims(outputDims),
|
||||
idxOC(indexOfOutputChannelDim),
|
||||
isINT8(isINT8) {
|
||||
isINT8(isInt8),
|
||||
weightScaleMaskPerChannel(weiScaleMaskPerChannel) {
|
||||
IE_ASSERT(idxOC >= 0 && idxOC < outputDims.size());
|
||||
OC = outputDims[idxOC];
|
||||
dimsPerOC = dimsPerTensor = VectorDims(outputDims.size(), 1);
|
||||
dimsPerOC[idxOC] = OC;
|
||||
oscale_mask = 0;
|
||||
oscale_values = {1.0f};
|
||||
|
||||
if (isINT8) {
|
||||
wei_scale_values = DQScales.empty() ? std::vector<float>{1.0} : DQScales;
|
||||
wei_scale_mask = wei_scale_values.size() > 1 ? weiScaleMaskPerChannel : 0;
|
||||
dst_scale_val = 1.0;
|
||||
|
||||
//set the DQscale into attr weight scale before appending any post-ops.
|
||||
updateWeiScales();
|
||||
//If having the bias, attr weight scale can't be updated for further ops-ops optimization.
|
||||
//ONEDNN 3.x quantization for scheme: QuantizedInput * QuantizedWeight * DQScale + Bias.
|
||||
weightScaleAvailable = !hasBias;
|
||||
} else if (!DQScales.empty()) {
|
||||
// DQ scale is fused but swiching back to non-INT8 for execution in some cases.
|
||||
DEBUG_LOG("Set DQ scales for None-INT8, scale size ", DQScales.size());
|
||||
appendScale(DQScales, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
void DnnlPostOpsComposer::updateOutputScales() {
|
||||
if (oscale_mask == 0 && oscale_values[0] == 1.0f)
|
||||
void DnnlPostOpsComposer::updateWeiScales() {
|
||||
if (wei_scale_mask == 0 && wei_scale_values[0] == 1.0f)
|
||||
return;
|
||||
|
||||
DEBUG_LOG("Set scales mask ", "DNNL_ARG: ", DNNL_ARG_DST, " mask: ", oscale_mask);
|
||||
attr.set_scales_mask(DNNL_ARG_DST, oscale_mask);
|
||||
DEBUG_LOG("Set weight scales mask ", "DNNL_ARG: ", DNNL_ARG_WEIGHTS, " mask: ", wei_scale_mask);
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, wei_scale_mask);
|
||||
|
||||
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({oscale_values.size()}));
|
||||
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({wei_scale_values.size()}));
|
||||
auto mem = std::make_shared<Memory>(engine);
|
||||
mem->Create(memoryDesc);
|
||||
memcpy(mem->GetPtr(), oscale_values.data(), oscale_values.size() * sizeof(float));
|
||||
memcpy(mem->GetPtr(), wei_scale_values.data(), wei_scale_values.size() * sizeof(float));
|
||||
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = mem;
|
||||
}
|
||||
|
||||
void DnnlPostOpsComposer::updateDestScales() {
|
||||
if (dst_scale_val == 1.0f)
|
||||
return;
|
||||
|
||||
DEBUG_LOG("Set dest scale mask ", "DNNL_ARG: ", DNNL_ARG_DST, " mask: ", 0);
|
||||
attr.set_scales_mask(DNNL_ARG_DST, 0);
|
||||
|
||||
DnnlBlockedMemoryDesc memoryDesc(InferenceEngine::Precision::FP32, Shape({1}));
|
||||
auto mem = std::make_shared<Memory>(engine);
|
||||
mem->Create(memoryDesc);
|
||||
memcpy(mem->GetPtr(), &dst_scale_val, sizeof(float));
|
||||
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST] = mem;
|
||||
}
|
||||
|
||||
@ -77,27 +109,33 @@ void DnnlPostOpsComposer::appendRoundHTE() {
|
||||
|
||||
bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLastPostOp, bool allowBinary) {
|
||||
IE_ASSERT(scale.size() == OC || scale.size() == 1);
|
||||
// there are so many possible optimizations can be done, for example:
|
||||
//
|
||||
// we can switch the existing postOps's order to take
|
||||
// advantage of output scale if it's available:
|
||||
// relu(x)*scale = relu(x*scale)
|
||||
// or we can fuse it into previous one as long as they are
|
||||
// compatible in shape
|
||||
// x*A*s = x*(A*s)
|
||||
// or even with add:
|
||||
// (x*A + B)*s = x*(A*s) + (B*s)
|
||||
// or we can combine these two tricks:
|
||||
// relu(x*A)*s = relu(x*(A*s))
|
||||
//
|
||||
// we cannot implement all of them, so we just add the one
|
||||
// that we observed in real models.
|
||||
|
||||
// fuse into existing output scale (only when isINT8)
|
||||
bool can_fuse_into_oscale = false;
|
||||
if (isINT8 && isLastPostOp && scale.size() == 1) { // oneDNN v3.* limitation does not allow per-channel dst scales
|
||||
if (ops.len() == 0)
|
||||
can_fuse_into_oscale = true;
|
||||
bool fuseIntoWeiScale = false;
|
||||
// Use dest scale when last post-ops is per-tensor quantization.
|
||||
if ((isINT8 && isLastPostOp && scale.size() == 1)) {
|
||||
dst_scale_val = 1.0 / scale[0];
|
||||
updateDestScales();
|
||||
return true;
|
||||
}
|
||||
if (weightScaleAvailable) {
|
||||
//oneDNN v3.* weight scale can also be used in the further optimization patterns.
|
||||
// there are so many possible optimizations can be done, for example:
|
||||
//
|
||||
// we can switch the existing postOps's order to take
|
||||
// advantage of output scale if it's available:
|
||||
// relu(x)*scale = relu(x*scale)
|
||||
// or we can fuse it into previous one as long as they are
|
||||
// compatible in shape
|
||||
// x*A*s = x*(A*s)
|
||||
// or even with add:
|
||||
// (x*A + B)*s = x*(A*s) + (B*s)
|
||||
// or we can combine these two tricks:
|
||||
// relu(x*A)*s = relu(x*(A*s))
|
||||
//
|
||||
// we cannot implement all of them, so we just add the one
|
||||
// that we observed in real models.
|
||||
if ((ops.len() == 0))
|
||||
fuseIntoWeiScale = true;
|
||||
|
||||
// relu(x)*s = relu(x*s)
|
||||
// prelu(x)*s = prelu(x*s)
|
||||
@ -105,7 +143,7 @@ bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLa
|
||||
auto& cur_op = ops.get()->entry_[0];
|
||||
if ((cur_op.kind == dnnl::impl::primitive_kind::eltwise && cur_op.eltwise.alg == dnnl_eltwise_relu) ||
|
||||
(cur_op.kind == dnnl::impl::primitive_kind::binary && cur_op.binary.alg == dnnl_binary_prelu)) {
|
||||
can_fuse_into_oscale = true;
|
||||
fuseIntoWeiScale = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,54 +152,32 @@ bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLa
|
||||
auto& cur_op = ops.get()->entry_.back();
|
||||
if (cur_op.kind == dnnl::impl::primitive_kind::sum) {
|
||||
cur_op.sum.scale *= scale[0];
|
||||
can_fuse_into_oscale = true;
|
||||
fuseIntoWeiScale = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_fuse_into_oscale) {
|
||||
if (fuseIntoWeiScale) {
|
||||
if (scale.size() > 1) {
|
||||
if (oscale_mask == 0)
|
||||
oscale_values.resize(scale.size(), oscale_values[0]);
|
||||
if (wei_scale_mask == 0)
|
||||
wei_scale_values.resize(scale.size(), wei_scale_values[0]);
|
||||
else
|
||||
IE_ASSERT(oscale_values.size() == OC);
|
||||
IE_ASSERT(wei_scale_values.size() == OC);
|
||||
|
||||
for (int j = 0; j < OC; j++)
|
||||
oscale_values[j] *= 1 / scale[j];
|
||||
wei_scale_values[j] *= scale[j];
|
||||
} else {
|
||||
for (int j = 0; j < oscale_values.size(); j++)
|
||||
oscale_values[j] *= 1 / scale[0];
|
||||
for (int j = 0; j < wei_scale_values.size(); j++)
|
||||
wei_scale_values[j] *= scale[0];
|
||||
}
|
||||
|
||||
if (oscale_values.size() == 1)
|
||||
oscale_mask = 0;
|
||||
if (wei_scale_values.size() == 1)
|
||||
wei_scale_mask = 0;
|
||||
else
|
||||
oscale_mask = 1 << idxOC;
|
||||
updateOutputScales();
|
||||
wei_scale_mask = weightScaleMaskPerChannel;
|
||||
updateWeiScales();
|
||||
return true;
|
||||
}
|
||||
|
||||
// (eltwise(x, scale, alpha, beta) + dst[:])*s = (eltwise(x, scale*s, alpha, beta) + s*dst[:])
|
||||
if (scale.size() == 1 && ops.len() > 1) {
|
||||
auto N = ops.len();
|
||||
auto& cur_op = ops.get()->entry_[N-1];
|
||||
auto& prev_op = ops.get()->entry_[N-2];
|
||||
if (cur_op.kind == dnnl::impl::primitive_kind::sum && prev_op.is_eltwise()) {
|
||||
cur_op.sum.scale *= scale[0];
|
||||
prev_op.eltwise.scale *= scale[0];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// eltwise(x, scale, alpha, beta)*s = eltwise(x, (scale*s), alpha, beta)
|
||||
if (scale.size() == 1 && ops.len() > 0) {
|
||||
auto& cur_op = ops.get()->entry_.back();
|
||||
if (cur_op.kind == dnnl::impl::primitive_kind::eltwise) {
|
||||
cur_op.eltwise.scale *= scale[0];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// final fallback
|
||||
if (scale.size() == 1) {
|
||||
appendEltwise(dnnl::algorithm::eltwise_linear, scale[0], 0);
|
||||
|
@ -29,7 +29,10 @@ public:
|
||||
std::unordered_map<int, MemoryPtr>& args,
|
||||
const VectorDims& outputDims,
|
||||
int indexOfOutputChannelDim,
|
||||
bool isINT8);
|
||||
bool isINT8,
|
||||
int weiScaleMaskPerChannel,
|
||||
const std::vector<float>& DQScales,
|
||||
bool hasBias);
|
||||
|
||||
void appendBinary(const dnnl::algorithm alg, const std::vector<float>& data);
|
||||
void appendEltwise(const dnnl::algorithm alg, float alpha, float beta);
|
||||
@ -50,14 +53,19 @@ private:
|
||||
std::unordered_map<int, MemoryPtr>& args;
|
||||
const VectorDims outputDims;
|
||||
int idxOC;
|
||||
const bool isINT8; // only INT8 primitive support scales
|
||||
const int weightScaleMaskPerChannel;
|
||||
bool weightScaleAvailable = false;
|
||||
|
||||
VectorDims dimsPerTensor;
|
||||
VectorDims dimsPerOC;
|
||||
Dim OC;
|
||||
const bool isINT8; // only INT8 primitive support output scale
|
||||
int oscale_mask;
|
||||
std::vector<float> oscale_values;
|
||||
int wei_scale_mask = -1;
|
||||
std::vector<float> wei_scale_values;
|
||||
float dst_scale_val;
|
||||
|
||||
void updateOutputScales();
|
||||
void updateWeiScales();
|
||||
void updateDestScales();
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -436,7 +436,6 @@ void Graph::InitDescriptors() {
|
||||
if (inputNode)
|
||||
inputNode->withMeanImage();
|
||||
}
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.getSupportedDescriptors);
|
||||
DEBUG_LOG("Get supported primitive descriptors for node: ", node->getName());
|
||||
node->getSupportedDescriptors();
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "nodes/reorder.h"
|
||||
#include "nodes/conv.h"
|
||||
#include "nodes/deconv.h"
|
||||
#include "nodes/fullyconnected.h"
|
||||
#include "nodes/bin_conv.h"
|
||||
#include "nodes/fake_quantize.h"
|
||||
#include "nodes/mvn.h"
|
||||
@ -27,6 +28,7 @@
|
||||
#include <blob_factory.hpp>
|
||||
#include "utils/general_utils.h"
|
||||
#include "utils/cpu_utils.hpp"
|
||||
#include "utils/debug_capabilities.h"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
@ -61,6 +63,9 @@ namespace intel_cpu {
|
||||
GraphOptimizer::GraphOptimizer() {}
|
||||
|
||||
void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
|
||||
FuseConvMatmulFCDeconvAndDQScales(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "ApplyCommonGraphOptimizations", "FuseConvolutionAndBias");
|
||||
FuseConvolutionMatMulDeconvAndBias(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
@ -177,6 +182,121 @@ void GraphOptimizer::ApplyImplSpecificGraphOptimizations(Graph &graph) {
|
||||
graph.RemoveDroppedEdges();
|
||||
}
|
||||
|
||||
void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
auto isDQScaleGraphPattern = [](NodePtr node) {
|
||||
if (node->getType() != Type::Eltwise || node->getAlgorithm() != Algorithm::EltwiseMultiply) {
|
||||
return false;
|
||||
}
|
||||
auto parentNode = node->getParentEdgesAtPort(0)[0]->getParent();
|
||||
auto scaleNode = node->getParentEdgesAtPort(1)[0]->getParent();
|
||||
if (!(parentNode->getType() == Type::Convolution
|
||||
|| parentNode->getType() == Type::MatMul
|
||||
|| parentNode->getType() == Type::Deconvolution
|
||||
|| parentNode->getType() == Type::FullyConnected))
|
||||
return false;
|
||||
if (!scaleNode->isConstant())
|
||||
return false;
|
||||
//Only Fusing scales for INT8 precision.
|
||||
if (parentNode->getOriginalInputPrecisionAtPort(0) != Precision::U8 && parentNode->getOriginalInputPrecisionAtPort(0) != Precision::I8)
|
||||
return false;
|
||||
if (parentNode->getOriginalInputPrecisionAtPort(1) != Precision::I8)
|
||||
return false;
|
||||
|
||||
//Deconv has some heuristic limitation to use INT8 besides input precision.
|
||||
auto deconv = std::dynamic_pointer_cast<Deconvolution>(parentNode);
|
||||
if (deconv && !deconv->canBeExecutedInInt8())
|
||||
return false;
|
||||
// FC bias has been fused into FC in transformation phase.
|
||||
// todo: Move the FC fusing bias into graph optimizer.
|
||||
const auto parentNodeInputEdges = parentNode->getParentEdges().size();
|
||||
|
||||
if (parentNodeInputEdges != 2) {
|
||||
auto fcNode = std::dynamic_pointer_cast<FullyConnected>(parentNode);
|
||||
if (!(parentNodeInputEdges == 3 && fcNode && fcNode->withBiasFused()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto scaleDimsCheck = [](NodePtr node, NodePtr scales) {
|
||||
const auto nodeOutDims = node->getOutputShapeAtPort(0).getDims();
|
||||
const auto channelAxis = node->getFusingAxis();
|
||||
auto OC = nodeOutDims[channelAxis];
|
||||
|
||||
if (Shape::UNDEFINED_DIM == OC)
|
||||
return false;
|
||||
if (!node->getFusedWith().empty() || !scales->getFusedWith().empty())
|
||||
return false;
|
||||
|
||||
const auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(),
|
||||
nodeOutDims.size());
|
||||
if (nodeOutDims.size() != scalesDims.size() || scalesDims.size() < 2)
|
||||
return false;
|
||||
|
||||
if (!dimsEqualStrong(scalesDims[channelAxis], nodeOutDims[channelAxis]) && scalesDims[channelAxis] != 1)
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < scalesDims.size(); i++) {
|
||||
if (scalesDims[i] != 1 && static_cast<int>(i) != channelAxis)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto initializeDeQuantizedScales = [](NodePtr node, NodePtr scales) {
|
||||
auto scalesConstant = dynamic_cast<node::Input*>(scales.get());
|
||||
if (scalesConstant == nullptr)
|
||||
IE_THROW() << "Cannot cast to Input node";
|
||||
|
||||
auto scalesBlob = scalesConstant->getMemoryPtr();
|
||||
if (scalesBlob == nullptr)
|
||||
IE_THROW() << "Cannot cast to TBlob internal scales blob";
|
||||
|
||||
auto scalesData = static_cast<const float*>(scalesBlob->GetPtr());
|
||||
if (scalesData == nullptr)
|
||||
IE_THROW() << "scalesBlob has not allocated buffer";
|
||||
auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(),
|
||||
node->getOutputShapeAtPort(0).getDims().size());
|
||||
auto scaleSize = std::accumulate(scalesDims.begin(), scalesDims.end(), 1, std::multiplies<size_t>());
|
||||
node->initializeDQScales(scalesData, scaleSize);
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < graphNodes.size(); i++) {
|
||||
auto mul = graphNodes[i];
|
||||
if (!isDQScaleGraphPattern(mul)) continue;
|
||||
|
||||
CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvMatmulFCDeconvAndDQScales);
|
||||
|
||||
auto node = mul->getParentEdgesAtPort(0)[0]->getParent();
|
||||
auto scales = mul->getParentEdgesAtPort(1)[0]->getParent();
|
||||
if (!scaleDimsCheck(node, scales)) {
|
||||
auto fcNode = std::dynamic_pointer_cast<FullyConnected>(node);
|
||||
if (fcNode && fcNode->withBiasFused()) {
|
||||
// For int8 FC, BIAS has been fused into FC during ngraph transformation. DQ fusing check fails here.
|
||||
// Sliently exit here would cause accuracy issue, because this multiply would be append after BIAS.
|
||||
// It is a bug. Assert to give more debugging information.
|
||||
// todo: Remove this by moving the fullyconnect_bias fusing into graph optimizer from ngraph transformation.
|
||||
DEBUG_LOG("BUG in scaleDimsCheck##", scales->getName(), " into FullyConnect ##", node->getName(),
|
||||
"Fusing axis: ", node->getFusingAxis());
|
||||
DEBUG_LOG(*node);
|
||||
DEBUG_LOG(*scales);
|
||||
IE_THROW() << "BUG: IN8 FC bias fused, DQ scale can not fused in " << node->getName() << std::endl;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (initializeDeQuantizedScales(node, scales)) {
|
||||
node->addOriginalLayer(mul->getOriginalLayers());
|
||||
auto p_edge = mul->getParentEdgesAtPort(1)[0];
|
||||
graph.RemoveEdge(p_edge);
|
||||
graph.DropNode(mul);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
|
@ -20,6 +20,7 @@ public:
|
||||
void ApplyImplSpecificGraphOptimizations(Graph& graph);
|
||||
|
||||
private:
|
||||
void FuseConvMatmulFCDeconvAndDQScales(Graph &graph);
|
||||
void FuseConvolutionMatMulDeconvAndBias(Graph &graph);
|
||||
void FuseDeconvolutionAndSimpleOperation(Graph &graph);
|
||||
void FuseMultiplyAndAdd(Graph &graph);
|
||||
|
@ -1666,5 +1666,20 @@ void Node::addSupportedPrimDesc(const std::vector<PortConfigurator>& inPortConfi
|
||||
supportedPrimitiveDescriptors.push_back({config, implType});
|
||||
}
|
||||
|
||||
void Node::initializeDQScales(const float* scaleData, const size_t scaleSize) {
|
||||
bool scalePerTensor;
|
||||
if (!DQScales.empty() || !scaleSize)
|
||||
IE_THROW() << "DQ scales is preset or scale size is 0, ##" << getName();
|
||||
DQScales.reserve(scaleSize);
|
||||
scalePerTensor = true;
|
||||
for (size_t i = 0; i < scaleSize; i++) {
|
||||
DQScales.push_back(scaleData[i]);
|
||||
if (scaleData[i] != scaleData[0])
|
||||
scalePerTensor = false;
|
||||
}
|
||||
if (scalePerTensor)
|
||||
DQScales.resize(1);
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -540,6 +540,10 @@ public:
|
||||
*/
|
||||
std::pair<std::vector<float>, std::vector<float>> getScalesAndShifts(const Node *parentNode) const;
|
||||
|
||||
void initializeDQScales(const float* scaleData, const size_t scaleSize);
|
||||
const std::vector<float>& getDQScales() const {
|
||||
return DQScales;
|
||||
}
|
||||
/**
|
||||
* @brief Appends new item into ops list with the information on how the node should be executed as post operation.
|
||||
* Seed node should call this routine and pass its post operations list as parameter.
|
||||
@ -715,7 +719,8 @@ private:
|
||||
|
||||
enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2 };
|
||||
ConstantType checkConstant(LOOK look, std::vector<NodePtr>& checkNodes);
|
||||
|
||||
// Hold output scales
|
||||
std::vector<float> DQScales;
|
||||
// we cannot rely on per-NUMA weightCache for caching weights because:
|
||||
// 1.it may not exist(in single stream configuration)
|
||||
// 2.it only holds weak references, the life-cycle of cached item
|
||||
|
@ -616,8 +616,9 @@ void Convolution::setPostOps(dnnl::primitive_attr& attr,
|
||||
dnnl::post_ops ops;
|
||||
auto& args = convPostOpsArgs[useLegacyPostOps];
|
||||
bool isINT8 = canBeExecutedInInt8();
|
||||
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8);
|
||||
// Weight dims in NON-Group CONV: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM, weiScaleMaskPerChannel = 1 << 0
|
||||
// Weight dims in Group CONV:[Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM, weiScaleMaskPerChannel = ( 1 << 0 | 1<< 1) = 0x03
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8, isGrouped ? 3 : 1 << 0, getDQScales(), withBiases);
|
||||
|
||||
DEBUG_LOG(getName(), " useLegacyPostOps=", useLegacyPostOps, " initWeights=", initWeights);
|
||||
|
||||
|
@ -83,6 +83,7 @@ private:
|
||||
PerTensor,
|
||||
PerChannel
|
||||
};
|
||||
|
||||
class FusedSubgraph;
|
||||
using FusedSubgraphPtr = std::shared_ptr<FusedSubgraph>;
|
||||
using executorPtr = std::shared_ptr<DnnlExecutor>;
|
||||
|
@ -470,8 +470,19 @@ void Deconvolution::initPaddingR(const Shape &inShape, const Shape &outShape) {
|
||||
|
||||
void Deconvolution::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims) {
|
||||
dnnl::post_ops ops;
|
||||
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8);
|
||||
// OC, IC is the convolution forward output channel, input channel.
|
||||
// According to ONEDNN API doc, mask whould be set on the corresponding index on weight.
|
||||
// For [OC, IC, KH, KW] perchannel scale weight mask should set on IC dim( 1 << 1) for none group deconv;
|
||||
// For [Group, OC, IC, KH, KW] IC and group dims ( 1 << 0 | 1<< 2) for group deconv.
|
||||
// Perchannel weight should set on IC dimention not OC dimention.
|
||||
// But we have to set on IC dimesion as following to make weight scale work. It should be ONEDNN bug??
|
||||
// Current perchannel mask setting.
|
||||
// Weight dims in NON-Group deconv: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM
|
||||
// weiScaleMaskPerChannel = 1 << 0
|
||||
// Weight dims in Group deconv: [Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM,
|
||||
// weiScaleMaskPerChannel = ( 1 << 0 | 1 << 1) = 0x03
|
||||
// @todo: Clarify with ONEDNN about deconvolution channel mask setting.
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8, withGroups ? 3 : 1 << 0, getDQScales(), withBiases);
|
||||
|
||||
for (int i = 0; i < fusedWith.size(); ++i) {
|
||||
auto& node = fusedWith[i];
|
||||
|
@ -237,6 +237,7 @@ void FullyConnected::getSupportedDescriptors() {
|
||||
}
|
||||
auto weightsDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID));
|
||||
|
||||
isINT8 = one_of(inputDataType, memory::data_type::u8, memory::data_type::s8) && weightsDataType == memory::data_type::s8;
|
||||
// revert back outputDataType on special cases
|
||||
if (inputDataType == memory::data_type::f32) {
|
||||
// oneDNN only support f32 output when input is f32, even if FQ is fused
|
||||
@ -534,10 +535,8 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
|
||||
IE_THROW() << "Unexpected rank(" << dims_ext.size() << ") for output tensor of node: " << getName();
|
||||
}
|
||||
|
||||
bool isINT8 = getOriginalInputPrecisionAtPort(WEIGHTS_ID) == Precision::U8 ||
|
||||
getOriginalInputPrecisionAtPort(WEIGHTS_ID) == Precision::I8;
|
||||
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8);
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8, 1 << 0, getDQScales(), withBiases);
|
||||
|
||||
for (int i = 0; i < fusedWith.size(); ++i) {
|
||||
auto& node = fusedWith[i];
|
||||
|
@ -58,6 +58,10 @@ public:
|
||||
|
||||
void setDynamicBatchLim(int lim) override;
|
||||
|
||||
bool withBiasFused() const {
|
||||
return withBiases;
|
||||
}
|
||||
|
||||
private:
|
||||
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
|
||||
const dnnl::memory::desc &outputDesc);
|
||||
@ -100,6 +104,7 @@ private:
|
||||
float minSparseRate = 1.f;
|
||||
float weiSparseRate = 0.f;
|
||||
bool useSparseWeightsDecompression();
|
||||
bool isINT8 = false;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
|
@ -204,6 +204,14 @@ MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
|
||||
}
|
||||
|
||||
bool MatMul::canFuse(const NodePtr& node) const {
|
||||
// Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output.
|
||||
// Then the Matmul will change its output precision to fp32. If fusing FQ into matmul, there would be reorder inserted
|
||||
// after matmul. In some bert model, this reorder causes great perf degradation.
|
||||
// Todo: Remove this if onednn primitive support U8 output with floating input.
|
||||
if (node->getType() == Type::FakeQuantize && one_of(node->getOriginalOutputPrecisionAtPort(0), Precision::I8, Precision::U8) &&
|
||||
!canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1)) &&
|
||||
getOriginalInputPrecisionAtPort(0) == InferenceEngine::Precision::FP32 )
|
||||
return false;
|
||||
return canFuseSimpleOperation(node);
|
||||
}
|
||||
|
||||
@ -217,7 +225,7 @@ void MatMul::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims, bool
|
||||
|
||||
bool isINT8 = canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1));
|
||||
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8);
|
||||
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, isINT8, 1 << (dims.size() - 1), getDQScales(), withBiases);
|
||||
|
||||
for (int i = 0; i < fusedWith.size(); ++i) {
|
||||
auto& node = fusedWith[i];
|
||||
|
@ -13,7 +13,7 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
ov::intel_cpu::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
|
||||
ov::intel_cpu::NonQuantizedFullyConnectedBiasFusion::NonQuantizedFullyConnectedBiasFusion() {
|
||||
MATCHER_SCOPE(FullyConnectedBiasFusion);
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
@ -74,3 +74,81 @@ ov::intel_cpu::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(m_add, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
//CPU plugin would config LPT not to propogate dequantization scale over bias to follow ONEDNN 3.x scheme.
|
||||
//It is a little tricky now to first fuse bias not DQ for pattern "FC + DQ + BIAS".
|
||||
//todo: Will move the FullyConnnect fusing into CPU and fuse the DQ and BIAS in topology order.
|
||||
ov::intel_cpu::QuantizedFullyConnectedBiasFusion::QuantizedFullyConnectedBiasFusion() {
|
||||
MATCHER_SCOPE(FullyConnectedBiasFusion);
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto weights = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto m_fc = ngraph::pattern::wrap_type<ov::intel_cpu::FullyConnectedNode>({ input, weights }, [](ngraph::Output<ngraph::Node> output) {
|
||||
return ngraph::pattern::consumers_count(1)(output) && ngraph::pattern::has_static_rank()(output);
|
||||
});
|
||||
|
||||
auto m_scale = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto m_mul = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>({m_fc, m_scale});
|
||||
|
||||
auto m_bias = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||
auto m_add = ngraph::pattern::wrap_type<ngraph::opset1::Add>({m_mul, m_bias});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto mul = pattern_to_output[m_mul].get_node_shared_ptr();
|
||||
auto scale = pattern_to_output[m_scale].get_node_shared_ptr();
|
||||
auto add = pattern_to_output[m_add].get_node_shared_ptr();
|
||||
auto bias = pattern_to_output[m_bias].get_node_shared_ptr();
|
||||
auto fc = std::dynamic_pointer_cast<ov::intel_cpu::FullyConnectedNode>(pattern_to_output[m_fc].get_node_shared_ptr());
|
||||
if (!fc || transformation_callback(fc)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(bias)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::Shape bias_shape(bias->get_shape());
|
||||
ngraph::PartialShape output_shape(fc->get_output_partial_shape(0));
|
||||
size_t bias_size = ngraph::shape_size(bias_shape);
|
||||
auto rank = output_shape.rank().get_length();
|
||||
if (rank == 0 || output_shape[rank - 1].is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool per_channel = std::count_if(bias_shape.begin(), bias_shape.end(), [](size_t x) { return x > 1; }) == 1;
|
||||
if (ov::shape_size(bias_shape) != 1 && !per_channel)
|
||||
return false;
|
||||
|
||||
if (bias_shape.empty() || bias_shape.back() != output_shape[rank - 1].get_length() || bias_shape.back() != bias_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
std::shared_ptr<ngraph::Node> final_bias = bias;
|
||||
if (bias_shape.size() >= 2) {
|
||||
auto reshape_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 1 }, { -1 });
|
||||
final_bias = ov::op::util::make_try_fold<ngraph::opset1::Reshape>(final_bias, reshape_const, true);
|
||||
new_ops.push_back(final_bias);
|
||||
}
|
||||
|
||||
auto new_fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(fc->input_value(0),
|
||||
fc->input_value(1),
|
||||
final_bias,
|
||||
fc->get_output_rank(),
|
||||
fc->get_output_type());
|
||||
new_ops.push_back(new_fc);
|
||||
|
||||
std::shared_ptr<ngraph::Node> final_scale = scale;
|
||||
auto new_mul = std::make_shared<ngraph::opset1::Multiply>(new_fc, final_scale, mul->get_autob());
|
||||
new_ops.push_back(new_mul);
|
||||
|
||||
new_mul->set_friendly_name(add->get_friendly_name());
|
||||
ngraph::copy_runtime_info({fc, mul, add}, new_ops);
|
||||
ngraph::replace_node(add, new_mul);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(m_add, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -9,10 +9,25 @@
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class FullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
|
||||
class NonQuantizedFullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("NonQuantizedFullyConnectedBiasFusion", "0");
|
||||
NonQuantizedFullyConnectedBiasFusion();
|
||||
};
|
||||
|
||||
class QuantizedFullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("FullyConnectedDQBiasFusion", "0");
|
||||
QuantizedFullyConnectedBiasFusion();
|
||||
};
|
||||
|
||||
class FullyConnectedBiasFusion : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("FullyConnectedBiasFusion", "0");
|
||||
FullyConnectedBiasFusion();
|
||||
FullyConnectedBiasFusion() {
|
||||
add_matcher<NonQuantizedFullyConnectedBiasFusion>();
|
||||
add_matcher<QuantizedFullyConnectedBiasFusion>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -81,12 +81,14 @@
|
||||
#include "utils/ngraph_transformation.hpp"
|
||||
|
||||
// LPT transformations
|
||||
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
|
||||
#include "low_precision/convolution_backprop_data.hpp"
|
||||
#include "low_precision/add.hpp"
|
||||
#include "low_precision/convert_subtract_constant.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||
#include "low_precision/convolution_backprop_data.hpp"
|
||||
#include "low_precision/group_convolution.hpp"
|
||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/bias_attribute.hpp"
|
||||
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
|
||||
|
||||
// CPU specific transformations
|
||||
#include "transformations/cpu_opset/convert_to_cpu_specific_opset.hpp"
|
||||
@ -521,6 +523,11 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
|
||||
},
|
||||
ngraph::pass::low_precision::ConvolutionBackpropDataTransformation);
|
||||
|
||||
lptManager.get_pass_config()->set_callback<ngraph::pass::low_precision::AddTransformation>(
|
||||
[](const_node_ptr& node) -> bool {
|
||||
return ov::marked_as_bias(node);
|
||||
});
|
||||
|
||||
CPU_DISABLE_PASS_COMMON(lptManager, ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation);
|
||||
|
||||
lptManager.run_passes(model);
|
||||
|
@ -43,7 +43,6 @@ protected:
|
||||
std::tie(input_shape, layer_type) = GetParam();
|
||||
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
fusedOps = std::vector<std::string>{"Add"};
|
||||
std::tie(inFmts, outFmts, priority, selectedType) = CPUSpecificParams{{}, {}, {}, CPUTestsBase::any_type};
|
||||
std::unordered_map<std::string, std::string> ngraph_type_to_plugin_type{
|
||||
{"Convolution", "Convolution"},
|
||||
@ -53,6 +52,11 @@ protected:
|
||||
{"MatMulWithConstant", "FullyConnected"},
|
||||
};
|
||||
node_type = ngraph_type_to_plugin_type[layer_type];
|
||||
if (node_type == "FullyConnected")
|
||||
// @todo: Recover the Multiply fusing check after moving FC bias fusing into CPUgraph optimizer.
|
||||
fusedOps = std::vector<std::string>{"Add"};
|
||||
else
|
||||
fusedOps = std::vector<std::string>{"Multiply", "Add"};
|
||||
|
||||
const auto shapes = layer_type == "MatMul" ? std::vector<InputShape>{input_shape, input_shape}
|
||||
: std::vector<InputShape>{input_shape};
|
||||
@ -100,7 +104,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_FQLayerDQBias_4D_dynamic, FQLayerDQBias,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes_4D_dynamic),
|
||||
::testing::ValuesIn(layer_types_4D_dynamic)),
|
||||
FQLayerDQBias::getTestCaseName);
|
||||
|
||||
const std::vector<InputShape> input_shapes_2D = {
|
||||
{{-1, 768}, {{1, 768}}}
|
||||
};
|
||||
|
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
@ -1 +1 @@
|
||||
Subproject commit f9127156d148393502d1d2254d9a48f564dc9adb
|
||||
Subproject commit 67c84b1d76390ba1bf977a1c2d4bda53cf479c65
|
Loading…
Reference in New Issue
Block a user