Migrate PReLU operator to new API (#21098)

This commit is contained in:
Pawel Raasz 2023-11-16 09:18:06 +01:00 committed by GitHub
parent f78d0950d6
commit e72afbec7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 66 deletions

View File

@ -24,15 +24,10 @@ public:
/// \param slope Multipliers for negative values
PRelu(const Output<Node>& data, const Output<Node>& slope);
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0

View File

@ -13,6 +13,14 @@
namespace ov {
namespace reference {
namespace func {
// Usage of custom function instead of lambda, gives smaller binary size.
template <class T>
T prelu(const T x, const T y) {
return x < T(0) ? x * y : x;
}
} // namespace func
template <typename T>
void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const Shape& slope_shape) {
Shape slope_shape_tmp = slope_shape;
@ -22,9 +30,7 @@ void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const S
channel_slope_shape[channel_dim_idx] = slope_shape[0];
std::swap(slope_shape_tmp, channel_slope_shape);
}
autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, [](T x, T y) -> T {
return x < T(0) ? T(x * y) : x;
});
autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, func::prelu<T>);
}
} // namespace reference
} // namespace ov

View File

@ -2,85 +2,79 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/prelu.hpp"
#include <ngraph/validation_util.hpp>
#include "openvino/op/prelu.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "openvino/reference/prelu.hpp"
using namespace std;
namespace ov {
namespace op {
namespace prelu {
ov::op::v0::PRelu::PRelu() : Op() {}
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;
ov::op::v0::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope) : Op({data, slope}) {
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& arg,
const Tensor& slope,
Tensor& out,
const Shape& arg_shape,
const Shape& slope_shape) {
reference::prelu(arg.data<const T>(), slope.data<const T>(), out.data<T>(), arg_shape, slope_shape);
return true;
}
};
} // namespace prelu
namespace v0 {
PRelu::PRelu() : Op() {}
PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope) : Op({data, slope}) {
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v0_PRelu_visit_attributes);
return true;
}
void ngraph::op::v0::PRelu::validate_and_infer_types() {
void PRelu::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<ov::Node> ov::op::v0::PRelu::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<ov::Node> PRelu::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v0_PRelu_clone_with_new_inputs);
OPENVINO_ASSERT(new_args.size() == 2, "Incorrect number of new arguments");
return make_shared<PRelu>(new_args.at(0), new_args.at(1));
return std::make_shared<PRelu>(new_args.at(0), new_args.at(1));
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace prelu {
namespace {
template <ov::element::Type_t ET>
bool evaluate(const ngraph::HostTensorPtr& arg, const ngraph::HostTensorPtr& slope, const ngraph::HostTensorPtr& out) {
ov::reference::prelu(arg->get_data_ptr<ET>(),
slope->get_data_ptr<ET>(),
out->get_data_ptr<ET>(),
arg->get_shape(),
slope->get_shape());
return true;
}
bool evaluate_prelu(const ngraph::HostTensorPtr& arg,
const ngraph::HostTensorPtr& slope,
const ngraph::HostTensorPtr& out) {
bool rc = true;
switch (arg->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_prelu, i8, arg, slope, out);
OPENVINO_TYPE_CASE(evaluate_prelu, bf16, arg, slope, out);
OPENVINO_TYPE_CASE(evaluate_prelu, f16, arg, slope, out);
OPENVINO_TYPE_CASE(evaluate_prelu, f32, arg, slope, out);
default:
rc = false;
break;
}
return rc;
}
} // namespace
} // namespace prelu
bool ov::op::v0::PRelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool PRelu::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_PRelu_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(outputs, 1) && ngraph::validate_host_tensor_vector(inputs, 2));
OPENVINO_SUPPRESS_DEPRECATED_END
return prelu::evaluate_prelu(inputs[0], inputs[1], outputs[0]);
OPENVINO_ASSERT(outputs.size() == 1);
OPENVINO_ASSERT(inputs.size() == 2);
auto& out = outputs[0];
const auto& arg_shape = inputs[0].get_shape();
out.set_shape(arg_shape);
using namespace ov::element;
return IfTypeOf<bf16, f16, f32, i8>::apply<prelu::Evaluate>(inputs[0].get_element_type(),
inputs[0],
inputs[1],
out,
arg_shape,
inputs[1].get_shape());
}
bool ov::op::v0::PRelu::has_evaluate() const {
bool PRelu::has_evaluate() const {
OV_OP_SCOPE(v0_PRelu_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::i8:
case ngraph::element::bf16:
case ngraph::element::f16:
case ngraph::element::f32:
case element::bf16:
case element::f16:
case element::f32:
case element::i8:
return true;
default:
break;
return false;
}
return false;
}
} // namespace v0
} // namespace op
} // namespace ov