[GPU] Additional pattern for GELU fusion (#12928)
This commit is contained in:
parent
af7384f8c9
commit
5500aed209
@ -21,6 +21,14 @@ class TRANSFORMATIONS_API GeluFusionWithTanh;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API GeluFusionWithErfFour;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces a sub-graph
|
||||
@ -35,7 +43,7 @@ public:
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces a sub-graph
|
||||
* 0.5 * (x * (1 + erf(x / sqrt(2))) with a Gelu op.
|
||||
* 0.5 * (x * (1 + erf(x / sqrt(2)))) with a Gelu op.
|
||||
*/
|
||||
class ngraph::pass::GeluFusionWithErfTwo : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
@ -46,7 +54,7 @@ public:
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces a sub-graph
|
||||
* x * (0.5 * (1 + erf(x / sqrt(2))) with a Gelu op.
|
||||
* x * (0.5 * (1 + erf(x / sqrt(2)))) with a Gelu op.
|
||||
*/
|
||||
class ngraph::pass::GeluFusionWithErfThree : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
@ -57,7 +65,18 @@ public:
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces a sub-graph
|
||||
* x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3])) with a Gelu (Tanh) op.
|
||||
* x * (0.5 + 0.5 * erf(x * (1 / sqrt(2)))) with a Gelu op.
|
||||
*/
|
||||
class ov::pass::GeluFusionWithErfFour : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GeluFusionWithErfFour", "0");
|
||||
GeluFusionWithErfFour();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GeluFusion transformation replaces a sub-graph
|
||||
* x * (0.5 * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]))) with a Gelu (Tanh) op.
|
||||
*/
|
||||
class ngraph::pass::GeluFusionWithTanh : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
@ -76,6 +95,7 @@ public:
|
||||
add_matcher<ngraph::pass::GeluFusionWithErfOne>();
|
||||
add_matcher<ngraph::pass::GeluFusionWithErfTwo>();
|
||||
add_matcher<ngraph::pass::GeluFusionWithErfThree>();
|
||||
add_matcher<ov::pass::GeluFusionWithErfFour>();
|
||||
add_matcher<ngraph::pass::GeluFusionWithTanh>();
|
||||
}
|
||||
};
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <ngraph/opsets/opset9.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
@ -194,6 +195,58 @@ ngraph::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::pass::GeluFusionWithErfFour::GeluFusionWithErfFour() {
|
||||
MATCHER_SCOPE(GeluFusionWithErfFour);
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
auto input = any_input();
|
||||
auto mul1_constant = wrap_type<Constant>();
|
||||
auto mul1 = wrap_type<Multiply>({input, mul1_constant});
|
||||
auto erf = wrap_type<Erf>({mul1});
|
||||
auto mul2_constant = wrap_type<Constant>();
|
||||
auto mul2 = wrap_type<Multiply>({erf, mul2_constant});
|
||||
auto add_constant = wrap_type<Constant>();
|
||||
auto add = wrap_type<Add>({add_constant, mul2});
|
||||
|
||||
// x * (0.5 + 0.5 * erf(x * (1 / sqrt(2))))
|
||||
auto mul3 = wrap_type<Multiply>({input, add});
|
||||
|
||||
matcher_pass_callback callback = [=](Matcher& m) {
|
||||
NodeRegistry rg;
|
||||
auto pattern_to_output = m.get_pattern_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto mul1_const_value = std::dynamic_pointer_cast<Constant>(pattern_to_output.at(mul1_constant));
|
||||
auto add_const_value = std::dynamic_pointer_cast<Constant>(pattern_to_output.at(add_constant));
|
||||
auto mul2_const_value = std::dynamic_pointer_cast<Constant>(pattern_to_output.at(mul2_constant));
|
||||
|
||||
if (!mul1_const_value || !add_const_value || !mul2_const_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid_constant_values =
|
||||
ngraph::op::util::has_constant_value<float>(mul1_const_value, 1.0f / M_SQRT2, 0.001f) &&
|
||||
ngraph::op::util::has_constant_value<float>(add_const_value, 0.5f) &&
|
||||
ngraph::op::util::has_constant_value<float>(mul2_const_value, 0.5f);
|
||||
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto gelu = rg.make<Gelu>(x_output);
|
||||
|
||||
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
copy_runtime_info(m.get_matched_nodes(), rg.get());
|
||||
replace_node(m.get_match_root(), gelu);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(mul3, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::GeluFusionWithTanh::GeluFusionWithTanh() {
|
||||
MATCHER_SCOPE(GeluFusionWithTanh);
|
||||
// Replaces a sub-graph with a Gelu (Tanh) op
|
||||
|
@ -120,6 +120,38 @@ TEST_F(TransformationTestsF, GeluFusionPatternThree) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternFour) {
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
|
||||
auto mul1_const =
|
||||
opset9::Constant::create(element::f32, Shape{1}, {1.0f / M_SQRT2});
|
||||
auto add_const =
|
||||
opset9::Constant::create(element::f32, Shape{1}, {0.5f});
|
||||
auto mul2_const =
|
||||
opset9::Constant::create(element::f32, Shape{1}, {0.5f});
|
||||
|
||||
auto mul1 = std::make_shared<opset9::Multiply>(data, mul1_const);
|
||||
auto erf = std::make_shared<opset9::Erf>(mul1);
|
||||
auto mul2 = std::make_shared<opset9::Multiply>(erf, mul2_const);
|
||||
auto add = std::make_shared<opset9::Add>(mul2, add_const);
|
||||
auto mul3 = std::make_shared<opset9::Multiply>(data, add);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul3}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::GeluFusionWithErfFour>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data =
|
||||
std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternIncorrectDivConstValue) {
|
||||
{
|
||||
auto data =
|
||||
|
Loading…
Reference in New Issue
Block a user