diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index 477de31b769..73c353960a1 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -846,7 +846,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) { bool isSupportedParams = layer->_group == 1 && is1x1Convolution(layer) && // TODO [oneDNN] : fusing is permitted only with 1x1 convolutions everyone_is(1, layer->_stride[X_AXIS], layer->_stride[Y_AXIS]) && - one_of(layer->outData[0].get()->getPrecision(), Precision::FP32) && + everyone_is(Precision::FP32, layer->insData[0].lock()->getPrecision(), layer->outData[0].get()->getPrecision()) && node->getChildEdgeAt(0)->getDims().ndims() == 4; if (!isSupportedParams) return false; @@ -862,10 +862,11 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) { if (parentLayer == nullptr) IE_THROW() << "Cannot get convolution layer " << parentNode->getName(); - if (parentLayer->outData[0].get()->getPrecision() != childLayer->outData[0].get()->getPrecision()) + if (!everyone_is(Precision::FP32, parentLayer->outData[0].get()->getPrecision(), childLayer->insData[0].lock()->getPrecision(), + childLayer->outData[0].get()->getPrecision())) return false; - if (parentLayer->precision != childLayer->precision) + if (!everyone_is(Precision::FP32, parentLayer->precision, childLayer->precision)) return false; auto parentOutputPrecision = !parentNode->fusedWith.empty() @@ -876,7 +877,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) { ? childNode->fusedWith[childNode->fusedWith.size() - 1]->getCnnLayer()->outData[0].get()->getPrecision() : childNode->getCnnLayer()->outData[0].get()->getPrecision(); - if (parentOutputPrecision != childOutputPrecision) + if (!everyone_is(Precision::FP32, parentOutputPrecision, childOutputPrecision)) return false; auto* childConvolutionNode = dynamic_cast(childNode.get());