[ONNX] Update ONNX importer to use SotfPlus-4 (#1959)
* Use SoftPlus-4 in ONNX importer * Tests update
This commit is contained in:
parent
bb729d0ee9
commit
410559d497
@ -18,7 +18,7 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "onnx_import/default_opset.hpp"
|
||||
#include "softplus.hpp"
|
||||
#include "onnx_import/op/softplus.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -31,36 +31,7 @@ namespace ngraph
|
||||
OutputVector softplus(const Node& node)
|
||||
{
|
||||
const auto data = node.get_ng_inputs().at(0);
|
||||
|
||||
const std::shared_ptr<ngraph::Node> zero_node =
|
||||
default_opset::Constant::create(data.get_element_type(), Shape{}, {0.f});
|
||||
const std::shared_ptr<ngraph::Node> one_node =
|
||||
default_opset::Constant::create(data.get_element_type(), Shape{}, {1.f});
|
||||
|
||||
// data + log(exp(-data) + 1)
|
||||
const std::shared_ptr<ngraph::Node> positive_val_node =
|
||||
std::make_shared<default_opset::Add>(
|
||||
data,
|
||||
std::make_shared<default_opset::Log>(
|
||||
std::make_shared<default_opset::Add>(
|
||||
std::make_shared<default_opset::Exp>(
|
||||
std::make_shared<default_opset::Negative>(data)),
|
||||
one_node)));
|
||||
|
||||
// log(exp(data) + 1)
|
||||
const std::shared_ptr<ngraph::Node> negative_val_node =
|
||||
std::make_shared<default_opset::Log>(std::make_shared<default_opset::Add>(
|
||||
std::make_shared<default_opset::Exp>(data), one_node));
|
||||
|
||||
const std::shared_ptr<ngraph::Node> condition_node =
|
||||
std::make_shared<default_opset::Greater>(data, zero_node);
|
||||
|
||||
// This equation represents:
|
||||
// x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow,
|
||||
// log(exp(x) + 1) - elsewhere.
|
||||
//
|
||||
return {std::make_shared<default_opset::Select>(
|
||||
condition_node, positive_val_node, negative_val_node)};
|
||||
return {std::make_shared<default_opset::SoftPlus>(data)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
@ -1653,20 +1653,20 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softplus)
|
||||
FLT_MAX,
|
||||
-FLT_MAX}};
|
||||
|
||||
std::vector<float>& input = inputs.back();
|
||||
std::vector<float> output;
|
||||
auto softplus_impl = [](float x) -> float {
|
||||
if (x > 0)
|
||||
{
|
||||
return x + std::log(std::exp(-x) + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::log(std::exp(x) + 1);
|
||||
}
|
||||
};
|
||||
|
||||
std::transform(std::begin(input), std::end(input), std::back_inserter(output), softplus_impl);
|
||||
const auto inf = std::numeric_limits<float>::infinity();
|
||||
std::vector<float> output{0.3132616579532623291,
|
||||
0.6931471824645996094,
|
||||
1.313261628150939941,
|
||||
10.0000457763671875,
|
||||
inf,
|
||||
0.0,
|
||||
inf,
|
||||
0.0,
|
||||
0.6931471824645996094,
|
||||
0.6931471824645996094,
|
||||
0.6931471824645996094,
|
||||
inf,
|
||||
0.0};
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_multiple_inputs(inputs);
|
||||
|
Loading…
Reference in New Issue
Block a user