[CPU] [BF16] Do not enforce BF16 for graph tail (#6114)

This commit is contained in:
Egor Duplensky 2021-10-26 14:49:36 +03:00 committed by GitHub
parent 4a96d14adc
commit a02eafb397
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 15 deletions

View File

@ -1209,12 +1209,51 @@ bool MKLDNNGraph::InsertNode(MKLDNNNodePtr parent, MKLDNNNodePtr child, MKLDNNNo
void MKLDNNGraph::EnforceBF16() {
// Floating point parts of FP32 + INT8 or FP32 + BIN mixed precision models will be executed in BF16 precision
// only if enforceBF16 flag was set manually because current performance is not good enough to enable it by default
if (implication(isQuantized(), config.manualEnforceBF16)) {
for (auto &node : graphNodes) {
if (!implication(isQuantized(), config.manualEnforceBF16))
return;
/* list of node types that must be forced to be executed in BF16 precision
* because of performance gains */
static const std::unordered_set<Type, std::hash<int>> significantNodes { // std::hash<int> is necessary old compilers (defect in C++11 standart)
Convolution, // conv nets
FullyConnected, // conv / bert nets
RNNCell, // recurent nets
RNNSeq, // recurent nets
MatMul, // bert nets
ROIPooling, // object detection nets
Interpolate, // super resolution nets
};
std::function<void(const MKLDNNNodePtr&, std::unordered_set<MKLDNNNodePtr>& skipNodes)> searchForNodesToSkip;
searchForNodesToSkip = [&](const MKLDNNNodePtr& node, std::unordered_set<MKLDNNNodePtr>& skipNodes) -> void {
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
const auto& parent = node->getParentEdgeAt(i)->getParent();
if (significantNodes.count(parent->getType())) // stop at significant nodes
continue;
const auto res = skipNodes.insert(parent);
if (res.second) // node not visited yet
searchForNodesToSkip(parent, skipNodes);
}
};
/* Skip BF16 enforcement for tail of the graph by forming set of nodes to skip.
* Necessary to maintain accuracy.
* Experiments show zero peformance impact on average */
std::unordered_set<MKLDNNNodePtr> nodesToSkip;
// starting from output nodes
for (const auto& entry : outputNodesMap) {
const auto& node = entry.second;
searchForNodesToSkip(node, nodesToSkip);
}
for (const auto& node : graphNodes) {
if (nodesToSkip.count(node) && !node->enforceBF16evenForGraphTail)
continue;
if (node->getType() != Input && node->getType() != Output) {
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
if (!(parent->getType() == Input && parent->isConstant()) && // exclude nodes after Constant Inputs
const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
if (!(parent->getType() == Input && parent->isConstant()) && // exclude skipNodes after Constant Inputs
!(parent->getType() == Input && node->getType() == Eltwise) && // exclude Eltwise after Input since it supports conversion to BF16
node->getOriginalInputPrecisionAtPort(i) == Precision::FP32)
node->setOriginalInputPrecisionAtPort(i, Precision::BF16);
@ -1226,7 +1265,6 @@ void MKLDNNGraph::EnforceBF16() {
}
}
}
}
}
std::shared_ptr<ngraph::Function> MKLDNNGraph::dump() const {

View File

@ -159,6 +159,12 @@ MKLDNNNode::MKLDNNNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::en
}
}
}
const auto it = rtInfo.find("enforceBF16evenForGraphTail");
if (it != rtInfo.end()) {
if (const auto value = std::dynamic_pointer_cast<ngraph::VariantImpl<int64_t>>(it->second))
enforceBF16evenForGraphTail = value->get();
}
}
MKLDNNNode::MKLDNNNode(const std::string& type, const std::string& name, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &w_cache)

View File

@ -593,6 +593,7 @@ protected:
std::vector <impl_desc_type> implPriorities;
std::vector <mkldnn::memory::format_tag> inputMemoryFormatsFilter;
std::vector <mkldnn::memory::format_tag> outputMemoryFormatsFilter;
bool enforceBF16evenForGraphTail = false;
std::string originalLayers; // contains names of the original layers separated by comma

View File

@ -100,7 +100,7 @@ protected:
// performance counters
expectedPrecisions["Matmul_0"] = "BF16";
expectedPrecisions["Mul_1"] = "BF16";
expectedPrecisions["Mul_1"] = netPrecision.name(); // tail kept in FP32 precision
}
};

View File

@ -4,6 +4,7 @@
#include "cpu_test_utils.hpp"
#include "utils/rt_info/memory_formats_attribute.hpp"
#include <cstdint>
namespace CPUTestUtils {
@ -257,6 +258,8 @@ CPUTestsBase::makeCPUInfo(std::vector<cpu_memory_format_t> inFmts, std::vector<c
cpuInfo.insert({"PrimitivesPriority", std::make_shared<ngraph::VariantWrapper<std::string>>(impls2str(priority))});
}
cpuInfo.insert({"enforceBF16evenForGraphTail", ov::make_variant<int64_t>(true)});
return cpuInfo;
}