[core] Migrate Gelu operator to new API (#20833)
* Drop HostTensor * Remove useless overwrite method
This commit is contained in:
@@ -23,8 +23,6 @@ public:
|
||||
/// \param data Input tensor
|
||||
Gelu(const Output<Node>& data);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
@@ -56,9 +54,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
@@ -2,41 +2,72 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/gelu.hpp"
|
||||
#include "openvino/op/gelu.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/type.hpp"
|
||||
#include "openvino/reference/gelu.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v0 {
|
||||
Gelu::Gelu() : UnaryElementwiseArithmetic() {}
|
||||
|
||||
// ------------------------------ V0 ------------------------------
|
||||
op::v0::Gelu::Gelu() : UnaryElementwiseArithmetic() {}
|
||||
|
||||
op::v0::Gelu::Gelu(const Output<Node>& data) : UnaryElementwiseArithmetic(data) {
|
||||
Gelu::Gelu(const Output<Node>& data) : UnaryElementwiseArithmetic(data) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v0::Gelu::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v0_Gelu_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v0::Gelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Gelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v0_Gelu_clone_with_new_inputs);
|
||||
if (new_args.size() != 1) {
|
||||
OPENVINO_THROW("Incorrect number of new arguments");
|
||||
}
|
||||
return make_shared<op::v0::Gelu>(new_args.at(0));
|
||||
return std::make_shared<Gelu>(new_args.at(0));
|
||||
}
|
||||
|
||||
void op::v0::Gelu::validate_and_infer_types() {
|
||||
void Gelu::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v0_Gelu_validate_and_infer_types);
|
||||
element::Type input_element_type = get_input_element_type(0);
|
||||
ov::PartialShape input_pshape = get_input_partial_shape(0);
|
||||
PartialShape input_pshape = get_input_partial_shape(0);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_element_type.is_dynamic() || input_element_type.is_real(),
|
||||
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
|
||||
input_element_type,
|
||||
").");
|
||||
|
||||
set_output_type(0, input_element_type, input_pshape);
|
||||
}
|
||||
} // namespace v0
|
||||
|
||||
namespace v7 {
|
||||
Gelu::Gelu(const Output<Node>& data, GeluApproximationMode mode)
|
||||
: UnaryElementwiseArithmetic(data),
|
||||
m_approximation_mode(mode) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool Gelu::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v7_Gelu_visit_attributes);
|
||||
visitor.on_attribute("approximation_mode", m_approximation_mode);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Gelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v7_Gelu_clone_with_new_inputs);
|
||||
if (new_args.size() != 1) {
|
||||
OPENVINO_THROW("Incorrect number of new arguments");
|
||||
}
|
||||
return std::make_shared<Gelu>(new_args.at(0), m_approximation_mode);
|
||||
}
|
||||
|
||||
void Gelu::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v7_Gelu_validate_and_infer_types);
|
||||
element::Type input_element_type = get_input_element_type(0);
|
||||
PartialShape input_pshape = get_input_partial_shape(0);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_element_type.is_dynamic() || input_element_type.is_real(),
|
||||
@@ -47,14 +78,57 @@ void op::v0::Gelu::validate_and_infer_types() {
|
||||
set_output_type(0, input_element_type, input_pshape);
|
||||
}
|
||||
|
||||
// ------------------------------ V7 ------------------------------
|
||||
op::GeluApproximationMode Gelu::get_approximation_mode() const {
|
||||
return m_approximation_mode;
|
||||
}
|
||||
|
||||
namespace gelu {
|
||||
namespace {
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& in, Tensor& out, const op::GeluApproximationMode mode, const size_t count) {
|
||||
reference::gelu(in.data<const T>(), out.data<T>(), mode, count);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
} // namespace gelu
|
||||
|
||||
bool Gelu::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v7_Gelu_evaluate);
|
||||
OPENVINO_ASSERT(inputs.size() == 1 && outputs.size() == 1);
|
||||
|
||||
const auto& input_shape = inputs[0].get_shape();
|
||||
const auto count = shape_size(input_shape);
|
||||
outputs[0].set_shape(input_shape);
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<f16, f32>::apply<gelu::Evaluate>(inputs[0].get_element_type(),
|
||||
inputs[0],
|
||||
outputs[0],
|
||||
m_approximation_mode,
|
||||
count);
|
||||
}
|
||||
|
||||
bool Gelu::has_evaluate() const {
|
||||
OV_OP_SCOPE(v7_Gelu_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace v7
|
||||
} // namespace op
|
||||
|
||||
namespace ov {
|
||||
template <>
|
||||
NGRAPH_API EnumNames<ngraph::op::GeluApproximationMode>& EnumNames<ngraph::op::GeluApproximationMode>::get() {
|
||||
static auto enum_names = EnumNames<ngraph::op::GeluApproximationMode>(
|
||||
OPENVINO_API EnumNames<op::GeluApproximationMode>& EnumNames<op::GeluApproximationMode>::get() {
|
||||
static auto enum_names = EnumNames<op::GeluApproximationMode>(
|
||||
"op::GeluApproximationMode",
|
||||
{{"TANH", ngraph::op::GeluApproximationMode::TANH}, {"ERF", ngraph::op::GeluApproximationMode::ERF}});
|
||||
{{"TANH", op::GeluApproximationMode::TANH}, {"ERF", op::GeluApproximationMode::ERF}});
|
||||
return enum_names;
|
||||
}
|
||||
|
||||
@@ -62,91 +136,3 @@ std::ostream& op::operator<<(std::ostream& s, const op::GeluApproximationMode& t
|
||||
return s << as_string(type);
|
||||
}
|
||||
} // namespace ov
|
||||
|
||||
op::v7::Gelu::Gelu(const Output<Node>& data, GeluApproximationMode mode)
|
||||
: UnaryElementwiseArithmetic(data),
|
||||
m_approximation_mode(mode) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v7::Gelu::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v7_Gelu_visit_attributes);
|
||||
visitor.on_attribute("approximation_mode", m_approximation_mode);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v7::Gelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v7_Gelu_clone_with_new_inputs);
|
||||
if (new_args.size() != 1) {
|
||||
OPENVINO_THROW("Incorrect number of new arguments");
|
||||
}
|
||||
return make_shared<op::v7::Gelu>(new_args.at(0), m_approximation_mode);
|
||||
}
|
||||
|
||||
void op::v7::Gelu::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v7_Gelu_validate_and_infer_types);
|
||||
element::Type input_element_type = get_input_element_type(0);
|
||||
ov::PartialShape input_pshape = get_input_partial_shape(0);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_element_type.is_dynamic() || input_element_type.is_real(),
|
||||
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
|
||||
input_element_type,
|
||||
").");
|
||||
|
||||
set_output_type(0, input_element_type, input_pshape);
|
||||
}
|
||||
|
||||
op::GeluApproximationMode op::v7::Gelu::get_approximation_mode() const {
|
||||
return m_approximation_mode;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace gelu {
|
||||
namespace {
|
||||
template <element::Type_t ET>
|
||||
inline bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& out,
|
||||
op::GeluApproximationMode mode,
|
||||
const size_t count) {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
ov::reference::gelu<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), mode, count);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_gelu(const HostTensorPtr& arg0, const HostTensorPtr& out, op::GeluApproximationMode mode) {
|
||||
bool rc = true;
|
||||
size_t count = shape_size(arg0->get_shape());
|
||||
out->set_unary(arg0);
|
||||
|
||||
switch (arg0->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_gelu, f16, arg0, out, mode, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_gelu, f32, arg0, out, mode, count);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace gelu
|
||||
|
||||
bool op::v7::Gelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v7_Gelu_evaluate);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 1));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
return gelu::evaluate_gelu(inputs[0], outputs[0], m_approximation_mode);
|
||||
}
|
||||
|
||||
bool op::v7::Gelu::has_evaluate() const {
|
||||
OV_OP_SCOPE(v7_Gelu_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user