[CPU] Added GELU-7 support (#4800)
This commit is contained in:
parent
54b6c77202
commit
325b5564f3
@ -781,7 +781,13 @@ MKLDNNEltwiseNode::initializers = {
|
||||
alpha = 0.0f;
|
||||
beta = 0.0f;
|
||||
opType = Gelu;
|
||||
algorithm = mkldnn::algorithm::eltwise_gelu;
|
||||
std::string approximationMode = activationLayer->GetParamAsString("approximation_mode", "erf");
|
||||
if (approximationMode == "erf")
|
||||
algorithm = mkldnn::algorithm::eltwise_gelu_erf;
|
||||
else if (approximationMode == "tanh")
|
||||
algorithm = mkldnn::algorithm::eltwise_gelu_tanh;
|
||||
else
|
||||
IE_THROW() << "Gelu layer with name " << activationLayer->name << " doesn't support approximation mode " << approximationMode;
|
||||
}},
|
||||
{"elu", [](GenericLayer* activationLayer, EltwiseOpType& opType, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
|
||||
alpha = activationLayer->GetParamAsFloat("alpha", 1.0f);
|
||||
@ -1743,7 +1749,8 @@ void MKLDNNEltwiseNode::appendPostOps(mkldnn::post_ops& ops) {
|
||||
case mkldnn::algorithm::eltwise_soft_relu:
|
||||
case mkldnn::algorithm::eltwise_logistic:
|
||||
case mkldnn::algorithm::eltwise_exp:
|
||||
case mkldnn::algorithm::eltwise_gelu:
|
||||
case mkldnn::algorithm::eltwise_gelu_erf:
|
||||
case mkldnn::algorithm::eltwise_gelu_tanh:
|
||||
case mkldnn::algorithm::eltwise_clip:
|
||||
case mkldnn::algorithm::eltwise_swish:
|
||||
case mkldnn::algorithm::eltwise_hswish:
|
||||
|
@ -53,7 +53,9 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
|
||||
{HSigmoid, {}},
|
||||
{RoundHalfToEven, {}},
|
||||
{RoundHalfAwayFromZero, {}},
|
||||
{Erf, {}}
|
||||
{Erf, {}},
|
||||
{GeluErf, {}},
|
||||
{GeluTanh, {}}
|
||||
};
|
||||
|
||||
const std::map<ActivationTypes, std::vector<std::vector<float>>> activationParamTypes = {
|
||||
|
@ -76,14 +76,15 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
|
||||
{Sigmoid, {{}}},
|
||||
{Tanh, {{}}},
|
||||
{Relu, {{}}},
|
||||
{Gelu, {{}}},
|
||||
{Exp, {{}}},
|
||||
{Clamp, {{-2.0f, 2.0f}}},
|
||||
{Elu, {{0.1f}}},
|
||||
{Swish, {{0.1f}}},
|
||||
{HSwish, {{}}},
|
||||
{Mish, {{}}},
|
||||
{PReLu, {{-0.01f}}}
|
||||
{PReLu, {{-0.01f}}},
|
||||
{GeluErf, {{}}},
|
||||
{GeluTanh, {{}}}
|
||||
};
|
||||
|
||||
std::vector<CPUSpecificParams> cpuParams_4D = {
|
||||
|
@ -36,7 +36,6 @@ static std::map<ngraph::helpers::ActivationTypes, std::string> activationNames =
|
||||
{ngraph::helpers::ActivationTypes::Log, "Log"},
|
||||
{ngraph::helpers::ActivationTypes::Sign, "Sign"},
|
||||
{ngraph::helpers::ActivationTypes::Abs, "Abs"},
|
||||
{ngraph::helpers::ActivationTypes::Gelu, "Gelu"},
|
||||
{ngraph::helpers::ActivationTypes::Clamp, "Clamp"},
|
||||
{ngraph::helpers::ActivationTypes::Negative, "Negative"},
|
||||
{ngraph::helpers::ActivationTypes::Acos, "Acos"},
|
||||
@ -70,7 +69,9 @@ static std::map<ngraph::helpers::ActivationTypes, std::string> activationNames =
|
||||
{ngraph::helpers::ActivationTypes::Swish, "Swish"},
|
||||
{ngraph::helpers::ActivationTypes::HSigmoid, "HSigmoid"},
|
||||
{ngraph::helpers::ActivationTypes::RoundHalfToEven, "RoundHalfToEven"},
|
||||
{ngraph::helpers::ActivationTypes::RoundHalfAwayFromZero, "RoundHalfAwayFromZero"}
|
||||
{ngraph::helpers::ActivationTypes::RoundHalfAwayFromZero, "RoundHalfAwayFromZero"},
|
||||
{ngraph::helpers::ActivationTypes::GeluErf, "GeluErf"},
|
||||
{ngraph::helpers::ActivationTypes::GeluTanh, "GeluTanh"},
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
|
@ -120,7 +120,9 @@ enum ActivationTypes {
|
||||
Swish,
|
||||
HSigmoid,
|
||||
RoundHalfToEven,
|
||||
RoundHalfAwayFromZero
|
||||
RoundHalfAwayFromZero,
|
||||
GeluErf,
|
||||
GeluTanh
|
||||
};
|
||||
|
||||
enum EltwiseTypes {
|
||||
|
@ -39,7 +39,7 @@ std::shared_ptr<ngraph::Node> makeActivation(const ngraph::Output<Node> &in,
|
||||
case ngraph::helpers::ActivationTypes::Abs:
|
||||
return std::make_shared<ngraph::op::Abs>(in);
|
||||
case ngraph::helpers::ActivationTypes::Gelu:
|
||||
return std::make_shared<ngraph::op::Gelu>(in);
|
||||
return std::make_shared<ngraph::op::v0::Gelu>(in);
|
||||
case ngraph::helpers::ActivationTypes::Clamp:
|
||||
return std::make_shared<ngraph::op::Clamp>(in, constantsValue[0], constantsValue[1]);
|
||||
case ngraph::helpers::ActivationTypes::Negative:
|
||||
@ -107,6 +107,10 @@ std::shared_ptr<ngraph::Node> makeActivation(const ngraph::Output<Node> &in,
|
||||
return std::make_shared<ngraph::op::v5::Round>(in, ngraph::op::v5::Round::RoundMode::HALF_TO_EVEN);
|
||||
case ngraph::helpers::ActivationTypes::RoundHalfAwayFromZero:
|
||||
return std::make_shared<ngraph::op::v5::Round>(in, ngraph::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO);
|
||||
case ngraph::helpers::ActivationTypes::GeluErf:
|
||||
return std::make_shared<ngraph::op::v7::Gelu>(in, ngraph::op::GeluApproximationMode::ERF);
|
||||
case ngraph::helpers::ActivationTypes::GeluTanh:
|
||||
return std::make_shared<ngraph::op::v7::Gelu>(in, ngraph::op::GeluApproximationMode::TANH);
|
||||
default:
|
||||
throw std::runtime_error("Can't create layer for this activation type");
|
||||
}
|
||||
|
@ -16,155 +16,155 @@ from ngraph.impl import Function
|
||||
from ngraph.helpers import function_from_cnn
|
||||
from ngraph.helpers import function_to_cnn
|
||||
|
||||
from ngraph.opset6 import absolute
|
||||
from ngraph.opset6 import absolute as abs
|
||||
from ngraph.opset6 import acos
|
||||
from ngraph.opset6 import acosh
|
||||
from ngraph.opset6 import add
|
||||
from ngraph.opset6 import asin
|
||||
from ngraph.opset6 import asinh
|
||||
from ngraph.opset6 import assign
|
||||
from ngraph.opset6 import atan
|
||||
from ngraph.opset6 import atanh
|
||||
from ngraph.opset6 import avg_pool
|
||||
from ngraph.opset6 import batch_norm_inference
|
||||
from ngraph.opset6 import batch_to_space
|
||||
from ngraph.opset6 import binary_convolution
|
||||
from ngraph.opset6 import broadcast
|
||||
from ngraph.opset6 import bucketize
|
||||
from ngraph.opset6 import ceiling
|
||||
from ngraph.opset6 import ceiling as ceil
|
||||
from ngraph.opset6 import clamp
|
||||
from ngraph.opset6 import concat
|
||||
from ngraph.opset6 import constant
|
||||
from ngraph.opset6 import convert
|
||||
from ngraph.opset6 import convert_like
|
||||
from ngraph.opset6 import convolution
|
||||
from ngraph.opset6 import convolution_backprop_data
|
||||
from ngraph.opset6 import cos
|
||||
from ngraph.opset6 import cosh
|
||||
from ngraph.opset6 import ctc_greedy_decoder
|
||||
from ngraph.opset6 import ctc_greedy_decoder_seq_len
|
||||
from ngraph.opset6 import ctc_loss
|
||||
from ngraph.opset6 import cum_sum
|
||||
from ngraph.opset6 import cum_sum as cumsum
|
||||
from ngraph.opset6 import deformable_convolution
|
||||
from ngraph.opset6 import deformable_psroi_pooling
|
||||
from ngraph.opset6 import depth_to_space
|
||||
from ngraph.opset6 import detection_output
|
||||
from ngraph.opset6 import divide
|
||||
from ngraph.opset6 import elu
|
||||
from ngraph.opset6 import embedding_bag_offsets_sum
|
||||
from ngraph.opset6 import embedding_bag_packed_sum
|
||||
from ngraph.opset6 import embedding_segments_sum
|
||||
from ngraph.opset6 import extract_image_patches
|
||||
from ngraph.opset6 import equal
|
||||
from ngraph.opset6 import erf
|
||||
from ngraph.opset6 import exp
|
||||
from ngraph.opset6 import fake_quantize
|
||||
from ngraph.opset6 import floor
|
||||
from ngraph.opset6 import floor_mod
|
||||
from ngraph.opset6 import gather
|
||||
from ngraph.opset6 import gather_elements
|
||||
from ngraph.opset6 import gather_nd
|
||||
from ngraph.opset6 import gather_tree
|
||||
from ngraph.opset6 import gelu
|
||||
from ngraph.opset6 import greater
|
||||
from ngraph.opset6 import greater_equal
|
||||
from ngraph.opset6 import grn
|
||||
from ngraph.opset6 import group_convolution
|
||||
from ngraph.opset6 import group_convolution_backprop_data
|
||||
from ngraph.opset6 import gru_cell
|
||||
from ngraph.opset6 import gru_sequence
|
||||
from ngraph.opset6 import hard_sigmoid
|
||||
from ngraph.opset6 import hsigmoid
|
||||
from ngraph.opset6 import hswish
|
||||
from ngraph.opset6 import interpolate
|
||||
from ngraph.opset6 import less
|
||||
from ngraph.opset6 import less_equal
|
||||
from ngraph.opset6 import log
|
||||
from ngraph.opset6 import logical_and
|
||||
from ngraph.opset6 import logical_not
|
||||
from ngraph.opset6 import logical_or
|
||||
from ngraph.opset6 import logical_xor
|
||||
from ngraph.opset6 import log_softmax
|
||||
from ngraph.opset6 import loop
|
||||
from ngraph.opset6 import lrn
|
||||
from ngraph.opset6 import lstm_cell
|
||||
from ngraph.opset6 import lstm_sequence
|
||||
from ngraph.opset6 import matmul
|
||||
from ngraph.opset6 import max_pool
|
||||
from ngraph.opset6 import maximum
|
||||
from ngraph.opset6 import minimum
|
||||
from ngraph.opset6 import mish
|
||||
from ngraph.opset6 import mod
|
||||
from ngraph.opset6 import multiply
|
||||
from ngraph.opset6 import mvn
|
||||
from ngraph.opset6 import negative
|
||||
from ngraph.opset6 import non_max_suppression
|
||||
from ngraph.opset6 import non_zero
|
||||
from ngraph.opset6 import normalize_l2
|
||||
from ngraph.opset6 import not_equal
|
||||
from ngraph.opset6 import one_hot
|
||||
from ngraph.opset6 import pad
|
||||
from ngraph.opset6 import parameter
|
||||
from ngraph.opset6 import power
|
||||
from ngraph.opset6 import prelu
|
||||
from ngraph.opset6 import prior_box
|
||||
from ngraph.opset6 import prior_box_clustered
|
||||
from ngraph.opset6 import psroi_pooling
|
||||
from ngraph.opset6 import proposal
|
||||
from ngraph.opset6 import range
|
||||
from ngraph.opset6 import read_value
|
||||
from ngraph.opset6 import reduce_l1
|
||||
from ngraph.opset6 import reduce_l2
|
||||
from ngraph.opset6 import reduce_logical_and
|
||||
from ngraph.opset6 import reduce_logical_or
|
||||
from ngraph.opset6 import reduce_max
|
||||
from ngraph.opset6 import reduce_mean
|
||||
from ngraph.opset6 import reduce_min
|
||||
from ngraph.opset6 import reduce_prod
|
||||
from ngraph.opset6 import reduce_sum
|
||||
from ngraph.opset6 import region_yolo
|
||||
from ngraph.opset6 import reorg_yolo
|
||||
from ngraph.opset6 import relu
|
||||
from ngraph.opset6 import reshape
|
||||
from ngraph.opset6 import result
|
||||
from ngraph.opset6 import reverse_sequence
|
||||
from ngraph.opset6 import rnn_cell
|
||||
from ngraph.opset6 import rnn_sequence
|
||||
from ngraph.opset6 import roi_align
|
||||
from ngraph.opset6 import roi_pooling
|
||||
from ngraph.opset6 import round
|
||||
from ngraph.opset6 import scatter_elements_update
|
||||
from ngraph.opset6 import scatter_update
|
||||
from ngraph.opset6 import select
|
||||
from ngraph.opset6 import selu
|
||||
from ngraph.opset6 import shape_of
|
||||
from ngraph.opset6 import shuffle_channels
|
||||
from ngraph.opset6 import sigmoid
|
||||
from ngraph.opset6 import sign
|
||||
from ngraph.opset6 import sin
|
||||
from ngraph.opset6 import sinh
|
||||
from ngraph.opset6 import softmax
|
||||
from ngraph.opset6 import softplus
|
||||
from ngraph.opset6 import space_to_batch
|
||||
from ngraph.opset6 import space_to_depth
|
||||
from ngraph.opset6 import split
|
||||
from ngraph.opset6 import sqrt
|
||||
from ngraph.opset6 import squared_difference
|
||||
from ngraph.opset6 import squeeze
|
||||
from ngraph.opset6 import strided_slice
|
||||
from ngraph.opset6 import subtract
|
||||
from ngraph.opset6 import swish
|
||||
from ngraph.opset6 import tan
|
||||
from ngraph.opset6 import tanh
|
||||
from ngraph.opset6 import tensor_iterator
|
||||
from ngraph.opset6 import tile
|
||||
from ngraph.opset6 import topk
|
||||
from ngraph.opset6 import transpose
|
||||
from ngraph.opset6 import unsqueeze
|
||||
from ngraph.opset6 import variadic_split
|
||||
from ngraph.opset7 import absolute
|
||||
from ngraph.opset7 import absolute as abs
|
||||
from ngraph.opset7 import acos
|
||||
from ngraph.opset7 import acosh
|
||||
from ngraph.opset7 import add
|
||||
from ngraph.opset7 import asin
|
||||
from ngraph.opset7 import asinh
|
||||
from ngraph.opset7 import assign
|
||||
from ngraph.opset7 import atan
|
||||
from ngraph.opset7 import atanh
|
||||
from ngraph.opset7 import avg_pool
|
||||
from ngraph.opset7 import batch_norm_inference
|
||||
from ngraph.opset7 import batch_to_space
|
||||
from ngraph.opset7 import binary_convolution
|
||||
from ngraph.opset7 import broadcast
|
||||
from ngraph.opset7 import bucketize
|
||||
from ngraph.opset7 import ceiling
|
||||
from ngraph.opset7 import ceiling as ceil
|
||||
from ngraph.opset7 import clamp
|
||||
from ngraph.opset7 import concat
|
||||
from ngraph.opset7 import constant
|
||||
from ngraph.opset7 import convert
|
||||
from ngraph.opset7 import convert_like
|
||||
from ngraph.opset7 import convolution
|
||||
from ngraph.opset7 import convolution_backprop_data
|
||||
from ngraph.opset7 import cos
|
||||
from ngraph.opset7 import cosh
|
||||
from ngraph.opset7 import ctc_greedy_decoder
|
||||
from ngraph.opset7 import ctc_greedy_decoder_seq_len
|
||||
from ngraph.opset7 import ctc_loss
|
||||
from ngraph.opset7 import cum_sum
|
||||
from ngraph.opset7 import cum_sum as cumsum
|
||||
from ngraph.opset7 import deformable_convolution
|
||||
from ngraph.opset7 import deformable_psroi_pooling
|
||||
from ngraph.opset7 import depth_to_space
|
||||
from ngraph.opset7 import detection_output
|
||||
from ngraph.opset7 import divide
|
||||
from ngraph.opset7 import elu
|
||||
from ngraph.opset7 import embedding_bag_offsets_sum
|
||||
from ngraph.opset7 import embedding_bag_packed_sum
|
||||
from ngraph.opset7 import embedding_segments_sum
|
||||
from ngraph.opset7 import extract_image_patches
|
||||
from ngraph.opset7 import equal
|
||||
from ngraph.opset7 import erf
|
||||
from ngraph.opset7 import exp
|
||||
from ngraph.opset7 import fake_quantize
|
||||
from ngraph.opset7 import floor
|
||||
from ngraph.opset7 import floor_mod
|
||||
from ngraph.opset7 import gather
|
||||
from ngraph.opset7 import gather_elements
|
||||
from ngraph.opset7 import gather_nd
|
||||
from ngraph.opset7 import gather_tree
|
||||
from ngraph.opset7 import gelu
|
||||
from ngraph.opset7 import greater
|
||||
from ngraph.opset7 import greater_equal
|
||||
from ngraph.opset7 import grn
|
||||
from ngraph.opset7 import group_convolution
|
||||
from ngraph.opset7 import group_convolution_backprop_data
|
||||
from ngraph.opset7 import gru_cell
|
||||
from ngraph.opset7 import gru_sequence
|
||||
from ngraph.opset7 import hard_sigmoid
|
||||
from ngraph.opset7 import hsigmoid
|
||||
from ngraph.opset7 import hswish
|
||||
from ngraph.opset7 import interpolate
|
||||
from ngraph.opset7 import less
|
||||
from ngraph.opset7 import less_equal
|
||||
from ngraph.opset7 import log
|
||||
from ngraph.opset7 import logical_and
|
||||
from ngraph.opset7 import logical_not
|
||||
from ngraph.opset7 import logical_or
|
||||
from ngraph.opset7 import logical_xor
|
||||
from ngraph.opset7 import log_softmax
|
||||
from ngraph.opset7 import loop
|
||||
from ngraph.opset7 import lrn
|
||||
from ngraph.opset7 import lstm_cell
|
||||
from ngraph.opset7 import lstm_sequence
|
||||
from ngraph.opset7 import matmul
|
||||
from ngraph.opset7 import max_pool
|
||||
from ngraph.opset7 import maximum
|
||||
from ngraph.opset7 import minimum
|
||||
from ngraph.opset7 import mish
|
||||
from ngraph.opset7 import mod
|
||||
from ngraph.opset7 import multiply
|
||||
from ngraph.opset7 import mvn
|
||||
from ngraph.opset7 import negative
|
||||
from ngraph.opset7 import non_max_suppression
|
||||
from ngraph.opset7 import non_zero
|
||||
from ngraph.opset7 import normalize_l2
|
||||
from ngraph.opset7 import not_equal
|
||||
from ngraph.opset7 import one_hot
|
||||
from ngraph.opset7 import pad
|
||||
from ngraph.opset7 import parameter
|
||||
from ngraph.opset7 import power
|
||||
from ngraph.opset7 import prelu
|
||||
from ngraph.opset7 import prior_box
|
||||
from ngraph.opset7 import prior_box_clustered
|
||||
from ngraph.opset7 import psroi_pooling
|
||||
from ngraph.opset7 import proposal
|
||||
from ngraph.opset7 import range
|
||||
from ngraph.opset7 import read_value
|
||||
from ngraph.opset7 import reduce_l1
|
||||
from ngraph.opset7 import reduce_l2
|
||||
from ngraph.opset7 import reduce_logical_and
|
||||
from ngraph.opset7 import reduce_logical_or
|
||||
from ngraph.opset7 import reduce_max
|
||||
from ngraph.opset7 import reduce_mean
|
||||
from ngraph.opset7 import reduce_min
|
||||
from ngraph.opset7 import reduce_prod
|
||||
from ngraph.opset7 import reduce_sum
|
||||
from ngraph.opset7 import region_yolo
|
||||
from ngraph.opset7 import reorg_yolo
|
||||
from ngraph.opset7 import relu
|
||||
from ngraph.opset7 import reshape
|
||||
from ngraph.opset7 import result
|
||||
from ngraph.opset7 import reverse_sequence
|
||||
from ngraph.opset7 import rnn_cell
|
||||
from ngraph.opset7 import rnn_sequence
|
||||
from ngraph.opset7 import roi_align
|
||||
from ngraph.opset7 import roi_pooling
|
||||
from ngraph.opset7 import round
|
||||
from ngraph.opset7 import scatter_elements_update
|
||||
from ngraph.opset7 import scatter_update
|
||||
from ngraph.opset7 import select
|
||||
from ngraph.opset7 import selu
|
||||
from ngraph.opset7 import shape_of
|
||||
from ngraph.opset7 import shuffle_channels
|
||||
from ngraph.opset7 import sigmoid
|
||||
from ngraph.opset7 import sign
|
||||
from ngraph.opset7 import sin
|
||||
from ngraph.opset7 import sinh
|
||||
from ngraph.opset7 import softmax
|
||||
from ngraph.opset7 import softplus
|
||||
from ngraph.opset7 import space_to_batch
|
||||
from ngraph.opset7 import space_to_depth
|
||||
from ngraph.opset7 import split
|
||||
from ngraph.opset7 import sqrt
|
||||
from ngraph.opset7 import squared_difference
|
||||
from ngraph.opset7 import squeeze
|
||||
from ngraph.opset7 import strided_slice
|
||||
from ngraph.opset7 import subtract
|
||||
from ngraph.opset7 import swish
|
||||
from ngraph.opset7 import tan
|
||||
from ngraph.opset7 import tanh
|
||||
from ngraph.opset7 import tensor_iterator
|
||||
from ngraph.opset7 import tile
|
||||
from ngraph.opset7 import topk
|
||||
from ngraph.opset7 import transpose
|
||||
from ngraph.opset7 import unsqueeze
|
||||
from ngraph.opset7 import variadic_split
|
||||
|
||||
|
||||
# Extend Node class to support binary operators
|
||||
|
@ -173,4 +173,3 @@ xfail_issue_49750 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v4
|
||||
xfail_issue_49752 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::Pad")
|
||||
xfail_issue_49753 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::StridedSlice")
|
||||
xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE")
|
||||
xfail_issue_49913 = xfail_test(reason="CPU supports Gelu with tanh mode only")
|
||||
|
@ -8,7 +8,7 @@ import ngraph as ng
|
||||
from ngraph.impl import Shape, Type
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
from tests import xfail_issue_44970, xfail_issue_49913
|
||||
from tests import xfail_issue_44970
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -176,7 +176,6 @@ def test_hsigmoid():
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
|
||||
|
||||
@xfail_issue_49913
|
||||
def test_gelu_operator_with_parameters():
|
||||
runtime = get_runtime()
|
||||
|
||||
@ -190,10 +189,9 @@ def test_gelu_operator_with_parameters():
|
||||
|
||||
result = computation(data_value)
|
||||
expected = np.array([[-1.6391277e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]], dtype=np.float32)
|
||||
assert np.allclose(result, expected)
|
||||
assert np.allclose(result, expected, 1e-6, 1e-6)
|
||||
|
||||
|
||||
@xfail_issue_49913
|
||||
def test_gelu_operator_with_array():
|
||||
runtime = get_runtime()
|
||||
|
||||
@ -204,7 +202,7 @@ def test_gelu_operator_with_array():
|
||||
|
||||
result = computation()
|
||||
expected = np.array([[-1.6391277e-06, 8.4134471e-01], [-4.5500278e-02, 2.9959502]], dtype=np.float32)
|
||||
assert np.allclose(result, expected)
|
||||
assert np.allclose(result, expected, 1e-6, 1e-6)
|
||||
|
||||
|
||||
def test_gelu_tanh_operator_with_parameters():
|
||||
|
Loading…
Reference in New Issue
Block a user