[CPU] fixed dw conv fusing (#5146)
This commit is contained in:
parent
634dc42808
commit
617a1024e1
@ -846,7 +846,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) {
|
||||
bool isSupportedParams = layer->_group == 1 &&
|
||||
is1x1Convolution(layer) && // TODO [oneDNN] : fusing is permitted only with 1x1 convolutions
|
||||
everyone_is(1, layer->_stride[X_AXIS], layer->_stride[Y_AXIS]) &&
|
||||
one_of(layer->outData[0].get()->getPrecision(), Precision::FP32) &&
|
||||
everyone_is(Precision::FP32, layer->insData[0].lock()->getPrecision(), layer->outData[0].get()->getPrecision()) &&
|
||||
node->getChildEdgeAt(0)->getDims().ndims() == 4;
|
||||
if (!isSupportedParams) return false;
|
||||
|
||||
@ -862,10 +862,11 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) {
|
||||
if (parentLayer == nullptr)
|
||||
IE_THROW() << "Cannot get convolution layer " << parentNode->getName();
|
||||
|
||||
if (parentLayer->outData[0].get()->getPrecision() != childLayer->outData[0].get()->getPrecision())
|
||||
if (!everyone_is(Precision::FP32, parentLayer->outData[0].get()->getPrecision(), childLayer->insData[0].lock()->getPrecision(),
|
||||
childLayer->outData[0].get()->getPrecision()))
|
||||
return false;
|
||||
|
||||
if (parentLayer->precision != childLayer->precision)
|
||||
if (!everyone_is(Precision::FP32, parentLayer->precision, childLayer->precision))
|
||||
return false;
|
||||
|
||||
auto parentOutputPrecision = !parentNode->fusedWith.empty()
|
||||
@ -876,7 +877,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) {
|
||||
? childNode->fusedWith[childNode->fusedWith.size() - 1]->getCnnLayer()->outData[0].get()->getPrecision()
|
||||
: childNode->getCnnLayer()->outData[0].get()->getPrecision();
|
||||
|
||||
if (parentOutputPrecision != childOutputPrecision)
|
||||
if (!everyone_is(Precision::FP32, parentOutputPrecision, childOutputPrecision))
|
||||
return false;
|
||||
|
||||
auto* childConvolutionNode = dynamic_cast<MKLDNNConvolutionNode*>(childNode.get());
|
||||
|
Loading…
Reference in New Issue
Block a user