From a02eafb397b23416078afdd1f3eabb84a0fb77a1 Mon Sep 17 00:00:00 2001 From: Egor Duplensky Date: Tue, 26 Oct 2021 14:49:36 +0300 Subject: [PATCH] [CPU] [BF16] Do not enforce BF16 for graph tail (#6114) --- .../src/mkldnn_plugin/mkldnn_graph.cpp | 66 +++++++++++++++---- .../src/mkldnn_plugin/mkldnn_node.cpp | 6 ++ .../src/mkldnn_plugin/mkldnn_node.h | 1 + .../plugin/cpu/bfloat16/gather_multiply.cpp | 2 +- .../plugin/cpu/test_utils/cpu_test_utils.cpp | 3 + 5 files changed, 63 insertions(+), 15 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp index e962d362293..61928e183f5 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp @@ -1209,21 +1209,59 @@ bool MKLDNNGraph::InsertNode(MKLDNNNodePtr parent, MKLDNNNodePtr child, MKLDNNNo void MKLDNNGraph::EnforceBF16() { // Floating point parts of FP32 + INT8 or FP32 + BIN mixed precision models will be executed in BF16 precision // only if enforceBF16 flag was set manually because current performance is not good enough to enable it by default - if (implication(isQuantized(), config.manualEnforceBF16)) { - for (auto &node : graphNodes) { - if (node->getType() != Input && node->getType() != Output) { - for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) { - auto &parent = node->getParentEdgesAtPort(i)[0]->getParent(); - if (!(parent->getType() == Input && parent->isConstant()) && // exclude nodes after Constant Inputs - !(parent->getType() == Input && node->getType() == Eltwise) && // exclude Eltwise after Input since it supports conversion to BF16 - node->getOriginalInputPrecisionAtPort(i) == Precision::FP32) - node->setOriginalInputPrecisionAtPort(i, Precision::BF16); - } + if (!implication(isQuantized(), config.manualEnforceBF16)) + return; + /* list of node types that must be forced to be executed in BF16 precision + * because of performance gains */ + static const std::unordered_set> significantNodes { // std::hash is necessary old compilers (defect in C++11 standart) + Convolution, // conv nets + FullyConnected, // conv / bert nets + RNNCell, // recurent nets + RNNSeq, // recurent nets + MatMul, // bert nets + ROIPooling, // object detection nets + Interpolate, // super resolution nets + }; - for (size_t i = 0; i < node->getOriginalOutputsNumber(); i++) { - if (node->getOriginalOutputPrecisionAtPort(i) == Precision::FP32) - node->setOriginalOutputPrecisionAtPort(i, Precision::BF16); - } + std::function& skipNodes)> searchForNodesToSkip; + searchForNodesToSkip = [&](const MKLDNNNodePtr& node, std::unordered_set& skipNodes) -> void { + for (size_t i = 0; i < node->getParentEdges().size(); i++) { + const auto& parent = node->getParentEdgeAt(i)->getParent(); + if (significantNodes.count(parent->getType())) // stop at significant nodes + continue; + + const auto res = skipNodes.insert(parent); + if (res.second) // node not visited yet + searchForNodesToSkip(parent, skipNodes); + } + }; + + /* Skip BF16 enforcement for tail of the graph by forming set of nodes to skip. + * Necessary to maintain accuracy. + * Experiments show zero peformance impact on average */ + std::unordered_set nodesToSkip; + // starting from output nodes + for (const auto& entry : outputNodesMap) { + const auto& node = entry.second; + searchForNodesToSkip(node, nodesToSkip); + } + + for (const auto& node : graphNodes) { + if (nodesToSkip.count(node) && !node->enforceBF16evenForGraphTail) + continue; + + if (node->getType() != Input && node->getType() != Output) { + for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) { + const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent(); + if (!(parent->getType() == Input && parent->isConstant()) && // exclude skipNodes after Constant Inputs + !(parent->getType() == Input && node->getType() == Eltwise) && // exclude Eltwise after Input since it supports conversion to BF16 + node->getOriginalInputPrecisionAtPort(i) == Precision::FP32) + node->setOriginalInputPrecisionAtPort(i, Precision::BF16); + } + + for (size_t i = 0; i < node->getOriginalOutputsNumber(); i++) { + if (node->getOriginalOutputPrecisionAtPort(i) == Precision::FP32) + node->setOriginalOutputPrecisionAtPort(i, Precision::BF16); } } } diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index d095b02d1f2..31d36aece4a 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -159,6 +159,12 @@ MKLDNNNode::MKLDNNNode(const std::shared_ptr& op, const mkldnn::en } } } + + const auto it = rtInfo.find("enforceBF16evenForGraphTail"); + if (it != rtInfo.end()) { + if (const auto value = std::dynamic_pointer_cast>(it->second)) + enforceBF16evenForGraphTail = value->get(); + } } MKLDNNNode::MKLDNNNode(const std::string& type, const std::string& name, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &w_cache) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.h b/inference-engine/src/mkldnn_plugin/mkldnn_node.h index b7a3622cb77..7e089789720 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.h @@ -593,6 +593,7 @@ protected: std::vector implPriorities; std::vector inputMemoryFormatsFilter; std::vector outputMemoryFormatsFilter; + bool enforceBF16evenForGraphTail = false; std::string originalLayers; // contains names of the original layers separated by comma diff --git a/inference-engine/tests/functional/plugin/cpu/bfloat16/gather_multiply.cpp b/inference-engine/tests/functional/plugin/cpu/bfloat16/gather_multiply.cpp index e4283a18931..84c9824f22b 100644 --- a/inference-engine/tests/functional/plugin/cpu/bfloat16/gather_multiply.cpp +++ b/inference-engine/tests/functional/plugin/cpu/bfloat16/gather_multiply.cpp @@ -100,7 +100,7 @@ protected: // performance counters expectedPrecisions["Matmul_0"] = "BF16"; - expectedPrecisions["Mul_1"] = "BF16"; + expectedPrecisions["Mul_1"] = netPrecision.name(); // tail kept in FP32 precision } }; diff --git a/inference-engine/tests/functional/plugin/cpu/test_utils/cpu_test_utils.cpp b/inference-engine/tests/functional/plugin/cpu/test_utils/cpu_test_utils.cpp index f01bd40b96b..4b515dcd144 100644 --- a/inference-engine/tests/functional/plugin/cpu/test_utils/cpu_test_utils.cpp +++ b/inference-engine/tests/functional/plugin/cpu/test_utils/cpu_test_utils.cpp @@ -4,6 +4,7 @@ #include "cpu_test_utils.hpp" #include "utils/rt_info/memory_formats_attribute.hpp" +#include namespace CPUTestUtils { @@ -257,6 +258,8 @@ CPUTestsBase::makeCPUInfo(std::vector inFmts, std::vector>(impls2str(priority))}); } + cpuInfo.insert({"enforceBF16evenForGraphTail", ov::make_variant(true)}); + return cpuInfo; }