[CPU] WA: Stop fusing per-OC eltwise into Matmul with input rank >4 (#16824)
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
#include "memory_desc/cpu_memory_desc_utils.h"
|
||||
#include <dnnl_extension_utils.h>
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
#include <cpu/x64/cpu_isa_traits.hpp>
|
||||
|
||||
using namespace dnnl;
|
||||
using namespace InferenceEngine;
|
||||
@@ -204,6 +205,19 @@ MatMul::MatMul(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
|
||||
}
|
||||
|
||||
bool MatMul::canFuse(const NodePtr& node) const {
|
||||
// WA for CVS-84056: oneDNN brgemm impl has problem with per-OC binary-postOps for MatMul with 6D inputs
|
||||
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core)) {
|
||||
if (auto* eltwiseNode = dynamic_cast<Eltwise*>(node.get())) {
|
||||
if (eltwiseNode->getBroadcastingPolicy() == Eltwise::BroadcastingPolicy::PerChannel) {
|
||||
auto rank = getInputShapeAtPort(0).getRank();
|
||||
if (rank > 4) {
|
||||
DEBUG_LOG("skip fusing non-perTensor Eltwise:", eltwiseNode->getName(), " into 6D MatMul:", getName());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Consider the case when Matmul doesn't support execution in int8, but is getting fused with FQ with int8 output.
|
||||
// Then the Matmul will change its output precision to fp32. If fusing FQ into matmul, there would be reorder inserted
|
||||
// after matmul. In some bert model, this reorder causes great perf degradation.
|
||||
|
||||
Reference in New Issue
Block a user