Add Mish with SoftPlus transformation (#1815)

* Add Mish with SoftPlus transformation

* Refactoring accrding code review

* Add softplus to mish pass registration

* Add checks customer count for SoftPlus and Tanh ops
This commit is contained in:
iliya mironov 2020-09-04 11:07:37 +03:00 committed by GitHub
parent 8017ac03ea
commit 0c1b2f836b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 0 deletions

View File

@ -0,0 +1,32 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SoftPlusToMishFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SoftPlusToMishFusion transformation replaces group of
* operations: x * tanh(softplus(x)) to Mish op.
*/
class ngraph::pass::SoftPlusToMishFusion: public ngraph::pass::MatcherPass {
public:
SoftPlusToMishFusion();
};

View File

@ -17,6 +17,7 @@
#include "transformations/itt.hpp"
#include "transformations/mish_fusion.hpp"
#include "transformations/softplus_fusion.hpp"
#include "transformations/softplus_to_mish_fusion.hpp"
#include "transformations/swish_fusion.hpp"
#include "transformations/hswish_fusion.hpp"
#include "transformations/normalize_l2_fusion.hpp"
@ -45,6 +46,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();
manager.register_pass<ngraph::pass::SoftPlusFusion>();
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/softplus_to_mish_fusion.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
ngraph::pass::SoftPlusToMishFusion::SoftPlusToMishFusion() {
auto input = ngraph::pattern::any_input();
auto softplus = ngraph::pattern::wrap_type<ngraph::opset4::SoftPlus>({input}, pattern::consumers_count(1));
auto tanh = ngraph::pattern::wrap_type<ngraph::opset4::Tanh>({softplus}, pattern::consumers_count(1));
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);
auto mish = std::make_shared<ngraph::opset4::Mish>(exp_input);
mish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(tanh).get_node_shared_ptr(),
pattern_to_output.at(softplus).get_node_shared_ptr()}, mish);
ngraph::replace_node(m.get_match_root(), mish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SoftPlusToMishFusion");
register_matcher(m, callback);
}

View File

@ -11,6 +11,7 @@
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/mish_fusion.hpp>
#include <transformations/softplus_to_mish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
@ -48,3 +49,32 @@ TEST(TransformationTests, MishFusing) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, MishWithSoftPlusFusing) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f64, ngraph::Shape{3, 1, 2});
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(input0);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(softplus);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input0, tanh);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto mish = std::make_shared<ngraph::opset4::Mish>(data);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mish}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}