[CPU] Added GELU-7 support (#4800)

This commit is contained in:
Alexandra Sidorova 2021-04-07 16:01:47 +03:00 committed by GitHub
parent 54b6c77202
commit 325b5564f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 178 additions and 164 deletions

View File

@ -781,7 +781,13 @@ MKLDNNEltwiseNode::initializers = {
alpha = 0.0f;
beta = 0.0f;
opType = Gelu;
algorithm = mkldnn::algorithm::eltwise_gelu;
std::string approximationMode = activationLayer->GetParamAsString("approximation_mode", "erf");
if (approximationMode == "erf")
algorithm = mkldnn::algorithm::eltwise_gelu_erf;
else if (approximationMode == "tanh")
algorithm = mkldnn::algorithm::eltwise_gelu_tanh;
else
IE_THROW() << "Gelu layer with name " << activationLayer->name << " doesn't support approximation mode " << approximationMode;
}},
{"elu", [](GenericLayer* activationLayer, EltwiseOpType& opType, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
alpha = activationLayer->GetParamAsFloat("alpha", 1.0f);
@ -1743,7 +1749,8 @@ void MKLDNNEltwiseNode::appendPostOps(mkldnn::post_ops& ops) {
case mkldnn::algorithm::eltwise_soft_relu:
case mkldnn::algorithm::eltwise_logistic:
case mkldnn::algorithm::eltwise_exp:
case mkldnn::algorithm::eltwise_gelu:
case mkldnn::algorithm::eltwise_gelu_erf:
case mkldnn::algorithm::eltwise_gelu_tanh:
case mkldnn::algorithm::eltwise_clip:
case mkldnn::algorithm::eltwise_swish:
case mkldnn::algorithm::eltwise_hswish:

View File

@ -53,7 +53,9 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{HSigmoid, {}},
{RoundHalfToEven, {}},
{RoundHalfAwayFromZero, {}},
{Erf, {}}
{Erf, {}},
{GeluErf, {}},
{GeluTanh, {}}
};
const std::map<ActivationTypes, std::vector<std::vector<float>>> activationParamTypes = {

View File

@ -76,14 +76,15 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{Sigmoid, {{}}},
{Tanh, {{}}},
{Relu, {{}}},
{Gelu, {{}}},
{Exp, {{}}},
{Clamp, {{-2.0f, 2.0f}}},
{Elu, {{0.1f}}},
{Swish, {{0.1f}}},
{HSwish, {{}}},
{Mish, {{}}},
{PReLu, {{-0.01f}}}
{PReLu, {{-0.01f}}},
{GeluErf, {{}}},
{GeluTanh, {{}}}
};
std::vector<CPUSpecificParams> cpuParams_4D = {

View File

@ -36,7 +36,6 @@ static std::map<ngraph::helpers::ActivationTypes, std::string> activationNames =
{ngraph::helpers::ActivationTypes::Log, "Log"},
{ngraph::helpers::ActivationTypes::Sign, "Sign"},
{ngraph::helpers::ActivationTypes::Abs, "Abs"},
{ngraph::helpers::ActivationTypes::Gelu, "Gelu"},
{ngraph::helpers::ActivationTypes::Clamp, "Clamp"},
{ngraph::helpers::ActivationTypes::Negative, "Negative"},
{ngraph::helpers::ActivationTypes::Acos, "Acos"},
@ -70,7 +69,9 @@ static std::map<ngraph::helpers::ActivationTypes, std::string> activationNames =
{ngraph::helpers::ActivationTypes::Swish, "Swish"},
{ngraph::helpers::ActivationTypes::HSigmoid, "HSigmoid"},
{ngraph::helpers::ActivationTypes::RoundHalfToEven, "RoundHalfToEven"},
{ngraph::helpers::ActivationTypes::RoundHalfAwayFromZero, "RoundHalfAwayFromZero"}
{ngraph::helpers::ActivationTypes::RoundHalfAwayFromZero, "RoundHalfAwayFromZero"},
{ngraph::helpers::ActivationTypes::GeluErf, "GeluErf"},
{ngraph::helpers::ActivationTypes::GeluTanh, "GeluTanh"},
};
typedef std::tuple<

View File

@ -120,7 +120,9 @@ enum ActivationTypes {
Swish,
HSigmoid,
RoundHalfToEven,
RoundHalfAwayFromZero
RoundHalfAwayFromZero,
GeluErf,
GeluTanh
};
enum EltwiseTypes {

View File

@ -39,7 +39,7 @@ std::shared_ptr<ngraph::Node> makeActivation(const ngraph::Output<Node> &in,
case ngraph::helpers::ActivationTypes::Abs:
return std::make_shared<ngraph::op::Abs>(in);
case ngraph::helpers::ActivationTypes::Gelu:
return std::make_shared<ngraph::op::Gelu>(in);
return std::make_shared<ngraph::op::v0::Gelu>(in);
case ngraph::helpers::ActivationTypes::Clamp:
return std::make_shared<ngraph::op::Clamp>(in, constantsValue[0], constantsValue[1]);
case ngraph::helpers::ActivationTypes::Negative:
@ -107,6 +107,10 @@ std::shared_ptr<ngraph::Node> makeActivation(const ngraph::Output<Node> &in,
return std::make_shared<ngraph::op::v5::Round>(in, ngraph::op::v5::Round::RoundMode::HALF_TO_EVEN);
case ngraph::helpers::ActivationTypes::RoundHalfAwayFromZero:
return std::make_shared<ngraph::op::v5::Round>(in, ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
case ngraph::helpers::ActivationTypes::GeluErf:
return std::make_shared<ngraph::op::v7::Gelu>(in, ngraph::op::GeluApproximationMode::ERF);
case ngraph::helpers::ActivationTypes::GeluTanh:
return std::make_shared<ngraph::op::v7::Gelu>(in, ngraph::op::GeluApproximationMode::TANH);
default:
throw std::runtime_error("Can't create layer for this activation type");
}

View File

@ -16,155 +16,155 @@ from ngraph.impl import Function
from ngraph.helpers import function_from_cnn
from ngraph.helpers import function_to_cnn
from ngraph.opset6 import absolute
from ngraph.opset6 import absolute as abs
from ngraph.opset6 import acos
from ngraph.opset6 import acosh
from ngraph.opset6 import add
from ngraph.opset6 import asin
from ngraph.opset6 import asinh
from ngraph.opset6 import assign
from ngraph.opset6 import atan
from ngraph.opset6 import atanh
from ngraph.opset6 import avg_pool
from ngraph.opset6 import batch_norm_inference
from ngraph.opset6 import batch_to_space
from ngraph.opset6 import binary_convolution
from ngraph.opset6 import broadcast
from ngraph.opset6 import bucketize
from ngraph.opset6 import ceiling
from ngraph.opset6 import ceiling as ceil
from ngraph.opset6 import clamp
from ngraph.opset6 import concat
from ngraph.opset6 import constant
from ngraph.opset6 import convert
from ngraph.opset6 import convert_like
from ngraph.opset6 import convolution
from ngraph.opset6 import convolution_backprop_data
from ngraph.opset6 import cos
from ngraph.opset6 import cosh
from ngraph.opset6 import ctc_greedy_decoder
from ngraph.opset6 import ctc_greedy_decoder_seq_len
from ngraph.opset6 import ctc_loss
from ngraph.opset6 import cum_sum
from ngraph.opset6 import cum_sum as cumsum
from ngraph.opset6 import deformable_convolution
from ngraph.opset6 import deformable_psroi_pooling
from ngraph.opset6 import depth_to_space
from ngraph.opset6 import detection_output
from ngraph.opset6 import divide
from ngraph.opset6 import elu
from ngraph.opset6 import embedding_bag_offsets_sum
from ngraph.opset6 import embedding_bag_packed_sum
from ngraph.opset6 import embedding_segments_sum
from ngraph.opset6 import extract_image_patches
from ngraph.opset6 import equal
from ngraph.opset6 import erf
from ngraph.opset6 import exp
from ngraph.opset6 import fake_quantize
from ngraph.opset6 import floor
from ngraph.opset6 import floor_mod
from ngraph.opset6 import gather
from ngraph.opset6 import gather_elements
from ngraph.opset6 import gather_nd
from ngraph.opset6 import gather_tree
from ngraph.opset6 import gelu
from ngraph.opset6 import greater
from ngraph.opset6 import greater_equal
from ngraph.opset6 import grn
from ngraph.opset6 import group_convolution
from ngraph.opset6 import group_convolution_backprop_data
from ngraph.opset6 import gru_cell
from ngraph.opset6 import gru_sequence
from ngraph.opset6 import hard_sigmoid
from ngraph.opset6 import hsigmoid
from ngraph.opset6 import hswish
from ngraph.opset6 import interpolate
from ngraph.opset6 import less
from ngraph.opset6 import less_equal
from ngraph.opset6 import log
from ngraph.opset6 import logical_and
from ngraph.opset6 import logical_not
from ngraph.opset6 import logical_or
from ngraph.opset6 import logical_xor
from ngraph.opset6 import log_softmax
from ngraph.opset6 import loop
from ngraph.opset6 import lrn
from ngraph.opset6 import lstm_cell
from ngraph.opset6 import lstm_sequence
from ngraph.opset6 import matmul
from ngraph.opset6 import max_pool
from ngraph.opset6 import maximum
from ngraph.opset6 import minimum
from ngraph.opset6 import mish
from ngraph.opset6 import mod
from ngraph.opset6 import multiply
from ngraph.opset6 import mvn
from ngraph.opset6 import negative
from ngraph.opset6 import non_max_suppression
from ngraph.opset6 import non_zero
from ngraph.opset6 import normalize_l2
from ngraph.opset6 import not_equal
from ngraph.opset6 import one_hot
from ngraph.opset6 import pad
from ngraph.opset6 import parameter
from ngraph.opset6 import power
from ngraph.opset6 import prelu
from ngraph.opset6 import prior_box
from ngraph.opset6 import prior_box_clustered
from ngraph.opset6 import psroi_pooling
from ngraph.opset6 import proposal
from ngraph.opset6 import range
from ngraph.opset6 import read_value
from ngraph.opset6 import reduce_l1
from ngraph.opset6 import reduce_l2
from ngraph.opset6 import reduce_logical_and
from ngraph.opset6 import reduce_logical_or
from ngraph.opset6 import reduce_max
from ngraph.opset6 import reduce_mean
from ngraph.opset6 import reduce_min
from ngraph.opset6 import reduce_prod
from ngraph.opset6 import reduce_sum
from ngraph.opset6 import region_yolo
from ngraph.opset6 import reorg_yolo
from ngraph.opset6 import relu
from ngraph.opset6 import reshape
from ngraph.opset6 import result
from ngraph.opset6 import reverse_sequence
from ngraph.opset6 import rnn_cell
from ngraph.opset6 import rnn_sequence
from ngraph.opset6 import roi_align
from ngraph.opset6 import roi_pooling
from ngraph.opset6 import round
from ngraph.opset6 import scatter_elements_update
from ngraph.opset6 import scatter_update
from ngraph.opset6 import select
from ngraph.opset6 import selu
from ngraph.opset6 import shape_of
from ngraph.opset6 import shuffle_channels
from ngraph.opset6 import sigmoid
from ngraph.opset6 import sign
from ngraph.opset6 import sin
from ngraph.opset6 import sinh
from ngraph.opset6 import softmax
from ngraph.opset6 import softplus
from ngraph.opset6 import space_to_batch
from ngraph.opset6 import space_to_depth
from ngraph.opset6 import split
from ngraph.opset6 import sqrt
from ngraph.opset6 import squared_difference
from ngraph.opset6 import squeeze
from ngraph.opset6 import strided_slice
from ngraph.opset6 import subtract
from ngraph.opset6 import swish
from ngraph.opset6 import tan
from ngraph.opset6 import tanh
from ngraph.opset6 import tensor_iterator
from ngraph.opset6 import tile
from ngraph.opset6 import topk
from ngraph.opset6 import transpose
from ngraph.opset6 import unsqueeze
from ngraph.opset6 import variadic_split
from ngraph.opset7 import absolute
from ngraph.opset7 import absolute as abs
from ngraph.opset7 import acos
from ngraph.opset7 import acosh
from ngraph.opset7 import add
from ngraph.opset7 import asin
from ngraph.opset7 import asinh
from ngraph.opset7 import assign
from ngraph.opset7 import atan
from ngraph.opset7 import atanh
from ngraph.opset7 import avg_pool
from ngraph.opset7 import batch_norm_inference
from ngraph.opset7 import batch_to_space
from ngraph.opset7 import binary_convolution
from ngraph.opset7 import broadcast
from ngraph.opset7 import bucketize
from ngraph.opset7 import ceiling
from ngraph.opset7 import ceiling as ceil
from ngraph.opset7 import clamp
from ngraph.opset7 import concat
from ngraph.opset7 import constant
from ngraph.opset7 import convert
from ngraph.opset7 import convert_like
from ngraph.opset7 import convolution
from ngraph.opset7 import convolution_backprop_data
from ngraph.opset7 import cos
from ngraph.opset7 import cosh
from ngraph.opset7 import ctc_greedy_decoder
from ngraph.opset7 import ctc_greedy_decoder_seq_len
from ngraph.opset7 import ctc_loss
from ngraph.opset7 import cum_sum
from ngraph.opset7 import cum_sum as cumsum
from ngraph.opset7 import deformable_convolution
from ngraph.opset7 import deformable_psroi_pooling
from ngraph.opset7 import depth_to_space
from ngraph.opset7 import detection_output
from ngraph.opset7 import divide
from ngraph.opset7 import elu
from ngraph.opset7 import embedding_bag_offsets_sum
from ngraph.opset7 import embedding_bag_packed_sum
from ngraph.opset7 import embedding_segments_sum
from ngraph.opset7 import extract_image_patches
from ngraph.opset7 import equal
from ngraph.opset7 import erf
from ngraph.opset7 import exp
from ngraph.opset7 import fake_quantize
from ngraph.opset7 import floor
from ngraph.opset7 import floor_mod
from ngraph.opset7 import gather
from ngraph.opset7 import gather_elements
from ngraph.opset7 import gather_nd
from ngraph.opset7 import gather_tree
from ngraph.opset7 import gelu
from ngraph.opset7 import greater
from ngraph.opset7 import greater_equal
from ngraph.opset7 import grn
from ngraph.opset7 import group_convolution
from ngraph.opset7 import group_convolution_backprop_data
from ngraph.opset7 import gru_cell
from ngraph.opset7 import gru_sequence
from ngraph.opset7 import hard_sigmoid
from ngraph.opset7 import hsigmoid
from ngraph.opset7 import hswish
from ngraph.opset7 import interpolate
from ngraph.opset7 import less
from ngraph.opset7 import less_equal
from ngraph.opset7 import log
from ngraph.opset7 import logical_and
from ngraph.opset7 import logical_not
from ngraph.opset7 import logical_or
from ngraph.opset7 import logical_xor
from ngraph.opset7 import log_softmax
from ngraph.opset7 import loop
from ngraph.opset7 import lrn
from ngraph.opset7 import lstm_cell
from ngraph.opset7 import lstm_sequence
from ngraph.opset7 import matmul
from ngraph.opset7 import max_pool
from ngraph.opset7 import maximum
from ngraph.opset7 import minimum
from ngraph.opset7 import mish
from ngraph.opset7 import mod
from ngraph.opset7 import multiply
from ngraph.opset7 import mvn
from ngraph.opset7 import negative
from ngraph.opset7 import non_max_suppression
from ngraph.opset7 import non_zero
from ngraph.opset7 import normalize_l2
from ngraph.opset7 import not_equal
from ngraph.opset7 import one_hot
from ngraph.opset7 import pad
from ngraph.opset7 import parameter
from ngraph.opset7 import power
from ngraph.opset7 import prelu
from ngraph.opset7 import prior_box
from ngraph.opset7 import prior_box_clustered
from ngraph.opset7 import psroi_pooling
from ngraph.opset7 import proposal
from ngraph.opset7 import range
from ngraph.opset7 import read_value
from ngraph.opset7 import reduce_l1
from ngraph.opset7 import reduce_l2
from ngraph.opset7 import reduce_logical_and
from ngraph.opset7 import reduce_logical_or
from ngraph.opset7 import reduce_max
from ngraph.opset7 import reduce_mean
from ngraph.opset7 import reduce_min
from ngraph.opset7 import reduce_prod
from ngraph.opset7 import reduce_sum
from ngraph.opset7 import region_yolo
from ngraph.opset7 import reorg_yolo
from ngraph.opset7 import relu
from ngraph.opset7 import reshape
from ngraph.opset7 import result
from ngraph.opset7 import reverse_sequence
from ngraph.opset7 import rnn_cell
from ngraph.opset7 import rnn_sequence
from ngraph.opset7 import roi_align
from ngraph.opset7 import roi_pooling
from ngraph.opset7 import round
from ngraph.opset7 import scatter_elements_update
from ngraph.opset7 import scatter_update
from ngraph.opset7 import select
from ngraph.opset7 import selu
from ngraph.opset7 import shape_of
from ngraph.opset7 import shuffle_channels
from ngraph.opset7 import sigmoid
from ngraph.opset7 import sign
from ngraph.opset7 import sin
from ngraph.opset7 import sinh
from ngraph.opset7 import softmax
from ngraph.opset7 import softplus
from ngraph.opset7 import space_to_batch
from ngraph.opset7 import space_to_depth
from ngraph.opset7 import split
from ngraph.opset7 import sqrt
from ngraph.opset7 import squared_difference
from ngraph.opset7 import squeeze
from ngraph.opset7 import strided_slice
from ngraph.opset7 import subtract
from ngraph.opset7 import swish
from ngraph.opset7 import tan
from ngraph.opset7 import tanh
from ngraph.opset7 import tensor_iterator
from ngraph.opset7 import tile
from ngraph.opset7 import topk
from ngraph.opset7 import transpose
from ngraph.opset7 import unsqueeze
from ngraph.opset7 import variadic_split
# Extend Node class to support binary operators

View File

@ -173,4 +173,3 @@ xfail_issue_49750 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v4
xfail_issue_49752 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::Pad")
xfail_issue_49753 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::StridedSlice")
xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE")
xfail_issue_49913 = xfail_test(reason="CPU supports Gelu with tanh mode only")

View File

@ -8,7 +8,7 @@ import ngraph as ng
from ngraph.impl import Shape, Type
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
from tests import xfail_issue_44970, xfail_issue_49913
from tests import xfail_issue_44970
@pytest.mark.parametrize(
@ -176,7 +176,6 @@ def test_hsigmoid():
assert node.get_output_element_type(0) == Type.f32
@xfail_issue_49913
def test_gelu_operator_with_parameters():
runtime = get_runtime()
@ -190,10 +189,9 @@ def test_gelu_operator_with_parameters():
result = computation(data_value)
expected = np.array([[-1.6391277e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]], dtype=np.float32)
assert np.allclose(result, expected)
assert np.allclose(result, expected, 1e-6, 1e-6)
@xfail_issue_49913
def test_gelu_operator_with_array():
runtime = get_runtime()
@ -204,7 +202,7 @@ def test_gelu_operator_with_array():
result = computation()
expected = np.array([[-1.6391277e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]], dtype=np.float32)
assert np.allclose(result, expected)
assert np.allclose(result, expected, 1e-6, 1e-6)
def test_gelu_tanh_operator_with_parameters():