[CPU] WA: Stop fusing per-OC eltwise into Matmul with input rank >4 (#16824)

This commit is contained in:
Tingqian Li
2023-04-19 15:11:04 +08:00
committed by GitHub
parent dbd20ec799
commit 1525f6cc16

View File

@@ -21,6 +21,7 @@
#include "memory_desc/cpu_memory_desc_utils.h" #include "memory_desc/cpu_memory_desc_utils.h"
#include <dnnl_extension_utils.h> #include <dnnl_extension_utils.h>
#include <common/primitive_hashing_utils.hpp> #include <common/primitive_hashing_utils.hpp>
#include <cpu/x64/cpu_isa_traits.hpp>
using namespace dnnl; using namespace dnnl;
using namespace InferenceEngine; using namespace InferenceEngine;
@@ -204,6 +205,19 @@ MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
} }
bool MatMul::canFuse(const NodePtr& node) const { bool MatMul::canFuse(const NodePtr& node) const {
// WA for CVS-84056: oneDNN brgemm impl has problem with per-OC binary-postOps for MatMul with 6D inputs
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core)) {
if (auto* eltwiseNode = dynamic_cast<Eltwise*>(node.get())) {
if (eltwiseNode->getBroadcastingPolicy() == Eltwise::BroadcastingPolicy::PerChannel) {
auto rank = getInputShapeAtPort(0).getRank();
if (rank > 4) {
DEBUG_LOG("skip fusing non-perTensor Eltwise:", eltwiseNode->getName(), " into 6D MatMul:", getName());
return false;
}
}
}
}
// Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output. // Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output.
// Then the Matmul will change its output precision to fp32. If fusing FQ into matmul, there would be reorder inserted // Then the Matmul will change its output precision to fp32. If fusing FQ into matmul, there would be reorder inserted
// after matmul. In some bert model, this reorder causes great perf degradation. // after matmul. In some bert model, this reorder causes great perf degradation.