[Transformation] Add GeluFusion with Tanh (#11752)

This commit is contained in:
Katarzyna Mitrus
2022-06-03 08:36:12 +02:00
committed by GitHub
parent 1db4446e2a
commit 47155b43d0
3 changed files with 481 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ class TRANSFORMATIONS_API GeluFusion;
class TRANSFORMATIONS_API GeluFusionWithErfOne;
class TRANSFORMATIONS_API GeluFusionWithErfTwo;
class TRANSFORMATIONS_API GeluFusionWithErfThree;
class TRANSFORMATIONS_API GeluFusionWithTanh;
} // namespace pass
} // namespace ngraph
@@ -53,6 +54,17 @@ public:
GeluFusionWithErfThree();
};
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces a sub-graph
* x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3])) with a Gelu (Tanh) op.
*/
class ngraph::pass::GeluFusionWithTanh : public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("GeluFusionWithTanh", "0");
GeluFusionWithTanh();
};
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces various sub-graphs with a Gelu op.
@@ -64,5 +76,6 @@ public:
add_matcher<ngraph::pass::GeluFusionWithErfOne>();
add_matcher<ngraph::pass::GeluFusionWithErfTwo>();
add_matcher<ngraph::pass::GeluFusionWithErfThree>();
add_matcher<ngraph::pass::GeluFusionWithTanh>();
}
};

View File

@@ -10,6 +10,7 @@
#include <memory>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
@@ -192,3 +193,86 @@ ngraph::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}
ngraph::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
MATCHER_SCOPE(GeluFusionWithTanh);
// Replaces a sub-graph with a Gelu (Tanh) op
// Gaussian Error Linear Unit, TanH based approximation:
// x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]))
auto input = ngraph::pattern::any_input();
auto pow_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
auto pow = ngraph::pattern::wrap_type<ngraph::opset9::Power>({input, pow_constant});
auto mul_0_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
auto mul_0 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({pow, mul_0_constant});
auto add_0 = ngraph::pattern::wrap_type<ngraph::opset9::Add>({input, mul_0});
auto mul_1_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
auto mul_1 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({add_0, mul_1_constant});
auto tanh = ngraph::pattern::wrap_type<ngraph::opset9::Tanh>({mul_1});
auto add_1_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
auto add_1 = ngraph::pattern::wrap_type<ngraph::opset9::Add>({tanh, add_1_constant});
auto mul_2_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
auto mul_2 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({add_1, mul_2_constant});
auto mul_3 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({input, mul_2});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto pow_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
pattern_to_output.at(pow_constant).get_node_shared_ptr());
auto mul_0_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
pattern_to_output.at(mul_0_constant).get_node_shared_ptr());
auto mul_1_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
pattern_to_output.at(mul_1_constant).get_node_shared_ptr());
auto mul_2_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
pattern_to_output.at(mul_2_constant).get_node_shared_ptr());
auto add_1_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
pattern_to_output.at(add_1_constant).get_node_shared_ptr());
if (!pow_constant_value || !add_1_constant_value || !mul_0_constant_value || !mul_1_constant_value ||
!mul_2_constant_value) {
return false;
}
constexpr float pi = 3.141592653589793238462643383279502884f;
bool valid_constant_values =
op::util::has_constant_value<float>(pow_constant_value, 3.0f) &&
op::util::has_constant_value<float>(mul_0_constant_value, 0.044715f, 0.001f) &&
op::util::has_constant_value<float>(mul_1_constant_value, std::sqrt(2.0f / pi), 0.01f) &&
op::util::has_constant_value<float>(mul_2_constant_value, 0.5f) &&
op::util::has_constant_value<float>(add_1_constant_value, 1.0f);
if (!valid_constant_values) {
return false;
}
auto gelu = std::make_shared<ngraph::opset9::Gelu>(x_output, op::GeluApproximationMode::TANH);
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(
{
pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(mul_0).get_node_shared_ptr(),
pattern_to_output.at(mul_1).get_node_shared_ptr(),
pattern_to_output.at(mul_2).get_node_shared_ptr(),
pattern_to_output.at(mul_3).get_node_shared_ptr(),
pattern_to_output.at(tanh).get_node_shared_ptr(),
pattern_to_output.at(add_0).get_node_shared_ptr(),
pattern_to_output.at(add_1).get_node_shared_ptr(),
},
gelu);
ngraph::replace_node(m.get_match_root(), gelu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_3, matcher_name);
register_matcher(m, callback);
}

View File

@@ -10,6 +10,7 @@
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/manager.hpp>
#include <queue>
@@ -170,3 +171,386 @@ TEST_F(TransformationTestsF, GeluFusionPatternTooShortDivConstValue) {
manager.register_pass<pass::GeluFusionWithErfTwo>();
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_equal_const_values) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_params_no_conversion) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
auto pow = std::make_shared<opset9::Power>(input, pow_param);
auto mul_0_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_param);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
auto mul_1_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_param);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_param);
auto mul_2_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_param);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3},
ParameterVector{input, pow_param, mul_0_param, mul_1_param, add_1_param, mul_2_param});
manager.register_pass<pass::GeluFusionWithTanh>();
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f + 1.0e-8f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_pow_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{2.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_mul_0_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.04515f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_mul_0_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.4715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_mul_1_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.7980868f});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_mul_1_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(10.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_add_1_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f + 1.0e-8f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_add_1_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{2.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_mul_2_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f + 1.0e-8f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_mul_2_value) {
{
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
constexpr float pi = 3.141592653589793238462643383279502884f;
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{5.0f});
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
manager.register_pass<pass::GeluFusionWithTanh>();
}
}