Fixed SoftPlusDecomposition transformation (#2011)

This commit is contained in:
Gleb Kazantaev 2020-09-01 16:20:42 +03:00 committed by GitHub
parent 757b1f0d9e
commit 25856f4cdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 3 deletions

View File

@ -27,7 +27,7 @@ ngraph::pass::SoftPlusDecomposition::SoftPlusDecomposition() {
auto exp = std::make_shared<ngraph::opset4::Exp>(softplus_input);
auto add = std::make_shared<ngraph::opset4::Add>(exp,
opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0}));
opset4::Constant::create(softplus_input.get_element_type(), ngraph::Shape{1}, {1.0}));
auto log = std::make_shared<ngraph::opset4::Log>(add);
log->set_friendly_name(softplus_node->get_friendly_name());

View File

@ -18,8 +18,8 @@
using namespace testing;
TEST(TransformationTests, SoftPlusDecomposition) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
TEST(TransformationTests, SoftPlusDecompositionFP32) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
@ -46,3 +46,32 @@ TEST(TransformationTests, SoftPlusDecomposition) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SoftPlusDecompositionFP16) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{3, 1, 2});
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{softplus}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{3, 1, 2});
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
auto add = std::make_shared<ngraph::opset4::Add>(exp,
ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{1}, {1.0}));
auto log = std::make_shared<ngraph::opset4::Log>(add);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{log}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}