[CPU] Support using BF16 in INT8 models (#15663)
This commit is contained in:
@@ -139,13 +139,11 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
|
||||
if (val == PluginConfigParams::YES) {
|
||||
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) {
|
||||
enforceBF16 = true;
|
||||
manualEnforceBF16 = true;
|
||||
} else {
|
||||
IE_THROW() << "Platform doesn't support BF16 format";
|
||||
}
|
||||
} else if (val == PluginConfigParams::NO) {
|
||||
enforceBF16 = false;
|
||||
manualEnforceBF16 = false;
|
||||
} else {
|
||||
IE_THROW() << "Wrong value for property key " << PluginConfigParams::KEY_ENFORCE_BF16
|
||||
<< ". Expected only YES/NO";
|
||||
@@ -159,13 +157,11 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
|
||||
if (val == "bf16") {
|
||||
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) {
|
||||
enforceBF16 = true;
|
||||
manualEnforceBF16 = true;
|
||||
} else {
|
||||
IE_THROW() << "Platform doesn't support BF16 format";
|
||||
}
|
||||
} else if (val == "f32") {
|
||||
enforceBF16 = false;
|
||||
manualEnforceBF16 = false;
|
||||
} else {
|
||||
IE_THROW() << "Wrong value for property key " << ov::inference_precision.name()
|
||||
<< ". Supported values: bf16, f32";
|
||||
|
||||
@@ -52,12 +52,10 @@ struct Config {
|
||||
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
|
||||
LPTransformsMode lpTransformsMode = LPTransformsMode::On;
|
||||
bool enforceBF16 = true;
|
||||
bool manualEnforceBF16 = false;
|
||||
#else
|
||||
// Currently INT8 mode is not optimized on ARM / RISCV or other non-x86 platforms, fallback to FP32 mode.
|
||||
LPTransformsMode lpTransformsMode = LPTransformsMode::Off;
|
||||
bool enforceBF16 = false;
|
||||
bool manualEnforceBF16 = false;
|
||||
#endif
|
||||
|
||||
DenormalsOptMode denormalsOptMode = DenormalsOptMode::DO_Keep;
|
||||
|
||||
@@ -136,7 +136,7 @@ bool DnnlPostOpsComposer::appendScale(const std::vector<float>& scale, bool isLa
|
||||
if (oscale_values.size() == 1)
|
||||
oscale_mask = 0;
|
||||
else
|
||||
oscale_mask = 1 << 1; // it works for both Conv/Matmul
|
||||
oscale_mask = 1 << idxOC;
|
||||
updateOutputScales();
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1506,11 +1506,6 @@ bool Graph::InsertNode(NodePtr parent, NodePtr child, NodePtr node, int parentPo
|
||||
|
||||
// Set all non const data paths precision to BF16
|
||||
void Graph::EnforceBF16() {
|
||||
// Floating point parts of FP32 + INT8 or FP32 + BIN mixed precision models will be executed in BF16 precision
|
||||
// only if enforceBF16 flag was set manually because current performance is not good enough to enable it by default
|
||||
if (!implication(context->isGraphQuantized(), getConfig().manualEnforceBF16))
|
||||
return;
|
||||
|
||||
std::function<void(const NodePtr&, std::unordered_set<NodePtr>& skipNodes)> searchForNodesToSkip;
|
||||
searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set<NodePtr>& skipNodes) -> void {
|
||||
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
|
||||
|
||||
@@ -261,7 +261,7 @@ void summary_perf(const Graph &graph) {
|
||||
}
|
||||
const std::string& summaryPerf = graph.getConfig().debugCaps.summaryPerf;
|
||||
|
||||
if (summaryPerf.empty())
|
||||
if (summaryPerf.empty() || !std::stoi(summaryPerf))
|
||||
return;
|
||||
|
||||
std::map<std::string, double> perf_by_type;
|
||||
@@ -308,7 +308,7 @@ void summary_perf(const Graph &graph) {
|
||||
std::stringstream ss;
|
||||
int percentage = static_cast<int>(it.second*100/total_avg);
|
||||
if (percentage == 0) break;
|
||||
ss << std::setw(10) << std::right << percentage << " % :" << it.first << std::endl;
|
||||
ss << std::setw(10) << std::right << percentage << " % : " << std::setw(8) << std::right << it.second << "(us) " << it.first << std::endl;
|
||||
std::cout << ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -734,21 +734,6 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @todo FQ fusing was disabled for BF16 output since oneDNN primitives lack support
|
||||
* for bf16 depthwise postops.
|
||||
* This is not the case anymore, because after migration to oneDNN 2.3 FQ will be fused as
|
||||
* multiple binary post ops.
|
||||
* This check can already be removed for FC fusing, but should be kept for Convolution,
|
||||
* which still uses legacy depthwise postops for performance reasons.
|
||||
*/
|
||||
static bool BF16QuantizeNodeFusing(const NodePtr& parentNode, const NodePtr& childNode) {
|
||||
return childNode->getType() == Type::FakeQuantize &&
|
||||
one_of(Precision::BF16,
|
||||
parentNode->getOriginalOutputPrecisionAtPort(0),
|
||||
childNode->getOriginalOutputPrecisionAtPort(0));
|
||||
}
|
||||
|
||||
void GraphOptimizer::FuseFullyConnectedAndSimpleOperation(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
@@ -772,12 +757,6 @@ void GraphOptimizer::FuseFullyConnectedAndSimpleOperation(Graph &graph) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// BF16 Quantize Layer Fusing Disabling
|
||||
if (BF16QuantizeNodeFusing(parentNode, childNode)) {
|
||||
parent++;
|
||||
continue;
|
||||
}
|
||||
|
||||
childNode->fuseInto(parentNode);
|
||||
|
||||
if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) {
|
||||
@@ -1066,12 +1045,6 @@ void GraphOptimizer::FuseConvolutionAndSimpleOperation(Graph &graph) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// BF16 Quantize Layer Fusing Disabling
|
||||
if (BF16QuantizeNodeFusing(parentNode, childNode)) {
|
||||
parent++;
|
||||
continue;
|
||||
}
|
||||
|
||||
childNode->fuseInto(parentNode);
|
||||
|
||||
if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) {
|
||||
|
||||
@@ -503,11 +503,6 @@ void Convolution::getSupportedDescriptors() {
|
||||
|
||||
if (canBeExecutedInInt8()) {
|
||||
DEBUG_LOG(getName(), "Creating I8 descriptor");
|
||||
// We have to extend convolution_x8s8s32x from oneDNN to support BF16 output data type
|
||||
if (outputDataType == memory::data_type::bf16)
|
||||
outputDataType = memory::data_type::f32;
|
||||
if (eltwisePrecision == Precision::BF16)
|
||||
eltwisePrecision = Precision::FP32;
|
||||
// initTryBrgconvFlag depends on outputDataType, should be after outputDataType computed
|
||||
if (!enforceBrgconv)
|
||||
initTryBrgconvFlag();
|
||||
|
||||
@@ -232,29 +232,29 @@ void FullyConnected::getSupportedDescriptors() {
|
||||
auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
|
||||
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));
|
||||
|
||||
if (inputDataType == memory::data_type::f32) {
|
||||
outputDataType = memory::data_type::f32;
|
||||
}
|
||||
|
||||
if (!fusedWith.empty()) {
|
||||
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0));
|
||||
}
|
||||
auto weightsDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(WEIGHTS_ID));
|
||||
|
||||
// We have to extend gemm_x8s8s32x_inner_product_fwd_t from oneDNN to support BF16 output data type
|
||||
if ((!one_of(inputDataType , memory::data_type::u8, memory::data_type::s8) || weightsDataType != memory::data_type::s8)
|
||||
&& inputDataType != memory::data_type::bf16) {
|
||||
inputDataType = outputDataType = memory::data_type::f32;
|
||||
}
|
||||
|
||||
if (one_of(inputDataType , memory::data_type::u8, memory::data_type::s8)
|
||||
&& outputDataType == memory::data_type::bf16) {
|
||||
// revert back outputDataType on special cases
|
||||
if (inputDataType == memory::data_type::f32) {
|
||||
// oneDNN only support f32 output when input is f32, even if FQ is fused
|
||||
outputDataType = memory::data_type::f32;
|
||||
}
|
||||
|
||||
if (inputDataType == memory::data_type::bf16
|
||||
&& one_of(outputDataType , memory::data_type::u8, memory::data_type::s8)) {
|
||||
outputDataType = memory::data_type::bf16;
|
||||
} else if (inputDataType == memory::data_type::bf16) {
|
||||
// bf16 input only supports bf16/f32 output, even if FQ is fused as post-ops
|
||||
if (one_of(outputDataType , memory::data_type::u8, memory::data_type::s8)) {
|
||||
outputDataType = memory::data_type::bf16;
|
||||
}
|
||||
} else if (one_of(inputDataType, memory::data_type::u8, memory::data_type::s8)) {
|
||||
if (weightsDataType != memory::data_type::s8) {
|
||||
// weight has to be s8 for INT8 mode, otherwise fallback to
|
||||
// f32 mode
|
||||
inputDataType = outputDataType = memory::data_type::f32;
|
||||
}
|
||||
} else {
|
||||
// s32/u32/... unsupported input data types, fallback to f32
|
||||
inputDataType = outputDataType = memory::data_type::f32;
|
||||
}
|
||||
|
||||
inDims = isDynamicNode() ? makeDummyInputDims() : getInputShapeAtPort(DATA_ID).getStaticDims();
|
||||
|
||||
@@ -204,34 +204,6 @@ MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
|
||||
}
|
||||
|
||||
bool MatMul::canFuse(const NodePtr& node) const {
|
||||
// per channel binary post op for rank > 2D is supported only by oneDNN reference implementation because of unusual MatMul channel axis (issue 6669)
|
||||
if (getOutputShapeAtPort(0).getRank() > 2) {
|
||||
if (const auto* eltwiseNode = dynamic_cast<Eltwise *>(node.get())) {
|
||||
if (one_of(eltwiseNode->getAlgorithm(), Algorithm::EltwiseAdd,
|
||||
Algorithm::EltwiseMultiply,
|
||||
Algorithm::EltwiseSubtract,
|
||||
Algorithm::EltwiseDivide,
|
||||
Algorithm::EltwisePrelu,
|
||||
Algorithm::EltwiseMulAdd,
|
||||
Algorithm::EltwisePowerStatic) &&
|
||||
eltwiseNode->getBroadcastingPolicy() != Eltwise::PerTensor) {
|
||||
return false;
|
||||
}
|
||||
} else if (const auto* fakeQuantizeNode = dynamic_cast<FakeQuantize *>(node.get())) {
|
||||
if (fakeQuantizeNode->getBroadcastingPolicy() != FakeQuantize::PerTensor) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Todo:
|
||||
// Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output.
|
||||
// Then the Matmul will change its output precision to fp32, but the FQ child will still has the int8 input precision.
|
||||
// This information should be propagated! Note that we may need to propagate updated precision to child fused nodes.
|
||||
if (node->getType() == Type::FakeQuantize &&
|
||||
one_of(node->getOriginalOutputPrecisionAtPort(0), Precision::I8, Precision::U8) &&
|
||||
!canBeExecutedInInt8(getOriginalInputPrecisionAtPort(0), getOriginalInputPrecisionAtPort(1)))
|
||||
return false;
|
||||
return canFuseSimpleOperation(node);
|
||||
}
|
||||
|
||||
@@ -344,12 +316,20 @@ void MatMul::getSupportedDescriptors() {
|
||||
outPortPrec = firstInPortPrec = secondInPortPrec = Precision::FP32;
|
||||
}
|
||||
|
||||
Precision postOpsPrec = outPortPrec;
|
||||
if (!fusedWith.empty()) {
|
||||
outPortPrec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
|
||||
postOpsPrec = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
|
||||
}
|
||||
|
||||
if (!canBeExecutedInInt8(firstInPortPrec, secondInPortPrec) && one_of(outPortPrec, Precision::U8, Precision::I8))
|
||||
outPortPrec = Precision::FP32; // INT output is not supported for non-INT inputs
|
||||
if (canBeExecutedInInt8(firstInPortPrec, secondInPortPrec)) {
|
||||
// INT8 mode support wide range of output precisions
|
||||
outPortPrec = postOpsPrec;
|
||||
} else if (postOpsPrec == Precision::FP32) {
|
||||
// all non-INT8 modes support fp32 output precision
|
||||
outPortPrec = postOpsPrec;
|
||||
} else {
|
||||
// otherwise we ignore postOpsPrec and stay with getOriginalOutputPrecisionAtPort(0)
|
||||
}
|
||||
|
||||
const auto& inputShape0 = getInputShapeAtPort(0);
|
||||
const auto& inputShape1 = getInputShapeAtPort(1);
|
||||
|
||||
@@ -479,11 +479,6 @@ std::ostream & operator<<(std::ostream & os, const PrintableModel& model) {
|
||||
os << std::endl;
|
||||
|
||||
// recursively output subgraphs
|
||||
if (auto subgraph = std::dynamic_pointer_cast<ngraph::snippets::op::Subgraph>(op)) {
|
||||
os << "\t\t snippets Subgraph: " << subgraph->get_friendly_name() << " is_quantized:" << subgraph->is_quantized() << std::endl;
|
||||
os << PrintableModel(subgraph->body(), tag, prefix + "\t\t");
|
||||
}
|
||||
|
||||
if (auto msubgraph = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
|
||||
auto cnt = msubgraph->get_internal_subgraphs_size();
|
||||
for (int i = 0; i < cnt; i++) {
|
||||
|
||||
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
Submodule src/plugins/intel_cpu/thirdparty/onednn updated: bd3498162f...0285720996
Reference in New Issue
Block a user