[Transformations] Add threshold for const comparison in Gelu fusion pass to fuse with fp16 precision (#17042)
This commit is contained in:
parent
e8ae1e41ea
commit
faba5fb71e
@ -49,7 +49,7 @@ ov::pass::GeluFusionWithErfOne::GeluFusionWithErfOne() {
|
||||
}
|
||||
|
||||
bool valid_constant_values =
|
||||
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2)) &&
|
||||
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
|
||||
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
|
||||
op::util::has_constant_value<float>(mul_const_value, 0.5f);
|
||||
|
||||
@ -109,7 +109,7 @@ ov::pass::GeluFusionWithErfTwo::GeluFusionWithErfTwo() {
|
||||
}
|
||||
|
||||
bool valid_constant_values =
|
||||
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2)) &&
|
||||
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
|
||||
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
|
||||
op::util::has_constant_value<float>(mul_const_value, 0.5f);
|
||||
|
||||
@ -169,7 +169,7 @@ ov::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
|
||||
}
|
||||
|
||||
bool valid_constant_values =
|
||||
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2)) &&
|
||||
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
|
||||
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
|
||||
op::util::has_constant_value<float>(mul_const_value, 0.5f);
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <transformations/common_optimizations/gelu_fusion.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
@ -50,6 +51,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternOne) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternOneF16) {
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f16, Shape{2, 2});
|
||||
|
||||
auto div_const = opset7::Constant::create(element::f16, Shape{1}, {M_SQRT2});
|
||||
auto add_const = opset7::Constant::create(element::f16, Shape{1}, {1.0});
|
||||
auto mul_const = opset7::Constant::create(element::f16, Shape{1}, {0.5});
|
||||
|
||||
auto div = std::make_shared<opset7::Divide>(data, div_const);
|
||||
auto erf = std::make_shared<opset7::Erf>(div);
|
||||
auto add = std::make_shared<opset7::Add>(erf, add_const);
|
||||
auto mul_first = std::make_shared<opset7::Multiply>(data, mul_const);
|
||||
auto mul = std::make_shared<opset7::Multiply>(mul_first, add);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::GeluFusionWithErfOne>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternTwo) {
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, Shape{2, 2});
|
||||
@ -76,6 +103,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternTwo) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternTwoF16) {
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f16, Shape{2, 2});
|
||||
|
||||
auto div_const = opset7::Constant::create(element::f16, Shape{1}, {M_SQRT2});
|
||||
auto add_const = opset7::Constant::create(element::f16, Shape{1}, {1.0});
|
||||
auto mul_const = opset7::Constant::create(element::f16, Shape{1}, {0.5});
|
||||
|
||||
auto div = std::make_shared<opset7::Divide>(data, div_const);
|
||||
auto erf = std::make_shared<opset7::Erf>(div);
|
||||
auto add = std::make_shared<opset7::Add>(erf, add_const);
|
||||
auto mul_first = std::make_shared<opset7::Multiply>(data, add);
|
||||
auto mul = std::make_shared<opset7::Multiply>(mul_first, mul_const);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::GeluFusionWithErfTwo>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternThree) {
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, Shape{2, 2});
|
||||
@ -102,6 +155,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternThree) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternThreeF16) {
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f16, Shape{2, 2});
|
||||
|
||||
auto div_const = opset7::Constant::create(element::f16, Shape{1}, {M_SQRT2});
|
||||
auto add_const = opset7::Constant::create(element::f16, Shape{1}, {1.0});
|
||||
auto mul_const = opset7::Constant::create(element::f16, Shape{1}, {0.5});
|
||||
|
||||
auto div = std::make_shared<opset7::Divide>(data, div_const);
|
||||
auto erf = std::make_shared<opset7::Erf>(div);
|
||||
auto add = std::make_shared<opset7::Add>(erf, add_const);
|
||||
auto mul_first = std::make_shared<opset7::Multiply>(add, mul_const);
|
||||
auto mul = std::make_shared<opset7::Multiply>(data, mul_first);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::GeluFusionWithErfThree>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternFour) {
|
||||
{
|
||||
auto data = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
|
||||
@ -128,6 +207,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternFour) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternFourF16) {
|
||||
{
|
||||
auto data = std::make_shared<opset9::Parameter>(element::f16, Shape{2, 2});
|
||||
|
||||
auto mul1_const = opset9::Constant::create(element::f16, Shape{1}, {1.0f / M_SQRT2});
|
||||
auto add_const = opset9::Constant::create(element::f16, Shape{1}, {0.5f});
|
||||
auto mul2_const = opset9::Constant::create(element::f16, Shape{1}, {0.5f});
|
||||
|
||||
auto mul1 = std::make_shared<opset9::Multiply>(data, mul1_const);
|
||||
auto erf = std::make_shared<opset9::Erf>(mul1);
|
||||
auto mul2 = std::make_shared<opset9::Multiply>(erf, mul2_const);
|
||||
auto add = std::make_shared<opset9::Add>(mul2, add_const);
|
||||
auto mul3 = std::make_shared<opset9::Multiply>(data, add);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul3}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::GeluFusionWithErfFour>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
|
||||
auto gelu = std::make_shared<opset9::Gelu>(data);
|
||||
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GeluFusionPatternIncorrectDivConstValue) {
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, Shape{2, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user