[GPU] Additional pattern for GELU fusion (#12928)

This commit is contained in:
Vladimir Paramuzov 2022-09-23 09:06:55 +04:00 committed by GitHub
parent af7384f8c9
commit 5500aed209
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 3 deletions

View File

@ -21,6 +21,14 @@ class TRANSFORMATIONS_API GeluFusionWithTanh;
} // namespace pass
} // namespace ngraph
namespace ov {
namespace pass {
class TRANSFORMATIONS_API GeluFusionWithErfFour;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces a sub-graph
@ -35,7 +43,7 @@ public:
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces a sub-graph
* 0.5 * (x * (1 + erf(x / sqrt(2))) with a Gelu op.
* 0.5 * (x * (1 + erf(x / sqrt(2)))) with a Gelu op.
*/
class ngraph::pass::GeluFusionWithErfTwo : public ngraph::pass::MatcherPass {
public:
@ -46,7 +54,7 @@ public:
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces a sub-graph
* x * (0.5 * (1 + erf(x / sqrt(2))) with a Gelu op.
* x * (0.5 * (1 + erf(x / sqrt(2)))) with a Gelu op.
*/
class ngraph::pass::GeluFusionWithErfThree : public ngraph::pass::MatcherPass {
public:
@ -57,7 +65,18 @@ public:
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces a sub-graph
* x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3])) with a Gelu (Tanh) op.
* x * (0.5 + 0.5 * erf(x * (1 / sqrt(2)))) with a Gelu op.
*/
class ov::pass::GeluFusionWithErfFour : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GeluFusionWithErfFour", "0");
GeluFusionWithErfFour();
};
/**
* @ingroup ie_transformation_common_api
* @brief GeluFusion transformation replaces a sub-graph
* x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]))) with a Gelu (Tanh) op.
*/
class ngraph::pass::GeluFusionWithTanh : public ngraph::pass::MatcherPass {
public:
@ -76,6 +95,7 @@ public:
add_matcher<ngraph::pass::GeluFusionWithErfOne>();
add_matcher<ngraph::pass::GeluFusionWithErfTwo>();
add_matcher<ngraph::pass::GeluFusionWithErfThree>();
add_matcher<ov::pass::GeluFusionWithErfFour>();
add_matcher<ngraph::pass::GeluFusionWithTanh>();
}
};

View File

@ -13,6 +13,7 @@
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset9.hpp>
#include "itt.hpp"
#include "transformations/utils/utils.hpp"
@ -194,6 +195,58 @@ ngraph::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
register_matcher(m, callback);
}
ov::pass::GeluFusionWithErfFour::GeluFusionWithErfFour() {
MATCHER_SCOPE(GeluFusionWithErfFour);
using namespace ov;
using namespace ov::opset9;
using namespace ov::pass::pattern;
auto input = any_input();
auto mul1_constant = wrap_type<Constant>();
auto mul1 = wrap_type<Multiply>({input, mul1_constant});
auto erf = wrap_type<Erf>({mul1});
auto mul2_constant = wrap_type<Constant>();
auto mul2 = wrap_type<Multiply>({erf, mul2_constant});
auto add_constant = wrap_type<Constant>();
auto add = wrap_type<Add>({add_constant, mul2});
// x * (0.5 + 0.5 * erf(x * (1 / sqrt(2))))
auto mul3 = wrap_type<Multiply>({input, add});
matcher_pass_callback callback = [=](Matcher& m) {
NodeRegistry rg;
auto pattern_to_output = m.get_pattern_map();
auto x_output = pattern_to_output.at(input);
auto mul1_const_value = std::dynamic_pointer_cast<Constant>(pattern_to_output.at(mul1_constant));
auto add_const_value = std::dynamic_pointer_cast<Constant>(pattern_to_output.at(add_constant));
auto mul2_const_value = std::dynamic_pointer_cast<Constant>(pattern_to_output.at(mul2_constant));
if (!mul1_const_value || !add_const_value || !mul2_const_value) {
return false;
}
bool valid_constant_values =
ngraph::op::util::has_constant_value<float>(mul1_const_value, 1.0f / M_SQRT2, 0.001f) &&
ngraph::op::util::has_constant_value<float>(add_const_value, 0.5f) &&
ngraph::op::util::has_constant_value<float>(mul2_const_value, 0.5f);
if (!valid_constant_values) {
return false;
}
auto gelu = rg.make<Gelu>(x_output);
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
copy_runtime_info(m.get_matched_nodes(), rg.get());
replace_node(m.get_match_root(), gelu);
return true;
};
auto m = std::make_shared<Matcher>(mul3, matcher_name);
register_matcher(m, callback);
}
ngraph::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
MATCHER_SCOPE(GeluFusionWithTanh);
// Replaces a sub-graph with a Gelu (Tanh) op

View File

@ -120,6 +120,38 @@ TEST_F(TransformationTestsF, GeluFusionPatternThree) {
}
}
TEST_F(TransformationTestsF, GeluFusionPatternFour) {
{
auto data =
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
auto mul1_const =
opset9::Constant::create(element::f32, Shape{1}, {1.0f / M_SQRT2});
auto add_const =
opset9::Constant::create(element::f32, Shape{1}, {0.5f});
auto mul2_const =
opset9::Constant::create(element::f32, Shape{1}, {0.5f});
auto mul1 = std::make_shared<opset9::Multiply>(data, mul1_const);
auto erf = std::make_shared<opset9::Erf>(mul1);
auto mul2 = std::make_shared<opset9::Multiply>(erf, mul2_const);
auto add = std::make_shared<opset9::Add>(mul2, add_const);
auto mul3 = std::make_shared<opset9::Multiply>(data, add);
function = std::make_shared<Function>(NodeVector{mul3}, ParameterVector{data});
manager.register_pass<ov::pass::GeluFusionWithErfFour>();
}
{
auto data =
std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data);
function_ref =
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionPatternIncorrectDivConstValue) {
{
auto data =