[CPU] Added Mish activation (#1555)
This commit is contained in:
parent
1eac9e3932
commit
50e003cded
@ -704,7 +704,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) {
|
||||
return activationNode &&
|
||||
(activationNode->getAlgorithm() == eltwise_relu ||
|
||||
(conv->getCnnLayer()->precision == Precision::FP32 &&
|
||||
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp, eltwise_swish})));
|
||||
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
|
||||
eltwise_swish, eltwise_mish})));
|
||||
};
|
||||
|
||||
for (int i = 0; i < graphNodes.size(); i++) {
|
||||
@ -1187,7 +1188,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph)
|
||||
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
|
||||
|
||||
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
|
||||
eltwise_clamp, eltwise_swish});
|
||||
eltwise_clamp, eltwise_swish, eltwise_mish});
|
||||
}
|
||||
|
||||
return false;
|
||||
@ -1431,7 +1432,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
|
||||
return activationNode &&
|
||||
(activationNode->getAlgorithm() == eltwise_relu ||
|
||||
(conv->getCnnLayer()->precision == Precision::FP32 &&
|
||||
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp, eltwise_swish})));
|
||||
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
|
||||
eltwise_swish, eltwise_mish})));
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
@ -1781,7 +1783,7 @@ void MKLDNNGraphOptimizer::FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph) {
|
||||
if (activationNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
|
||||
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic,
|
||||
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_linear, eltwise_abs,
|
||||
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_mish, eltwise_linear, eltwise_abs,
|
||||
eltwise_square, eltwise_sqrt});
|
||||
}
|
||||
return false;
|
||||
@ -1893,7 +1895,7 @@ void MKLDNNGraphOptimizer::FuseEltwiseAndSimple(MKLDNNGraph &graph) {
|
||||
if (activationNode == nullptr)
|
||||
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
|
||||
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
|
||||
eltwise_clamp, eltwise_swish});
|
||||
eltwise_clamp, eltwise_swish, eltwise_mish});
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -75,6 +75,7 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
|
||||
{ "Activation", Activation },
|
||||
{ "Clamp", Activation },
|
||||
{ "Swish", Activation },
|
||||
{ "Mish", Activation },
|
||||
{ "ScaleShift", Depthwise },
|
||||
{ "PReLU", Depthwise },
|
||||
{ "Norm", Lrn },
|
||||
|
@ -96,6 +96,11 @@ caseless_map<std::string, std::function<void(GenericLayer*, mkldnn::algorithm&,
|
||||
beta = 0.0f;
|
||||
algorithm = eltwise_swish;
|
||||
}},
|
||||
{"mish", [](GenericLayer* activationLayer, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
|
||||
alpha = 0.0f;
|
||||
beta = 0.0f;
|
||||
algorithm = eltwise_mish;
|
||||
}},
|
||||
};
|
||||
|
||||
MKLDNNActivationNode::MKLDNNActivationNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng,
|
||||
|
@ -51,8 +51,6 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*ActivationLayerTest.*Ceiling.*)",
|
||||
// TODO: Issue: 32032
|
||||
R"(.*ActivationParamLayerTest.*)",
|
||||
// TODO: Issue: 32959
|
||||
R"(.*ActivationLayerTest.*Mish.*)",
|
||||
// TODO: Issue: 30999 (Implement Interpolate reference in NGraph)
|
||||
R"(.*InterpolateLayerTest.*)"
|
||||
};
|
||||
|
2
inference-engine/thirdparty/mkl-dnn
vendored
2
inference-engine/thirdparty/mkl-dnn
vendored
@ -1 +1 @@
|
||||
Subproject commit 36f650aac835b5ef8ab2459eda337ed881a1d3c4
|
||||
Subproject commit 4f511de56e21b417f7c49c3f50cf7217350412ab
|
Loading…
Reference in New Issue
Block a user