[CPU] fixed conv + dw conv fusing (#6975)
This commit is contained in:
@@ -641,6 +641,9 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) {
|
||||
};
|
||||
|
||||
auto isSutableParentConvolution = [&](MKLDNNNodePtr node) {
|
||||
if (node->isDropped())
|
||||
return false;
|
||||
|
||||
const auto conv = std::dynamic_pointer_cast<MKLDNNConvolutionNode>(node);
|
||||
if (conv == nullptr)
|
||||
IE_THROW() << "Cannot cast to convolution node " << node->getName();
|
||||
@@ -649,17 +652,26 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) {
|
||||
return false;
|
||||
|
||||
const auto &strides = conv->getStride();
|
||||
const auto &paddings = conv->getPaddingL();
|
||||
const auto &inDims = node->getParentEdgeAt(0)->getDims();
|
||||
const auto &outDims = node->getChildEdgeAt(0)->getDims();
|
||||
bool isSupportedParams = conv->getGroupNum() == 1 &&
|
||||
inDims.ndims() == 4 &&
|
||||
inDims[inDims.ndims() - 1] == outDims[outDims.ndims() - 1] &&
|
||||
inDims[inDims.ndims() - 2] == outDims[outDims.ndims() - 2] &&
|
||||
is1x1Convolution(conv) && // TODO [oneDNN] : fusing is permitted only with 1x1 convolutions
|
||||
everyone_is(1, strides[strides.size() - 1], strides[strides.size() - 2]) &&
|
||||
!conv->canBeExecutedInInt8() &&
|
||||
node->getChildEdgeAt(0)->getDims().ndims() == 4;
|
||||
everyone_is(0, paddings[paddings.size() - 1], paddings[paddings.size() - 2]) &&
|
||||
!conv->canBeExecutedInInt8();
|
||||
if (!isSupportedParams) return false;
|
||||
|
||||
return node->getChildEdges().size() == 1 && isConvolutionNode(node->getChildEdgeAt(0)->getChild());
|
||||
};
|
||||
|
||||
auto isSutableChildConvolution = [&](const MKLDNNNodePtr &parentNode, const MKLDNNNodePtr &childNode) {
|
||||
if (parentNode->isDropped() || childNode->isDropped())
|
||||
return false;
|
||||
|
||||
const auto convChild = std::dynamic_pointer_cast<MKLDNNConvolutionNode>(childNode);
|
||||
if (convChild == nullptr)
|
||||
IE_THROW() << "Cannot cast to convolution node " << childNode->getName();
|
||||
|
||||
Reference in New Issue
Block a user