[CPU] Disabled input zero point fusing into fp32 Convolution (#4056)

This commit is contained in:
Gorokhov Dmitriy 2021-01-29 08:38:58 +03:00 committed by GitHub
parent a67a72084f
commit a8b921791e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -216,7 +216,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
return true;
};
auto initializeInputZeroPoints = [](MKLDNNNodePtr node, MKLDNNNodePtr parent0) {
auto initializeInputZeroPoints = [](MKLDNNNodePtr node, MKLDNNNodePtr parent0, MKLDNNNodePtr parent1) {
auto* convNode = dynamic_cast<MKLDNNConvolutionNode*>(node.get());
if (convNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get convolution node " << node->getName();
@ -225,6 +225,14 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
int OC = node->getChildEdgesAtPort(0)[0]->getDims()[1];
if (parent0->getType() == Eltwise) {
// The plug-in doesn't support FP32 convolution with input/weights zero points.
// In case weights are in FP32 (or we have zero points on weights which are not supported by INT8 convolution) we cannot use
// INT8 implementation so we have to disable input zero points fusing as well.
auto weightsLayer = parent1->getCnnLayer();
if (!weightsLayer || weightsLayer->type != "Const" || weightsLayer->outData[0]->getPrecision() != Precision::I8) {
return false;
}
auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(parent0.get());
if (eltwiseNode->getOpType() != Subtract)
return false;
@ -395,7 +403,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
if (!isSutableConvNode(conv)) continue;
auto dataEltwise = conv->getParentEdgesAtPort(0)[0]->getParent();
if (initializeInputZeroPoints(conv, dataEltwise)) {
auto weightsEltwise = conv->getParentEdgesAtPort(1)[0]->getParent();
if (initializeInputZeroPoints(conv, dataEltwise, weightsEltwise)) {
auto p_edge = dataEltwise->getParentEdgesAtPort(1)[0];
removeEdge(graph, p_edge);