[ONNX] Update ONNX importer to use SotfPlus-4 (#1959)

* Use SoftPlus-4 in ONNX importer

* Tests update
This commit is contained in:
Katarzyna Mitrus 2020-08-27 14:55:04 +02:00 committed by GitHub
parent bb729d0ee9
commit 410559d497
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 45 deletions

View File

@ -18,7 +18,7 @@
#include "ngraph/node.hpp"
#include "onnx_import/default_opset.hpp"
#include "softplus.hpp"
#include "onnx_import/op/softplus.hpp"
namespace ngraph
{
@ -31,36 +31,7 @@ namespace ngraph
OutputVector softplus(const Node& node)
{
const auto data = node.get_ng_inputs().at(0);
const std::shared_ptr<ngraph::Node> zero_node =
default_opset::Constant::create(data.get_element_type(), Shape{}, {0.f});
const std::shared_ptr<ngraph::Node> one_node =
default_opset::Constant::create(data.get_element_type(), Shape{}, {1.f});
// data + log(exp(-data) + 1)
const std::shared_ptr<ngraph::Node> positive_val_node =
std::make_shared<default_opset::Add>(
data,
std::make_shared<default_opset::Log>(
std::make_shared<default_opset::Add>(
std::make_shared<default_opset::Exp>(
std::make_shared<default_opset::Negative>(data)),
one_node)));
// log(exp(data) + 1)
const std::shared_ptr<ngraph::Node> negative_val_node =
std::make_shared<default_opset::Log>(std::make_shared<default_opset::Add>(
std::make_shared<default_opset::Exp>(data), one_node));
const std::shared_ptr<ngraph::Node> condition_node =
std::make_shared<default_opset::Greater>(data, zero_node);
// This equation represents:
// x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow,
// log(exp(x) + 1) - elsewhere.
//
return {std::make_shared<default_opset::Select>(
condition_node, positive_val_node, negative_val_node)};
return {std::make_shared<default_opset::SoftPlus>(data)};
}
} // namespace set_1

View File

@ -1653,20 +1653,20 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softplus)
FLT_MAX,
-FLT_MAX}};
std::vector<float>& input = inputs.back();
std::vector<float> output;
auto softplus_impl = [](float x) -> float {
if (x > 0)
{
return x + std::log(std::exp(-x) + 1);
}
else
{
return std::log(std::exp(x) + 1);
}
};
std::transform(std::begin(input), std::end(input), std::back_inserter(output), softplus_impl);
const auto inf = std::numeric_limits<float>::infinity();
std::vector<float> output{0.3132616579532623291,
0.6931471824645996094,
1.313261628150939941,
10.0000457763671875,
inf,
0.0,
inf,
0.0,
0.6931471824645996094,
0.6931471824645996094,
0.6931471824645996094,
inf,
0.0};
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_multiple_inputs(inputs);