[CPU] [BF16] Do not enforce BF16 for graph tail (#6114)
This commit is contained in:
parent
4a96d14adc
commit
a02eafb397
@ -1209,21 +1209,59 @@ bool MKLDNNGraph::InsertNode(MKLDNNNodePtr parent, MKLDNNNodePtr child, MKLDNNNo
|
||||
void MKLDNNGraph::EnforceBF16() {
|
||||
// Floating point parts of FP32 + INT8 or FP32 + BIN mixed precision models will be executed in BF16 precision
|
||||
// only if enforceBF16 flag was set manually because current performance is not good enough to enable it by default
|
||||
if (implication(isQuantized(), config.manualEnforceBF16)) {
|
||||
for (auto &node : graphNodes) {
|
||||
if (node->getType() != Input && node->getType() != Output) {
|
||||
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
|
||||
auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
|
||||
if (!(parent->getType() == Input && parent->isConstant()) && // exclude nodes after Constant Inputs
|
||||
!(parent->getType() == Input && node->getType() == Eltwise) && // exclude Eltwise after Input since it supports conversion to BF16
|
||||
node->getOriginalInputPrecisionAtPort(i) == Precision::FP32)
|
||||
node->setOriginalInputPrecisionAtPort(i, Precision::BF16);
|
||||
}
|
||||
if (!implication(isQuantized(), config.manualEnforceBF16))
|
||||
return;
|
||||
/* list of node types that must be forced to be executed in BF16 precision
|
||||
* because of performance gains */
|
||||
static const std::unordered_set<Type, std::hash<int>> significantNodes { // std::hash<int> is necessary old compilers (defect in C++11 standart)
|
||||
Convolution, // conv nets
|
||||
FullyConnected, // conv / bert nets
|
||||
RNNCell, // recurent nets
|
||||
RNNSeq, // recurent nets
|
||||
MatMul, // bert nets
|
||||
ROIPooling, // object detection nets
|
||||
Interpolate, // super resolution nets
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < node->getOriginalOutputsNumber(); i++) {
|
||||
if (node->getOriginalOutputPrecisionAtPort(i) == Precision::FP32)
|
||||
node->setOriginalOutputPrecisionAtPort(i, Precision::BF16);
|
||||
}
|
||||
std::function<void(const MKLDNNNodePtr&, std::unordered_set<MKLDNNNodePtr>& skipNodes)> searchForNodesToSkip;
|
||||
searchForNodesToSkip = [&](const MKLDNNNodePtr& node, std::unordered_set<MKLDNNNodePtr>& skipNodes) -> void {
|
||||
for (size_t i = 0; i < node->getParentEdges().size(); i++) {
|
||||
const auto& parent = node->getParentEdgeAt(i)->getParent();
|
||||
if (significantNodes.count(parent->getType())) // stop at significant nodes
|
||||
continue;
|
||||
|
||||
const auto res = skipNodes.insert(parent);
|
||||
if (res.second) // node not visited yet
|
||||
searchForNodesToSkip(parent, skipNodes);
|
||||
}
|
||||
};
|
||||
|
||||
/* Skip BF16 enforcement for tail of the graph by forming set of nodes to skip.
|
||||
* Necessary to maintain accuracy.
|
||||
* Experiments show zero peformance impact on average */
|
||||
std::unordered_set<MKLDNNNodePtr> nodesToSkip;
|
||||
// starting from output nodes
|
||||
for (const auto& entry : outputNodesMap) {
|
||||
const auto& node = entry.second;
|
||||
searchForNodesToSkip(node, nodesToSkip);
|
||||
}
|
||||
|
||||
for (const auto& node : graphNodes) {
|
||||
if (nodesToSkip.count(node) && !node->enforceBF16evenForGraphTail)
|
||||
continue;
|
||||
|
||||
if (node->getType() != Input && node->getType() != Output) {
|
||||
for (size_t i = 0; i < node->getOriginalInputsNumber(); i++) {
|
||||
const auto &parent = node->getParentEdgesAtPort(i)[0]->getParent();
|
||||
if (!(parent->getType() == Input && parent->isConstant()) && // exclude skipNodes after Constant Inputs
|
||||
!(parent->getType() == Input && node->getType() == Eltwise) && // exclude Eltwise after Input since it supports conversion to BF16
|
||||
node->getOriginalInputPrecisionAtPort(i) == Precision::FP32)
|
||||
node->setOriginalInputPrecisionAtPort(i, Precision::BF16);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < node->getOriginalOutputsNumber(); i++) {
|
||||
if (node->getOriginalOutputPrecisionAtPort(i) == Precision::FP32)
|
||||
node->setOriginalOutputPrecisionAtPort(i, Precision::BF16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -159,6 +159,12 @@ MKLDNNNode::MKLDNNNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::en
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto it = rtInfo.find("enforceBF16evenForGraphTail");
|
||||
if (it != rtInfo.end()) {
|
||||
if (const auto value = std::dynamic_pointer_cast<ngraph::VariantImpl<int64_t>>(it->second))
|
||||
enforceBF16evenForGraphTail = value->get();
|
||||
}
|
||||
}
|
||||
|
||||
MKLDNNNode::MKLDNNNode(const std::string& type, const std::string& name, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &w_cache)
|
||||
|
@ -593,6 +593,7 @@ protected:
|
||||
std::vector <impl_desc_type> implPriorities;
|
||||
std::vector <mkldnn::memory::format_tag> inputMemoryFormatsFilter;
|
||||
std::vector <mkldnn::memory::format_tag> outputMemoryFormatsFilter;
|
||||
bool enforceBF16evenForGraphTail = false;
|
||||
|
||||
std::string originalLayers; // contains names of the original layers separated by comma
|
||||
|
||||
|
@ -100,7 +100,7 @@ protected:
|
||||
// performance counters
|
||||
|
||||
expectedPrecisions["Matmul_0"] = "BF16";
|
||||
expectedPrecisions["Mul_1"] = "BF16";
|
||||
expectedPrecisions["Mul_1"] = netPrecision.name(); // tail kept in FP32 precision
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "cpu_test_utils.hpp"
|
||||
#include "utils/rt_info/memory_formats_attribute.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace CPUTestUtils {
|
||||
|
||||
@ -257,6 +258,8 @@ CPUTestsBase::makeCPUInfo(std::vector<cpu_memory_format_t> inFmts, std::vector<c
|
||||
cpuInfo.insert({"PrimitivesPriority", std::make_shared<ngraph::VariantWrapper<std::string>>(impls2str(priority))});
|
||||
}
|
||||
|
||||
cpuInfo.insert({"enforceBF16evenForGraphTail", ov::make_variant<int64_t>(true)});
|
||||
|
||||
return cpuInfo;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user