Add Mish with SoftPlus transformation (#1815)
* Add Mish with SoftPlus transformation * Refactoring accrding code review * Add softplus to mish pass registration * Add checks customer count for SoftPlus and Tanh ops
This commit is contained in:
parent
8017ac03ea
commit
0c1b2f836b
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SoftPlusToMishFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SoftPlusToMishFusion transformation replaces group of
|
||||
* operations: x * tanh(softplus(x)) to Mish op.
|
||||
*/
|
||||
class ngraph::pass::SoftPlusToMishFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
SoftPlusToMishFusion();
|
||||
};
|
@ -17,6 +17,7 @@
|
||||
#include "transformations/itt.hpp"
|
||||
#include "transformations/mish_fusion.hpp"
|
||||
#include "transformations/softplus_fusion.hpp"
|
||||
#include "transformations/softplus_to_mish_fusion.hpp"
|
||||
#include "transformations/swish_fusion.hpp"
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
#include "transformations/normalize_l2_fusion.hpp"
|
||||
@ -45,6 +46,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.register_pass<ngraph::pass::MishFusion>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusFusion>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
|
||||
|
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/softplus_to_mish_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
ngraph::pass::SoftPlusToMishFusion::SoftPlusToMishFusion() {
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto softplus = ngraph::pattern::wrap_type<ngraph::opset4::SoftPlus>({input}, pattern::consumers_count(1));
|
||||
auto tanh = ngraph::pattern::wrap_type<ngraph::opset4::Tanh>({softplus}, pattern::consumers_count(1));
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto & pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
|
||||
auto mish = std::make_shared<ngraph::opset4::Mish>(exp_input);
|
||||
|
||||
mish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(tanh).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softplus).get_node_shared_ptr()}, mish);
|
||||
ngraph::replace_node(m.get_match_root(), mish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SoftPlusToMishFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -11,6 +11,7 @@
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/mish_fusion.hpp>
|
||||
#include <transformations/softplus_to_mish_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
@ -48,3 +49,32 @@ TEST(TransformationTests, MishFusing) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
|
||||
TEST(TransformationTests, MishWithSoftPlusFusing) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input0 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f64, ngraph::Shape{3, 1, 2});
|
||||
auto softplus = std::make_shared<ngraph::opset4::SoftPlus>(input0);
|
||||
auto tanh = std::make_shared<ngraph::opset4::Tanh>(softplus);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input0, tanh);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto mish = std::make_shared<ngraph::opset4::Mish>(data);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mish}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user