Removal of FusedOp inheritance leftovers (#7113)
* Remove FusedOp from v0::Gelu * Update v0::Gelu NGRAPH_RTTI_DECLARATION * Enable gelu type_prop tests * Remove FusedOp from v0::MVN * Remove FusedOp from HardSigmoid * Remove FusedOp from LSTMSequence * Remove supress deprecated * Add missed NGRAPH_OP_SCOPE to v0 Gelu and HardSigmoid
This commit is contained in:
parent
983bab8271
commit
5867eb3a40
@ -6,21 +6,17 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
namespace v0 {
|
||||
/// \brief Gaussian Error Linear Unit
|
||||
/// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
|
||||
class NGRAPH_API Gelu : public ngraph::op::util::FusedOp {
|
||||
class NGRAPH_API Gelu : public Op {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Gelu", 0};
|
||||
const NodeTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Gelu();
|
||||
/// \brief Constructs a Gelu operation.
|
||||
///
|
||||
@ -28,15 +24,13 @@ public:
|
||||
Gelu(const Output<Node>& data);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
|
||||
void pre_validate_and_infer_types() override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
} // namespace v0
|
||||
using v0::Gelu;
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
/// \brief Specifies the approximation to calculate Gelu
|
||||
enum class GeluApproximationMode { TANH, ERF };
|
||||
|
@ -5,9 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
@ -15,7 +13,7 @@ namespace v0 {
|
||||
/// \brief Parameterized, bounded sigmoid-like, piecewise linear
|
||||
/// function. min(max(alpha*x + beta, 0), 1)
|
||||
///
|
||||
class NGRAPH_API HardSigmoid : public ngraph::op::util::FusedOp {
|
||||
class NGRAPH_API HardSigmoid : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
@ -30,13 +28,10 @@ public:
|
||||
HardSigmoid(const Output<Node>& data, const Output<Node>& alpha, const Output<Node>& beta);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
virtual void validate_and_infer_types() override;
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
} // namespace v0
|
||||
using v0::HardSigmoid;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -14,13 +14,11 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/lstm_cell.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace v0 {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
///
|
||||
/// \brief Class for lstm sequence node.
|
||||
@ -31,7 +29,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
/// \sa LSTMCell, RNNCell, GRUCell
|
||||
///
|
||||
///
|
||||
class NGRAPH_API LSTMSequence : public util::FusedOp {
|
||||
class NGRAPH_API LSTMSequence : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
LSTMSequence();
|
||||
@ -76,7 +74,6 @@ public:
|
||||
|
||||
virtual void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
@ -138,8 +135,6 @@ private:
|
||||
bool m_input_forget;
|
||||
LSTMWeightsFormat m_weights_format;
|
||||
};
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
} // namespace v0
|
||||
|
||||
namespace v5 {
|
||||
|
@ -6,16 +6,14 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace v0 {
|
||||
/// \brief Operator performing Mean Variance Normalization
|
||||
///
|
||||
class NGRAPH_API MVN : public ngraph::op::util::FusedOp {
|
||||
class NGRAPH_API MVN : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
@ -43,8 +41,6 @@ public:
|
||||
///
|
||||
MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance = true, double eps = 1e-9);
|
||||
|
||||
virtual OutputVector decompose_op() const override;
|
||||
|
||||
virtual void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
@ -76,8 +72,6 @@ private:
|
||||
} // namespace v0
|
||||
using v0::MVN;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
/// \brief Specifies how eps is applied in MVN
|
||||
enum class MVNEpsMode {
|
||||
// Apply eps inside sqrt
|
||||
|
@ -8,26 +8,17 @@
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/builder/make_constant.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/erf.hpp"
|
||||
#include "ngraph/op/exp.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/negative.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/runtime/reference/gelu.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
// ------------------------------ V0 ------------------------------
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::Gelu, "Gelu", 0);
|
||||
|
||||
constexpr NodeTypeInfo op::Gelu::type_info;
|
||||
op::v0::Gelu::Gelu() : Op() {}
|
||||
|
||||
op::v0::Gelu::Gelu() : FusedOp() {}
|
||||
|
||||
op::v0::Gelu::Gelu(const Output<Node>& data) : FusedOp({data}) {
|
||||
op::v0::Gelu::Gelu(const Output<Node>& data) : Op({data}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -36,25 +27,6 @@ bool op::v0::Gelu::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) )
|
||||
OutputVector op::Gelu::decompose_op() const {
|
||||
auto data = input_value(0);
|
||||
|
||||
shared_ptr<ngraph::Node> half = builder::make_constant(data.get_element_type(), data.get_shape(), 0.5);
|
||||
|
||||
shared_ptr<ngraph::Node> one = builder::make_constant(data.get_element_type(), data.get_shape(), 1.0);
|
||||
|
||||
shared_ptr<ngraph::Node> sqrt_two =
|
||||
builder::make_constant(data.get_element_type(), data.get_shape(), std::sqrt(2.0));
|
||||
|
||||
shared_ptr<ngraph::Node> add =
|
||||
std::make_shared<op::v1::Add>(one,
|
||||
make_shared<ngraph::op::Erf>(std::make_shared<op::v1::Divide>(data, sqrt_two)));
|
||||
shared_ptr<ngraph::Node> multiply = std::make_shared<op::v1::Multiply>(half, data);
|
||||
|
||||
return {std::make_shared<op::v1::Multiply>(multiply, add)};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v0::Gelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
NGRAPH_OP_SCOPE(v0_Gelu_clone_with_new_inputs);
|
||||
if (new_args.size() != 1) {
|
||||
@ -63,7 +35,8 @@ shared_ptr<Node> op::v0::Gelu::clone_with_new_inputs(const OutputVector& new_arg
|
||||
return make_shared<op::v0::Gelu>(new_args.at(0));
|
||||
}
|
||||
|
||||
void op::v0::Gelu::pre_validate_and_infer_types() {
|
||||
void op::v0::Gelu::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v0_Gelu_validate_and_infer_types);
|
||||
element::Type input_element_type = get_input_element_type(0);
|
||||
PartialShape input_pshape = get_input_partial_shape(0);
|
||||
|
||||
|
@ -7,24 +7,17 @@
|
||||
#include <memory>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/maximum.hpp"
|
||||
#include "ngraph/op/minimum.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::HardSigmoid, "HardSigmoid", 0);
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::HardSigmoid, "HardSigmoid", 0, op::util::FusedOp);
|
||||
op::v0::HardSigmoid::HardSigmoid() : Op() {}
|
||||
|
||||
op::HardSigmoid::HardSigmoid() : FusedOp() {}
|
||||
|
||||
op::HardSigmoid::HardSigmoid(const Output<Node>& data, const Output<Node>& alpha, const Output<Node>& beta)
|
||||
: FusedOp({data, alpha, beta}) {
|
||||
op::v0::HardSigmoid::HardSigmoid(const Output<Node>& data, const Output<Node>& alpha, const Output<Node>& beta)
|
||||
: Op({data, alpha, beta}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -33,7 +26,8 @@ bool ngraph::op::v0::HardSigmoid::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::HardSigmoid::pre_validate_and_infer_types() {
|
||||
void op::v0::HardSigmoid::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v0_HardSigmoid_validate_and_infer_types);
|
||||
const auto& alpha_pshape = get_input_partial_shape(1);
|
||||
const auto& beta_pshape = get_input_partial_shape(2);
|
||||
|
||||
@ -64,28 +58,9 @@ void op::HardSigmoid::pre_validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
OutputVector op::HardSigmoid::decompose_op() const {
|
||||
const auto data = input_value(0);
|
||||
|
||||
const auto one_node = ngraph::op::Constant::create<float>(data.get_element_type(), data.get_shape(), {1.0f});
|
||||
|
||||
const auto zero_node = ngraph::op::Constant::create<float>(data.get_element_type(), data.get_shape(), {0.0f});
|
||||
|
||||
const auto alpha_node = input_value(1).get_node_shared_ptr();
|
||||
const auto beta_node = input_value(2).get_node_shared_ptr();
|
||||
|
||||
std::shared_ptr<Node> alpha_x_plus_beta =
|
||||
std::make_shared<op::v1::Multiply>(alpha_node, data, AutoBroadcastType::NUMPY);
|
||||
|
||||
alpha_x_plus_beta = std::make_shared<op::v1::Add>(alpha_x_plus_beta, beta_node, AutoBroadcastType::NUMPY);
|
||||
|
||||
return {
|
||||
std::make_shared<op::v1::Minimum>(std::make_shared<op::v1::Maximum>(alpha_x_plus_beta, zero_node), one_node)};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::HardSigmoid::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
shared_ptr<Node> op::v0::HardSigmoid::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
NGRAPH_OP_SCOPE(v0_HardSigmoid_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
|
||||
return make_shared<HardSigmoid>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
return std::make_shared<op::v0::HardSigmoid>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
@ -16,13 +16,11 @@
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::LSTMSequence, "LSTMSequence", 0);
|
||||
NGRAPH_RTTI_DEFINITION(op::v5::LSTMSequence, "LSTMSequence", 5);
|
||||
|
||||
op::v0::LSTMSequence::LSTMSequence()
|
||||
: FusedOp(),
|
||||
: Op(),
|
||||
m_activations_alpha(),
|
||||
m_activations_beta(),
|
||||
m_activations(),
|
||||
@ -48,7 +46,7 @@ op::v0::LSTMSequence::LSTMSequence(const Output<Node>& X,
|
||||
const std::vector<std::string> activations,
|
||||
const float clip_threshold,
|
||||
const bool input_forget)
|
||||
: FusedOp({X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B, P}),
|
||||
: Op({X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B, P}),
|
||||
m_activations_alpha(activations_alpha),
|
||||
m_activations_beta(activations_beta),
|
||||
m_activations(activations),
|
||||
@ -110,24 +108,6 @@ bool op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
OutputVector op::v0::LSTMSequence::decompose_op() const {
|
||||
OutputVector results;
|
||||
if (m_direction == direction::FORWARD || m_direction == direction::REVERSE) {
|
||||
results = lstm_pass(m_direction == direction::REVERSE);
|
||||
}
|
||||
if (m_direction == direction::BIDIRECTIONAL) {
|
||||
OutputVector fwd_results{lstm_pass()};
|
||||
OutputVector rev_results{lstm_pass(true)};
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passess.
|
||||
shared_ptr<Node> Y{make_shared<opset1::Concat>(OutputVector{fwd_results.at(0), rev_results.at(0)}, 1)};
|
||||
shared_ptr<Node> Y_h{make_shared<opset1::Concat>(OutputVector{fwd_results.at(1), rev_results.at(1)}, 1)};
|
||||
shared_ptr<Node> Y_c{make_shared<opset1::Concat>(OutputVector{fwd_results.at(2), rev_results.at(2)}, 1)};
|
||||
results = OutputVector{Y, Y_h, Y_c};
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v0::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
NGRAPH_OP_SCOPE(v0_LSTMSequence_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
|
@ -7,36 +7,26 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/builder/autobroadcast.hpp"
|
||||
#include "ngraph/builder/reduce_ops.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/sqrt.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
// ------------------------------ V0 ------------------------------
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::MVN, "MVN", 0);
|
||||
|
||||
op::MVN::MVN() : FusedOp(), m_across_channels(), m_normalize_variance(), m_reduction_axes() {}
|
||||
op::v0::MVN::MVN() : Op(), m_across_channels(), m_normalize_variance(), m_reduction_axes() {}
|
||||
|
||||
op::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_variance, double eps)
|
||||
: FusedOp({data}),
|
||||
op::v0::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_variance, double eps)
|
||||
: Op({data}),
|
||||
m_eps{eps},
|
||||
m_across_channels{across_channels},
|
||||
m_normalize_variance{normalize_variance} {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps)
|
||||
: FusedOp({data}),
|
||||
op::v0::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps)
|
||||
: Op({data}),
|
||||
m_eps{eps},
|
||||
m_across_channels{false},
|
||||
m_normalize_variance{normalize_variance},
|
||||
@ -46,10 +36,7 @@ op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_va
|
||||
m_across_channels = (m_reduction_axes.count(chanelAxis) > 0);
|
||||
}
|
||||
|
||||
// decompose_op() relies on knowing the data type of input data which might
|
||||
// not be available at shape inference time. So do direct shape inference
|
||||
// instead of relying on op decomposition.
|
||||
void op::MVN::validate_and_infer_types() {
|
||||
void op::v0::MVN::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v0_MVN_validate_and_infer_types);
|
||||
// if m_across_channels is true we should calculate mean and variance per batch
|
||||
// else we calculate these per channel
|
||||
@ -65,40 +52,16 @@ void op::MVN::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
OutputVector op::MVN::decompose_op() const {
|
||||
auto data = input_value(0);
|
||||
auto data_shape = data.get_shape(); // assume that data has n and c channels.
|
||||
|
||||
// calculate mean normalization
|
||||
auto mean = builder::opset1::mean(data, m_reduction_axes);
|
||||
auto mean_normalization =
|
||||
std::make_shared<op::v1::Subtract>(data, builder::opset1::make_broadcast(mean, data_shape, m_reduction_axes));
|
||||
|
||||
if (!m_normalize_variance) {
|
||||
return {mean_normalization};
|
||||
} else {
|
||||
// calculate variance
|
||||
auto variance = builder::opset1::variance(data, m_reduction_axes);
|
||||
// add epsilon
|
||||
auto eps_node =
|
||||
op::Constant::create(data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
|
||||
variance = std::make_shared<op::Sqrt>(std::make_shared<op::v1::Add>(variance, eps_node));
|
||||
return OutputVector{
|
||||
std::make_shared<op::v1::Divide>(mean_normalization,
|
||||
builder::opset1::make_broadcast(variance, data_shape, m_reduction_axes))};
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::MVN::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
shared_ptr<Node> op::v0::MVN::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
NGRAPH_OP_SCOPE(v0_MVN_clone_with_new_inputs);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
new_args.size() == 1,
|
||||
"Expected 1 element in new_args for the MVN op but got ",
|
||||
new_args.size());
|
||||
return make_shared<MVN>(new_args.at(0), m_reduction_axes, m_normalize_variance, m_eps);
|
||||
return std::make_shared<op::v0::MVN>(new_args.at(0), m_reduction_axes, m_normalize_variance, m_eps);
|
||||
}
|
||||
|
||||
bool op::MVN::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool op::v0::MVN::visit_attributes(AttributeVisitor& visitor) {
|
||||
NGRAPH_OP_SCOPE(v0_MVN_visit_attributes);
|
||||
visitor.on_attribute("eps", m_eps);
|
||||
visitor.on_attribute("across_channels", m_across_channels);
|
||||
|
@ -144,6 +144,7 @@ set(SRC
|
||||
type_prop/gather_elements.cpp
|
||||
type_prop/gather_nd.cpp
|
||||
type_prop/gather_tree.cpp
|
||||
type_prop/gelu.cpp
|
||||
type_prop/grn.cpp
|
||||
type_prop/group_convolution.cpp
|
||||
type_prop/group_convolution_backprop_data.cpp
|
||||
|
@ -9,8 +9,17 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, gelu_default_mode_inference_f32)
|
||||
{
|
||||
// ------------------------------ V0 ------------------------------
|
||||
TEST(type_prop, gelu_v0) {
|
||||
const PartialShape param_shape{64, Dimension::dynamic(), 256, Dimension(4, 8)};
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
|
||||
const auto op = std::make_shared<op::v0::Gelu>(param);
|
||||
ASSERT_EQ(op->get_element_type(), element::f32);
|
||||
ASSERT_EQ(op->get_output_partial_shape(0), param_shape);
|
||||
}
|
||||
|
||||
// ------------------------------ V7 ------------------------------
|
||||
TEST(type_prop, gelu_default_mode_inference_f32) {
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 32, 32});
|
||||
auto gelu = make_shared<op::v7::Gelu>(param);
|
||||
|
||||
@ -19,8 +28,7 @@ TEST(type_prop, gelu_default_mode_inference_f32)
|
||||
ASSERT_EQ(gelu->get_approximation_mode(), op::GeluApproximationMode::ERF);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_default_mode_inference_f16)
|
||||
{
|
||||
TEST(type_prop, gelu_default_mode_inference_f16) {
|
||||
auto param = make_shared<op::Parameter>(element::f16, Shape{1, 32, 32});
|
||||
auto gelu = make_shared<op::v7::Gelu>(param);
|
||||
|
||||
@ -29,8 +37,7 @@ TEST(type_prop, gelu_default_mode_inference_f16)
|
||||
ASSERT_EQ(gelu->get_approximation_mode(), op::GeluApproximationMode::ERF);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_tanh_mode_inference_f32)
|
||||
{
|
||||
TEST(type_prop, gelu_tanh_mode_inference_f32) {
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 32, 32});
|
||||
auto gelu = make_shared<op::v7::Gelu>(param, op::GeluApproximationMode::TANH);
|
||||
|
||||
@ -39,58 +46,50 @@ TEST(type_prop, gelu_tanh_mode_inference_f32)
|
||||
ASSERT_EQ(gelu->get_approximation_mode(), op::GeluApproximationMode::TANH);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_tanh_mode_inference_f16)
|
||||
{
|
||||
TEST(type_prop, gelu_tanh_mode_inference_f16) {
|
||||
auto param = make_shared<op::Parameter>(element::f16, Shape{1, 32, 32});
|
||||
auto gelu = make_shared<op::v7::Gelu>(param, op::GeluApproximationMode::TANH);
|
||||
|
||||
ASSERT_EQ(gelu->get_element_type(), element::f16);
|
||||
ASSERT_EQ(gelu->get_shape(), (Shape{1, 32, 32}));
|
||||
ASSERT_EQ(gelu->get_approximation_mode(), op::GeluApproximationMode::TANH);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_incompatible_input_type_boolean)
|
||||
{
|
||||
TEST(type_prop, gelu_incompatible_input_type_boolean) {
|
||||
auto param = make_shared<op::Parameter>(element::boolean, Shape{1, 32, 32});
|
||||
ASSERT_THROW(std::make_shared<op::v7::Gelu>(param), ngraph::NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_incompatible_input_type_u16)
|
||||
{
|
||||
TEST(type_prop, gelu_incompatible_input_type_u16) {
|
||||
auto param = make_shared<op::Parameter>(element::u16, Shape{1, 32, 32});
|
||||
ASSERT_THROW(std::make_shared<op::v7::Gelu>(param), ngraph::NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_incompatible_input_type_i32)
|
||||
{
|
||||
TEST(type_prop, gelu_incompatible_input_type_i32) {
|
||||
auto param = make_shared<op::Parameter>(element::i32, Shape{1, 32, 32});
|
||||
ASSERT_THROW(std::make_shared<op::v7::Gelu>(param), ngraph::NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_incompatible_input_type_i16)
|
||||
{
|
||||
TEST(type_prop, gelu_incompatible_input_type_i16) {
|
||||
auto param = make_shared<op::Parameter>(element::i16, Shape{1, 32, 32});
|
||||
ASSERT_THROW(std::make_shared<op::v7::Gelu>(param), ngraph::NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_dynamic_rank_input_shape_2D)
|
||||
{
|
||||
TEST(type_prop, gelu_dynamic_rank_input_shape_2D) {
|
||||
const PartialShape param_shape{Dimension::dynamic(), 10};
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
|
||||
const auto op = std::make_shared<op::v7::Gelu>(param);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{Dimension(), 10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_dynamic_rank_input_shape_3D)
|
||||
{
|
||||
TEST(type_prop, gelu_dynamic_rank_input_shape_3D) {
|
||||
const PartialShape param_shape{100, Dimension::dynamic(), 58};
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
|
||||
const auto op = std::make_shared<op::v7::Gelu>(param);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{100, Dimension(), 58}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gelu_dynamic_rank_input_shape_full)
|
||||
{
|
||||
TEST(type_prop, gelu_dynamic_rank_input_shape_full) {
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto op = std::make_shared<op::v7::Gelu>(param);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
|
@ -7,9 +7,6 @@
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
// suppress FusedOp deprecation warnings
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user