From 03ca3d1ef712506e52c36218c51c65b61779c39c Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 30 Apr 2021 13:34:33 +0300 Subject: [PATCH] [CPU] Fixed SoftPlus for large positive values (#4932) --- docs/ops/activation/SoftPlus_4.md | 16 +++++++++++++++- .../src/mkldnn_plugin/mkldnn_graph_optimizer.cpp | 10 +++++----- .../src/mkldnn_plugin/mkldnn_node.cpp | 1 + .../src/mkldnn_plugin/nodes/list_tbl.hpp | 1 - .../mkldnn_plugin/nodes/mkldnn_bin_conv_node.cpp | 2 +- .../mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp | 5 +++-- .../nodes/mkldnn_interpolate_node.cpp | 2 +- .../cpu/single_layer_tests/convolution.cpp | 1 + .../cpu/single_layer_tests/group_convolution.cpp | 1 + .../plugin/cpu/test_utils/fusing_test_utils.hpp | 4 ++++ .../shared_tests_instances/skip_tests_config.cpp | 2 ++ .../shared_tests_instances/skip_tests_config.cpp | 2 ++ .../src/single_layer/activation.cpp | 6 ++++++ inference-engine/thirdparty/mkl-dnn | 2 +- .../ngraph/runtime/reference/softplus.hpp | 5 ++++- ngraph/python/tests/test_onnx/test_ops_unary.py | 2 +- ngraph/test/onnx/onnx_import.in.cpp | 4 ++-- 17 files changed, 50 insertions(+), 16 deletions(-) diff --git a/docs/ops/activation/SoftPlus_4.md b/docs/ops/activation/SoftPlus_4.md index 8afc94684ac..19714de749b 100644 --- a/docs/ops/activation/SoftPlus_4.md +++ b/docs/ops/activation/SoftPlus_4.md @@ -13,9 +13,23 @@ *SoftPlus* performs element-wise activation function on a given input tensor, based on the following mathematical formula: \f[ -SoftPlus(x) = \ln(1+e^{x}) +SoftPlus(x) = \left\{\begin{array}{r} + x \qquad \mbox{if } x \geq threshold \\ + log(e^{x} + 1.0) \qquad \mbox{if } x < threshold +\end{array}\right. \f] +**Note**: For numerical stability the operation reverts to the linear function when `x > threshold` where `threshold` depends on *T* and +is chosen in such a way that the difference between the linear function and exact calculation is no more than `1e-6`. +The `threshold` can be calculated with the following formula where `alpha` is the number of digits after the decimal point, +`beta` is maximum value of *T* data type: + +\f[ +-log(e^{10^{-\alpha}} - 1.0) < threshold < log(\beta) +\f] + +For example, if *T* is `fp32`, `threshold` should be `20` or if *T* is `fp16`, `threshold` should be `12`. + **Attributes**: *SoftPlus* operation has no attributes. diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index 10c808ba2b2..0976b5f1961 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -615,7 +615,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) { (eltwiseNode->getOpType() == Relu || (conv->getCnnLayer()->precision == Precision::FP32 && IsOneOf(eltwiseNode->getOpType(), {Elu, Logistic, BoundedRelu, Clamp, Swish, Hswish, Mish, Hsigmoid, - Round}))); + Round, SoftRelu}))); }; for (int i = 0; i < graphNodes.size(); i++) { @@ -694,7 +694,7 @@ void MKLDNNGraphOptimizer::FuseFullyConnectedAndSimpleOperation(MKLDNNGraph &gra IE_THROW() << "Cannot get Eltwise node " << childNode->getName(); if (IsOneOf(eltwiseNode->getOpType(), {Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, Swish, Hswish, Mish, - Hsigmoid, Round})) { + Hsigmoid, Round, SoftRelu})) { return true; } else if (IsOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu})) { if (eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() != 2) @@ -1053,7 +1053,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph) return ((eltwiseNode->getOpType() == MulAdd && node->getCnnLayer()->blobs.size() == 2) || (eltwiseNode->getOpType() == Prelu) || IsOneOf(eltwiseNode->getOpType(), {Relu, Elu, Logistic, BoundedRelu, Clamp, Swish, Hswish, Mish, - Hsigmoid, Round})); + Hsigmoid, Round, SoftRelu})); } return false; @@ -1269,7 +1269,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG (eltwiseNode->getOpType() == Relu || (conv->getCnnLayer()->precision == Precision::FP32 && IsOneOf(eltwiseNode->getOpType(), {Elu, Logistic, BoundedRelu, Clamp, Swish, Hswish, Mish, Hsigmoid, - Round}))); + Round, SoftRelu}))); }; for (auto &graphNode : graphNodes) { @@ -1568,7 +1568,7 @@ void MKLDNNGraphOptimizer::FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph) { if (eltwiseNode == nullptr) IE_THROW() << "Cannot get Eltwise node " << node->getName(); return IsOneOf(eltwiseNode->getOpType(), {Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, Tanh, Swish, - Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}) || + Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt, SoftRelu}) || ((eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2) || (eltwiseNode->getOpType() == Prelu)); } diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index 114b6d18c05..d3af44347ad 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -80,6 +80,7 @@ static const InferenceEngine::details::caseless_unordered_map { "Round", Eltwise }, { "ScaleShift", Eltwise }, { "PReLU", Eltwise }, + { "SoftPlus", Eltwise }, { "Norm", Lrn }, { "LRN", Lrn }, { "Pooling", Pooling }, diff --git a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp index da6e70d7eeb..e66af69e08f 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp +++ b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp @@ -32,7 +32,6 @@ MKLDNN_EXTENSION_NODE(MathImpl, Selu); MKLDNN_EXTENSION_NODE(MathImpl, Sign); MKLDNN_EXTENSION_NODE(MathImpl, Sin); MKLDNN_EXTENSION_NODE(MathImpl, Sinh); -MKLDNN_EXTENSION_NODE(MathImpl, SoftPlus); MKLDNN_EXTENSION_NODE(MathImpl, Softsign); MKLDNN_EXTENSION_NODE(MathImpl, Tan); MKLDNN_EXTENSION_NODE(ExperimentalDetectronTopKROIsImpl, ExperimentalDetectronTopKROIs); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_bin_conv_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_bin_conv_node.cpp index 68c554ceef0..1738d1798a9 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_bin_conv_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_bin_conv_node.cpp @@ -1114,7 +1114,7 @@ bool MKLDNNBinaryConvolutionNode::canFuse(const MKLDNNNodePtr& node) const { } return eltwiseNode->isSum() || - isOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, + isOneOf(eltwiseNode->getOpType(), {MulAdd, Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, SoftRelu, Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}); } diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp index 3f7b02b9a4c..fca94bf51d9 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp @@ -843,7 +843,7 @@ MKLDNNEltwiseNode::initializers = { opType = BoundedRelu; algorithm = mkldnn::algorithm::eltwise_bounded_relu; }}, - {"soft_relu", [](GenericLayer* activationLayer, EltwiseOpType& opType, mkldnn::algorithm& algorithm, float& alpha, float& beta) { + {"softplus", [](GenericLayer* activationLayer, EltwiseOpType& opType, mkldnn::algorithm& algorithm, float& alpha, float& beta) { alpha = 0.0f; beta = 0.0f; opType = SoftRelu; @@ -983,7 +983,8 @@ void MKLDNNEltwiseNode::init() { comparator(layerType, "hswish") || comparator(layerType, "mish") || comparator(layerType, "hsigmoid") || - comparator(layerType, "round")) { + comparator(layerType, "round") || + comparator(layerType, "softplus")) { initializers[layerType](getCnnLayer().get(), eltwiseOp, eltwiseAlgorithm, alpha, beta); } else if (comparator(layerType, "erf")) { eltwiseOp = Erf; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp index 8c496991393..b87c33b7320 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_interpolate_node.cpp @@ -3197,7 +3197,7 @@ bool MKLDNNInterpolateNode::canFuse(const MKLDNNNodePtr& node) const { auto* eltwiseNode = dynamic_cast(node.get()); if (eltwiseNode == nullptr) IE_THROW() << "Cannot get eltwise node " << node->getName(); - return isOneOf(eltwiseNode->getOpType(), {Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, + return isOneOf(eltwiseNode->getOpType(), {Prelu, Relu, Gelu, Elu, Logistic, BoundedRelu, Clamp, SoftRelu, Tanh, Swish, Hswish, Mish, Hsigmoid, Round, Linear, Abs, Square, Sqrt}) || (eltwiseNode->getOpType() == MulAdd && eltwiseNode->getCnnLayer()->blobs.size() == 2); } diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convolution.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convolution.cpp index 7b596f40a9a..597b6d053b7 100755 --- a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convolution.cpp +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convolution.cpp @@ -112,6 +112,7 @@ const std::vector fusingParamsSet{ fusingSwish, fusingHSwish, fusingMish, + fusingSoftPlus, // other patterns fusingReluScaleShift, fusingFakeQuantizePerTensorRelu, diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/group_convolution.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/group_convolution.cpp index 068f074c7ff..b3267e7e199 100644 --- a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/group_convolution.cpp +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/group_convolution.cpp @@ -123,6 +123,7 @@ std::vector fusingParamsSet { fusingSwish, fusingHSwish, fusingMish, + fusingSoftPlus, // other patterns fusingReluScaleShift, fusingFakeQuantizePerTensorRelu, diff --git a/inference-engine/tests/functional/plugin/cpu/test_utils/fusing_test_utils.hpp b/inference-engine/tests/functional/plugin/cpu/test_utils/fusing_test_utils.hpp index 9d2de2b2715..b084dacbd16 100644 --- a/inference-engine/tests/functional/plugin/cpu/test_utils/fusing_test_utils.hpp +++ b/inference-engine/tests/functional/plugin/cpu/test_utils/fusing_test_utils.hpp @@ -112,6 +112,10 @@ const auto fusingMish = fusingSpecificParams{std::make_shared(std: {[](std::shared_ptr inpNode, const ngraph::element::Type& ngPrc, ngraph::ParameterVector& params){ return ngraph::builder::makeActivation(inpNode, ngPrc, ngraph::helpers::Mish, {}, {}); }, "Mish"}}), {"Mish"}}; +const auto fusingSoftPlus = fusingSpecificParams{std::make_shared(std::vector{ + {[](std::shared_ptr inpNode, const ngraph::element::Type& ngPrc, ngraph::ParameterVector& params){ + return ngraph::builder::makeActivation(inpNode, ngPrc, ngraph::helpers::SoftPlus, {}, {}); + }, "SoftPlus"}}), {"SoftPlus"}}; const auto fusingTanh = fusingSpecificParams{std::make_shared(std::vector{ {[](std::shared_ptr inpNode, const ngraph::element::Type& ngPrc, ngraph::ParameterVector& params){ return ngraph::builder::makeActivation(inpNode, ngPrc, ngraph::helpers::Tanh, {}, {}); diff --git a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/skip_tests_config.cpp b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/skip_tests_config.cpp index 01ffef7fe45..631c6a88c4d 100644 --- a/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/skip_tests_config.cpp +++ b/inference-engine/tests/functional/plugin/gpu/shared_tests_instances/skip_tests_config.cpp @@ -59,5 +59,7 @@ std::vector disabledTestPatterns() { R"(.*ConstantResultSubgraphTest.*inPrc=I16.*)", // TODO: Issue: 54436 R"(.*LSTMSequence.*CompareWithRefs.*mode=PURE_SEQ_RAND_SEQ_LEN_PARAM.*direction=bidirectional_clip=0.7_netPRC=FP32.*)", + // TODO: Issue: 54194 + R"(.*ActivationLayerTest.*SoftPlus.*)", }; } diff --git a/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/skip_tests_config.cpp b/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/skip_tests_config.cpp index 8bf55a9e35d..5252cddbd95 100644 --- a/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/skip_tests_config.cpp +++ b/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/skip_tests_config.cpp @@ -37,5 +37,7 @@ std::vector disabledTestPatterns() { R"(.*CTCGreedyDecoderSeqLen.*?\(1.1.1\).*)", // TODO: Issue 51804 ".*PreprocessConversionTest.*oPRC=U8.*", + // TODO: Issue 54163 + R"(.*ActivationLayerTest.*SoftPlus.*)", }; } diff --git a/inference-engine/tests/functional/shared_test_classes/src/single_layer/activation.cpp b/inference-engine/tests/functional/shared_test_classes/src/single_layer/activation.cpp index 83edd50358b..58a671eae73 100644 --- a/inference-engine/tests/functional/shared_test_classes/src/single_layer/activation.cpp +++ b/inference-engine/tests/functional/shared_test_classes/src/single_layer/activation.cpp @@ -101,6 +101,12 @@ InferenceEngine::Blob::Ptr ActivationLayerTest::GenerateInput(const InferenceEng resolution = 32768; break; } + case ngraph::helpers::ActivationTypes::SoftPlus: { + data_start_from = -100; + data_range = 200; + resolution = 32768; + break; + } default: { data_start_from = -10; data_range = 20; diff --git a/inference-engine/thirdparty/mkl-dnn b/inference-engine/thirdparty/mkl-dnn index 0292c2a2a25..2dd78726213 160000 --- a/inference-engine/thirdparty/mkl-dnn +++ b/inference-engine/thirdparty/mkl-dnn @@ -1 +1 @@ -Subproject commit 0292c2a2a2525ff86590de3b499ceb61a5e2355f +Subproject commit 2dd787262134c20f91f222bfa776225d2dddbc9a diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/softplus.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/softplus.hpp index d68ecb31c0d..a5c95e4c6f9 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/softplus.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/softplus.hpp @@ -16,9 +16,12 @@ namespace ngraph template void softplus(const T* arg, T* out, size_t count) { + const T threshold = static_cast(-std::log(std::exp(std::pow(10, -6)) - 1)); + for (size_t i = 0; i < count; i++) { - out[i] = std::log(std::exp(arg[i]) + 1.0); + out[i] = (arg[i] < threshold) ? static_cast(std::log(std::exp(arg[i]) + 1)) + : arg[i]; } } } // namespace reference diff --git a/ngraph/python/tests/test_onnx/test_ops_unary.py b/ngraph/python/tests/test_onnx/test_ops_unary.py index ca300167b8c..582749264a6 100644 --- a/ngraph/python/tests/test_onnx/test_ops_unary.py +++ b/ngraph/python/tests/test_onnx/test_ops_unary.py @@ -266,7 +266,7 @@ def test_logsoftmax(): def test_softplus(): def softplus(x): - return np.log(np.exp(x) + 1) + return np.where(x < 20, np.log(np.exp(x) + 1), x) np.random.seed(133391) data = np.random.randn(3, 4, 5).astype(np.float32) diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index a8b0d7c90b4..a5de02359ee 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -2402,9 +2402,9 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softplus) 0.6931471824645996094, 1.313261628150939941, 10.0000457763671875, - inf, + 100.0, 0.0, - inf, + 1000.0, 0.0, 0.6931471824645996094, 0.6931471824645996094,