Fixed SoftPlusDecomposition transformation (#2011)
This commit is contained in:
parent
757b1f0d9e
commit
25856f4cdc
@ -27,7 +27,7 @@ ngraph::pass::SoftPlusDecomposition::SoftPlusDecomposition() {
|
||||
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(softplus_input);
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp,
|
||||
opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0}));
|
||||
opset4::Constant::create(softplus_input.get_element_type(), ngraph::Shape{1}, {1.0}));
|
||||
auto log = std::make_shared<ngraph::opset4::Log>(add);
|
||||
|
||||
log->set_friendly_name(softplus_node->get_friendly_name());
|
||||
|
@ -18,8 +18,8 @@
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, SoftPlusDecomposition) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
TEST(TransformationTests, SoftPlusDecompositionFP32) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
|
||||
@ -46,3 +46,32 @@ TEST(TransformationTests, SoftPlusDecomposition) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SoftPlusDecompositionFP16) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{3, 1, 2});
|
||||
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{softplus}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{3, 1, 2});
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp,
|
||||
ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{1}, {1.0}));
|
||||
auto log = std::make_shared<ngraph::opset4::Log>(add);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{log}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user