[Transformations] Add threshold for const comparison in Gelu fusion pass to fuse with fp16 precision (#17042)

This commit is contained in:
Vladimir Paramuzov 2023-04-24 14:37:31 +04:00 committed by GitHub
parent e8ae1e41ea
commit faba5fb71e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 3 deletions

View File

@ -49,7 +49,7 @@ ov::pass::GeluFusionWithErfOne::GeluFusionWithErfOne() {
}
bool valid_constant_values =
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2)) &&
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
op::util::has_constant_value<float>(mul_const_value, 0.5f);
@ -109,7 +109,7 @@ ov::pass::GeluFusionWithErfTwo::GeluFusionWithErfTwo() {
}
bool valid_constant_values =
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2)) &&
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
op::util::has_constant_value<float>(mul_const_value, 0.5f);
@ -169,7 +169,7 @@ ov::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
}
bool valid_constant_values =
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2)) &&
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
op::util::has_constant_value<float>(mul_const_value, 0.5f);

View File

@ -16,6 +16,7 @@
#include <queue>
#include <string>
#include <transformations/common_optimizations/gelu_fusion.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
@ -50,6 +51,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternOne) {
}
}
TEST_F(TransformationTestsF, GeluFusionPatternOneF16) {
{
auto data = std::make_shared<opset7::Parameter>(element::f16, Shape{2, 2});
auto div_const = opset7::Constant::create(element::f16, Shape{1}, {M_SQRT2});
auto add_const = opset7::Constant::create(element::f16, Shape{1}, {1.0});
auto mul_const = opset7::Constant::create(element::f16, Shape{1}, {0.5});
auto div = std::make_shared<opset7::Divide>(data, div_const);
auto erf = std::make_shared<opset7::Erf>(div);
auto add = std::make_shared<opset7::Add>(erf, add_const);
auto mul_first = std::make_shared<opset7::Multiply>(data, mul_const);
auto mul = std::make_shared<opset7::Multiply>(mul_first, add);
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
manager.register_pass<ov::pass::GeluFusionWithErfOne>();
}
{
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
auto gelu = std::make_shared<opset7::Gelu>(data);
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionPatternTwo) {
{
auto data = std::make_shared<opset7::Parameter>(element::f32, Shape{2, 2});
@ -76,6 +103,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternTwo) {
}
}
TEST_F(TransformationTestsF, GeluFusionPatternTwoF16) {
{
auto data = std::make_shared<opset7::Parameter>(element::f16, Shape{2, 2});
auto div_const = opset7::Constant::create(element::f16, Shape{1}, {M_SQRT2});
auto add_const = opset7::Constant::create(element::f16, Shape{1}, {1.0});
auto mul_const = opset7::Constant::create(element::f16, Shape{1}, {0.5});
auto div = std::make_shared<opset7::Divide>(data, div_const);
auto erf = std::make_shared<opset7::Erf>(div);
auto add = std::make_shared<opset7::Add>(erf, add_const);
auto mul_first = std::make_shared<opset7::Multiply>(data, add);
auto mul = std::make_shared<opset7::Multiply>(mul_first, mul_const);
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
manager.register_pass<ov::pass::GeluFusionWithErfTwo>();
}
{
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
auto gelu = std::make_shared<opset7::Gelu>(data);
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionPatternThree) {
{
auto data = std::make_shared<opset7::Parameter>(element::f32, Shape{2, 2});
@ -102,6 +155,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternThree) {
}
}
TEST_F(TransformationTestsF, GeluFusionPatternThreeF16) {
{
auto data = std::make_shared<opset7::Parameter>(element::f16, Shape{2, 2});
auto div_const = opset7::Constant::create(element::f16, Shape{1}, {M_SQRT2});
auto add_const = opset7::Constant::create(element::f16, Shape{1}, {1.0});
auto mul_const = opset7::Constant::create(element::f16, Shape{1}, {0.5});
auto div = std::make_shared<opset7::Divide>(data, div_const);
auto erf = std::make_shared<opset7::Erf>(div);
auto add = std::make_shared<opset7::Add>(erf, add_const);
auto mul_first = std::make_shared<opset7::Multiply>(add, mul_const);
auto mul = std::make_shared<opset7::Multiply>(data, mul_first);
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
manager.register_pass<ov::pass::GeluFusionWithErfThree>();
}
{
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
auto gelu = std::make_shared<opset7::Gelu>(data);
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionPatternFour) {
{
auto data = std::make_shared<opset9::Parameter>(element::f32, Shape{2, 2});
@ -128,6 +207,32 @@ TEST_F(TransformationTestsF, GeluFusionPatternFour) {
}
}
TEST_F(TransformationTestsF, GeluFusionPatternFourF16) {
{
auto data = std::make_shared<opset9::Parameter>(element::f16, Shape{2, 2});
auto mul1_const = opset9::Constant::create(element::f16, Shape{1}, {1.0f / M_SQRT2});
auto add_const = opset9::Constant::create(element::f16, Shape{1}, {0.5f});
auto mul2_const = opset9::Constant::create(element::f16, Shape{1}, {0.5f});
auto mul1 = std::make_shared<opset9::Multiply>(data, mul1_const);
auto erf = std::make_shared<opset9::Erf>(mul1);
auto mul2 = std::make_shared<opset9::Multiply>(erf, mul2_const);
auto add = std::make_shared<opset9::Add>(mul2, add_const);
auto mul3 = std::make_shared<opset9::Multiply>(data, add);
function = std::make_shared<Function>(NodeVector{mul3}, ParameterVector{data});
manager.register_pass<ov::pass::GeluFusionWithErfFour>();
}
{
auto data = std::make_shared<opset1::Parameter>(element::f16, Shape{2, 2});
auto gelu = std::make_shared<opset9::Gelu>(data);
function_ref = std::make_shared<Function>(NodeVector{gelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GeluFusionPatternIncorrectDivConstValue) {
{
auto data = std::make_shared<opset7::Parameter>(element::f32, Shape{2, 2});