[Transformation] Add GeluFusion with Tanh (#11752)
This commit is contained in:
@@ -16,6 +16,7 @@ class TRANSFORMATIONS_API GeluFusion;
|
||||
class TRANSFORMATIONS_API GeluFusionWithErfOne;
|
||||
class TRANSFORMATIONS_API GeluFusionWithErfTwo;
|
||||
class TRANSFORMATIONS_API GeluFusionWithErfThree;
|
||||
class TRANSFORMATIONS_API GeluFusionWithTanh;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@@ -53,6 +54,17 @@ public:
|
||||
GeluFusionWithErfThree();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces a sub-graph
|
||||
* x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3])) with a Gelu (Tanh) op.
|
||||
*/
|
||||
class ngraph::pass::GeluFusionWithTanh : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GeluFusionWithTanh", "0");
|
||||
GeluFusionWithTanh();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces various sub-graphs with a Gelu op.
|
||||
@@ -64,5 +76,6 @@ public:
|
||||
add_matcher<ngraph::pass::GeluFusionWithErfOne>();
|
||||
add_matcher<ngraph::pass::GeluFusionWithErfTwo>();
|
||||
add_matcher<ngraph::pass::GeluFusionWithErfThree>();
|
||||
add_matcher<ngraph::pass::GeluFusionWithTanh>();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset9.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
@@ -192,3 +193,86 @@ ngraph::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
|
||||
MATCHER_SCOPE(GeluFusionWithTanh);
|
||||
// Replaces a sub-graph with a Gelu (Tanh) op
|
||||
// Gaussian Error Linear Unit, TanH based approximation:
|
||||
// x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]))
|
||||
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto pow_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto pow = ngraph::pattern::wrap_type<ngraph::opset9::Power>({input, pow_constant});
|
||||
|
||||
auto mul_0_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto mul_0 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({pow, mul_0_constant});
|
||||
|
||||
auto add_0 = ngraph::pattern::wrap_type<ngraph::opset9::Add>({input, mul_0});
|
||||
|
||||
auto mul_1_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto mul_1 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({add_0, mul_1_constant});
|
||||
|
||||
auto tanh = ngraph::pattern::wrap_type<ngraph::opset9::Tanh>({mul_1});
|
||||
|
||||
auto add_1_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto add_1 = ngraph::pattern::wrap_type<ngraph::opset9::Add>({tanh, add_1_constant});
|
||||
|
||||
auto mul_2_constant = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto mul_2 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({add_1, mul_2_constant});
|
||||
|
||||
auto mul_3 = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({input, mul_2});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto pow_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
|
||||
pattern_to_output.at(pow_constant).get_node_shared_ptr());
|
||||
auto mul_0_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
|
||||
pattern_to_output.at(mul_0_constant).get_node_shared_ptr());
|
||||
auto mul_1_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
|
||||
pattern_to_output.at(mul_1_constant).get_node_shared_ptr());
|
||||
auto mul_2_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
|
||||
pattern_to_output.at(mul_2_constant).get_node_shared_ptr());
|
||||
auto add_1_constant_value = std::dynamic_pointer_cast<ngraph::opset9::Constant>(
|
||||
pattern_to_output.at(add_1_constant).get_node_shared_ptr());
|
||||
|
||||
if (!pow_constant_value || !add_1_constant_value || !mul_0_constant_value || !mul_1_constant_value ||
|
||||
!mul_2_constant_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
bool valid_constant_values =
|
||||
op::util::has_constant_value<float>(pow_constant_value, 3.0f) &&
|
||||
op::util::has_constant_value<float>(mul_0_constant_value, 0.044715f, 0.001f) &&
|
||||
op::util::has_constant_value<float>(mul_1_constant_value, std::sqrt(2.0f / pi), 0.01f) &&
|
||||
op::util::has_constant_value<float>(mul_2_constant_value, 0.5f) &&
|
||||
op::util::has_constant_value<float>(add_1_constant_value, 1.0f);
|
||||
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto gelu = std::make_shared<ngraph::opset9::Gelu>(x_output, op::GeluApproximationMode::TANH);
|
||||
|
||||
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(
|
||||
{
|
||||
pattern_to_output.at(pow).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_3).get_node_shared_ptr(),
|
||||
pattern_to_output.at(tanh).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add_0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add_1).get_node_shared_ptr(),
|
||||
},
|
||||
gelu);
|
||||
ngraph::replace_node(m.get_match_root(), gelu);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_3, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <memory>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset9.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <queue>
|
||||
@@ -170,3 +171,386 @@ TEST_F(TransformationTestsF, GeluFusionPatternTooShortDivConstValue) {
|
||||
manager.register_pass<pass::GeluFusionWithErfTwo>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_equal_const_values) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_params_no_conversion) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_param);
|
||||
auto mul_0_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_param);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
auto mul_1_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_param);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_param);
|
||||
|
||||
auto mul_2_param = std::make_shared<opset9::Parameter>(element::f32, Shape{1});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_param);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3},
|
||||
ParameterVector{input, pow_param, mul_0_param, mul_1_param, add_1_param, mul_2_param});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f + 1.0e-8f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_pow_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{2.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_mul_0_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.04515f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_mul_0_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.4715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_mul_1_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.7980868f});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_mul_1_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(10.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_add_1_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f + 1.0e-8f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_add_1_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{2.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_mul_2_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.5f + 1.0e-8f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data, op::GeluApproximationMode::TANH);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_mul_2_value) {
|
||||
{
|
||||
auto input = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
auto pow_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{3.0f});
|
||||
auto pow = std::make_shared<opset9::Power>(input, pow_constant);
|
||||
auto mul_0_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{0.044715f});
|
||||
auto mul_0 = std::make_shared<opset9::Multiply>(pow, mul_0_constant);
|
||||
auto add_0 = std::make_shared<opset9::Add>(input, mul_0);
|
||||
|
||||
constexpr float pi = 3.141592653589793238462643383279502884f;
|
||||
auto mul_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{std::sqrt(2.0f / pi)});
|
||||
auto mul_1 = std::make_shared<opset9::Multiply>(add_0, mul_1_constant);
|
||||
|
||||
auto tanh = std::make_shared<opset9::Tanh>(mul_1);
|
||||
|
||||
auto add_1_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{1.0f});
|
||||
auto add_1 = std::make_shared<opset9::Add>(tanh, add_1_constant);
|
||||
|
||||
auto mul_2_constant = std::make_shared<opset9::Constant>(element::f32, Shape{1}, std::vector<float>{5.0f});
|
||||
auto mul_2 = std::make_shared<opset9::Multiply>(add_1, mul_2_constant);
|
||||
|
||||
auto mul_3 = std::make_shared<opset9::Multiply>(input, mul_2);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul_3}, ParameterVector{input});
|
||||
manager.register_pass<pass::GeluFusionWithTanh>();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user