[CPU] fix load time for several models (#5958)
This commit is contained in:
parent
126d1a649c
commit
b1257a5528
@ -158,12 +158,11 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndBias(MKLDNNGraph &graph) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) {
|
auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) {
|
||||||
if ((parentNode->isConstant() && !childNode->isConstant()) || childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() ||
|
if (childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() || childNode->getParentEdges().size() != 2)
|
||||||
childNode->getParentEdges().size() != 2)
|
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto biasNode = childNode->getParentEdgesAtPort(1)[0]->getParent();
|
auto biasNode = childNode->getParentEdgesAtPort(1)[0]->getParent();
|
||||||
if (biasNode->getChildEdges().size() != 1)
|
if (biasNode->getType() != Input || !biasNode->isConstant() || biasNode->getChildEdges().size() != 1)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto convOutDims = parentNode->getChildEdgesAtPort(0)[0]->getDims().ToSizeVector();
|
auto convOutDims = parentNode->getChildEdgesAtPort(0)[0]->getDims().ToSizeVector();
|
||||||
@ -302,6 +301,8 @@ void MKLDNNGraphOptimizer::FuseMultiplyAndAdd(MKLDNNGraph &graph) {
|
|||||||
auto& graphNodes = graph.GetNodes();
|
auto& graphNodes = graph.GetNodes();
|
||||||
|
|
||||||
auto isSutableSecondInput = [](MKLDNNNodePtr node, MKLDNNDims dataDims) {
|
auto isSutableSecondInput = [](MKLDNNNodePtr node, MKLDNNDims dataDims) {
|
||||||
|
if (node->getType() != Input || !node->isConstant())
|
||||||
|
return false;
|
||||||
auto secondInputDims = node->outDims[0];
|
auto secondInputDims = node->outDims[0];
|
||||||
if (secondInputDims.ndims() != dataDims.ndims() || secondInputDims.ndims() < 2)
|
if (secondInputDims.ndims() != dataDims.ndims() || secondInputDims.ndims() < 2)
|
||||||
return false;
|
return false;
|
||||||
@ -326,8 +327,7 @@ void MKLDNNGraphOptimizer::FuseMultiplyAndAdd(MKLDNNGraph &graph) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) {
|
auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) {
|
||||||
if ((parentNode->isConstant() && !childNode->isConstant()) || childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() ||
|
if (childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() || childNode->getParentEdges().size() != 2)
|
||||||
childNode->getParentEdges().size() != 2)
|
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return isSutableSecondInput(childNode->getParentEdgesAtPort(1)[0]->getParent(), childNode->getParentEdgesAtPort(0)[0]->getDims());
|
return isSutableSecondInput(childNode->getParentEdgesAtPort(1)[0]->getParent(), childNode->getParentEdgesAtPort(0)[0]->getDims());
|
||||||
@ -1518,9 +1518,9 @@ void MKLDNNGraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(MKLDNNGraph
|
|||||||
auto& graphNodes = graph.GetNodes();
|
auto& graphNodes = graph.GetNodes();
|
||||||
|
|
||||||
auto getConstPort = [](const MKLDNNNodePtr node) -> int {
|
auto getConstPort = [](const MKLDNNNodePtr node) -> int {
|
||||||
if (node->getParentEdgeAt(0)->getParent()->isConstant() && node->getParentEdgeAt(0)->getParent()->getType() == Input) {
|
if (node->getParentEdgeAt(0)->getParent()->getType() == Input && node->getParentEdgeAt(0)->getParent()->isConstant()) {
|
||||||
return 0;
|
return 0;
|
||||||
} else if (node->getParentEdgeAt(1)->getParent()->isConstant() && node->getParentEdgeAt(1)->getParent()->getType() == Input) {
|
} else if (node->getParentEdgeAt(1)->getParent()->getType() == Input && node->getParentEdgeAt(1)->getParent()->isConstant()) {
|
||||||
return 1;
|
return 1;
|
||||||
} else {
|
} else {
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -1296,7 +1296,7 @@ bool MKLDNNNode::canBePerformedAsScaleShift(const MKLDNNNode *parentNode) const
|
|||||||
fusingPort = i;
|
fusingPort = i;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (!node->isConstant() || node->getType() != Input) {
|
if (node->getType() != Input || !node->isConstant()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user