Fixed SoftPlusDecomposition transformation (#2011)
This commit is contained in:
parent
757b1f0d9e
commit
25856f4cdc
@ -27,7 +27,7 @@ ngraph::pass::SoftPlusDecomposition::SoftPlusDecomposition() {
|
|||||||
|
|
||||||
auto exp = std::make_shared<ngraph::opset4::Exp>(softplus_input);
|
auto exp = std::make_shared<ngraph::opset4::Exp>(softplus_input);
|
||||||
auto add = std::make_shared<ngraph::opset4::Add>(exp,
|
auto add = std::make_shared<ngraph::opset4::Add>(exp,
|
||||||
opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0}));
|
opset4::Constant::create(softplus_input.get_element_type(), ngraph::Shape{1}, {1.0}));
|
||||||
auto log = std::make_shared<ngraph::opset4::Log>(add);
|
auto log = std::make_shared<ngraph::opset4::Log>(add);
|
||||||
|
|
||||||
log->set_friendly_name(softplus_node->get_friendly_name());
|
log->set_friendly_name(softplus_node->get_friendly_name());
|
||||||
|
@ -18,8 +18,8 @@
|
|||||||
|
|
||||||
using namespace testing;
|
using namespace testing;
|
||||||
|
|
||||||
TEST(TransformationTests, SoftPlusDecomposition) {
|
TEST(TransformationTests, SoftPlusDecompositionFP32) {
|
||||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||||
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
|
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
|
||||||
@ -46,3 +46,32 @@ TEST(TransformationTests, SoftPlusDecomposition) {
|
|||||||
auto res = compare_functions(f, f_ref);
|
auto res = compare_functions(f, f_ref);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, SoftPlusDecompositionFP16) {
|
||||||
|
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{3, 1, 2});
|
||||||
|
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(data);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{softplus}, ngraph::ParameterVector{data});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::SoftPlusDecomposition>();
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{3, 1, 2});
|
||||||
|
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
|
||||||
|
auto add = std::make_shared<ngraph::opset4::Add>(exp,
|
||||||
|
ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{1}, {1.0}));
|
||||||
|
auto log = std::make_shared<ngraph::opset4::Log>(add);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{log}, ngraph::ParameterVector{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user