[CPU] Disabled input zero point fusing into fp32 Convolution (#4056)
This commit is contained in:
parent
a67a72084f
commit
a8b921791e
@ -216,7 +216,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto initializeInputZeroPoints = [](MKLDNNNodePtr node, MKLDNNNodePtr parent0) {
|
||||
auto initializeInputZeroPoints = [](MKLDNNNodePtr node, MKLDNNNodePtr parent0, MKLDNNNodePtr parent1) {
|
||||
auto* convNode = dynamic_cast<MKLDNNConvolutionNode*>(node.get());
|
||||
if (convNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get convolution node " << node->getName();
|
||||
@ -225,6 +225,14 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
|
||||
int OC = node->getChildEdgesAtPort(0)[0]->getDims()[1];
|
||||
|
||||
if (parent0->getType() == Eltwise) {
|
||||
// The plug-in doesn't support FP32 convolution with input/weights zero points.
|
||||
// In case weights are in FP32 (or we have zero points on weights which are not supported by INT8 convolution) we cannot use
|
||||
// INT8 implementation so we have to disable input zero points fusing as well.
|
||||
auto weightsLayer = parent1->getCnnLayer();
|
||||
if (!weightsLayer || weightsLayer->type != "Const" || weightsLayer->outData[0]->getPrecision() != Precision::I8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(parent0.get());
|
||||
if (eltwiseNode->getOpType() != Subtract)
|
||||
return false;
|
||||
@ -395,7 +403,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
|
||||
if (!isSutableConvNode(conv)) continue;
|
||||
|
||||
auto dataEltwise = conv->getParentEdgesAtPort(0)[0]->getParent();
|
||||
if (initializeInputZeroPoints(conv, dataEltwise)) {
|
||||
auto weightsEltwise = conv->getParentEdgesAtPort(1)[0]->getParent();
|
||||
if (initializeInputZeroPoints(conv, dataEltwise, weightsEltwise)) {
|
||||
auto p_edge = dataEltwise->getParentEdgesAtPort(1)[0];
|
||||
removeEdge(graph, p_edge);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user