diff --git a/inference-engine/src/transformations/src/transformations/softplus_decomposition.cpp b/inference-engine/src/transformations/src/transformations/softplus_decomposition.cpp index bfac0afa1e4..010ae8b3f3d 100644 --- a/inference-engine/src/transformations/src/transformations/softplus_decomposition.cpp +++ b/inference-engine/src/transformations/src/transformations/softplus_decomposition.cpp @@ -27,7 +27,7 @@ ngraph::pass::SoftPlusDecomposition::SoftPlusDecomposition() { auto exp = std::make_shared(softplus_input); auto add = std::make_shared(exp, - opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0})); + opset4::Constant::create(softplus_input.get_element_type(), ngraph::Shape{1}, {1.0})); auto log = std::make_shared(add); log->set_friendly_name(softplus_node->get_friendly_name()); diff --git a/inference-engine/tests/functional/inference_engine/transformations/softplus_decomposition_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/softplus_decomposition_test.cpp index ad661918f92..952a50fb652 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/softplus_decomposition_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/softplus_decomposition_test.cpp @@ -18,8 +18,8 @@ using namespace testing; -TEST(TransformationTests, SoftPlusDecomposition) { - std::shared_ptr f(nullptr), f_ref(nullptr); +TEST(TransformationTests, SoftPlusDecompositionFP32) { + std::shared_ptr f, f_ref; { auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); auto softplus = std::make_shared(data); @@ -46,3 +46,32 @@ TEST(TransformationTests, SoftPlusDecomposition) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; } + +TEST(TransformationTests, SoftPlusDecompositionFP16) { + std::shared_ptr f, f_ref; + { + auto data = std::make_shared(ngraph::element::f16, ngraph::Shape{3, 1, 2}); + auto softplus = std::make_shared(data); + + f = std::make_shared(ngraph::NodeVector{softplus}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::Shape{3, 1, 2}); + auto exp = std::make_shared(input); + auto add = std::make_shared(exp, + ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{1}, {1.0})); + auto log = std::make_shared(add); + + f_ref = std::make_shared(ngraph::NodeVector{log}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} \ No newline at end of file