From e72afbec7ecebeb78d753e6e2c1757b9b72222e3 Mon Sep 17 00:00:00 2001 From: Pawel Raasz Date: Thu, 16 Nov 2023 09:18:06 +0100 Subject: [PATCH] Migrate PReLU operator to new API (#21098) --- src/core/include/openvino/op/prelu.hpp | 7 +- .../include/openvino/reference/prelu.hpp | 12 +- src/core/src/op/prelu.cpp | 108 +++++++++--------- 3 files changed, 61 insertions(+), 66 deletions(-) diff --git a/src/core/include/openvino/op/prelu.hpp b/src/core/include/openvino/op/prelu.hpp index 62c320d8c6a..bed626c4993 100644 --- a/src/core/include/openvino/op/prelu.hpp +++ b/src/core/include/openvino/op/prelu.hpp @@ -24,15 +24,10 @@ public: /// \param slope Multipliers for negative values PRelu(const Output& data, const Output& slope); - bool visit_attributes(AttributeVisitor& visitor) override; - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; - - OPENVINO_SUPPRESS_DEPRECATED_START - bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; - OPENVINO_SUPPRESS_DEPRECATED_END + bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override; bool has_evaluate() const override; }; } // namespace v0 diff --git a/src/core/reference/include/openvino/reference/prelu.hpp b/src/core/reference/include/openvino/reference/prelu.hpp index 7c3005e7e57..88ffcc1ffcf 100644 --- a/src/core/reference/include/openvino/reference/prelu.hpp +++ b/src/core/reference/include/openvino/reference/prelu.hpp @@ -13,6 +13,14 @@ namespace ov { namespace reference { +namespace func { +// Usage of custom function instead of lambda, gives smaller binary size. +template +T prelu(const T x, const T y) { + return x < T(0) ? x * y : x; +} +} // namespace func + template void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const Shape& slope_shape) { Shape slope_shape_tmp = slope_shape; @@ -22,9 +30,7 @@ void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const S channel_slope_shape[channel_dim_idx] = slope_shape[0]; std::swap(slope_shape_tmp, channel_slope_shape); } - autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, [](T x, T y) -> T { - return x < T(0) ? T(x * y) : x; - }); + autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, func::prelu); } } // namespace reference } // namespace ov diff --git a/src/core/src/op/prelu.cpp b/src/core/src/op/prelu.cpp index 9e1ccd3ec3c..ee417602cf0 100644 --- a/src/core/src/op/prelu.cpp +++ b/src/core/src/op/prelu.cpp @@ -2,85 +2,79 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/op/prelu.hpp" - -#include +#include "openvino/op/prelu.hpp" +#include "element_visitor.hpp" #include "itt.hpp" #include "openvino/reference/prelu.hpp" -using namespace std; +namespace ov { +namespace op { +namespace prelu { -ov::op::v0::PRelu::PRelu() : Op() {} +struct Evaluate : element::NoAction { + using element::NoAction::visit; -ov::op::v0::PRelu::PRelu(const Output& data, const Output& slope) : Op({data, slope}) { + template > + static result_type visit(const Tensor& arg, + const Tensor& slope, + Tensor& out, + const Shape& arg_shape, + const Shape& slope_shape) { + reference::prelu(arg.data(), slope.data(), out.data(), arg_shape, slope_shape); + return true; + } +}; +} // namespace prelu + +namespace v0 { + +PRelu::PRelu() : Op() {} + +PRelu::PRelu(const Output& data, const Output& slope) : Op({data, slope}) { constructor_validate_and_infer_types(); } -bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor) { - OV_OP_SCOPE(v0_PRelu_visit_attributes); - return true; -} - -void ngraph::op::v0::PRelu::validate_and_infer_types() { +void PRelu::validate_and_infer_types() { set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } -shared_ptr ov::op::v0::PRelu::clone_with_new_inputs(const OutputVector& new_args) const { +std::shared_ptr PRelu::clone_with_new_inputs(const OutputVector& new_args) const { OV_OP_SCOPE(v0_PRelu_clone_with_new_inputs); OPENVINO_ASSERT(new_args.size() == 2, "Incorrect number of new arguments"); - return make_shared(new_args.at(0), new_args.at(1)); + return std::make_shared(new_args.at(0), new_args.at(1)); } -OPENVINO_SUPPRESS_DEPRECATED_START -namespace prelu { -namespace { -template -bool evaluate(const ngraph::HostTensorPtr& arg, const ngraph::HostTensorPtr& slope, const ngraph::HostTensorPtr& out) { - ov::reference::prelu(arg->get_data_ptr(), - slope->get_data_ptr(), - out->get_data_ptr(), - arg->get_shape(), - slope->get_shape()); - return true; -} - -bool evaluate_prelu(const ngraph::HostTensorPtr& arg, - const ngraph::HostTensorPtr& slope, - const ngraph::HostTensorPtr& out) { - bool rc = true; - switch (arg->get_element_type()) { - OPENVINO_TYPE_CASE(evaluate_prelu, i8, arg, slope, out); - OPENVINO_TYPE_CASE(evaluate_prelu, bf16, arg, slope, out); - OPENVINO_TYPE_CASE(evaluate_prelu, f16, arg, slope, out); - OPENVINO_TYPE_CASE(evaluate_prelu, f32, arg, slope, out); - default: - rc = false; - break; - } - return rc; -} -} // namespace -} // namespace prelu - -bool ov::op::v0::PRelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const { +bool PRelu::evaluate(TensorVector& outputs, const TensorVector& inputs) const { OV_OP_SCOPE(v0_PRelu_evaluate); - OPENVINO_SUPPRESS_DEPRECATED_START - OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(outputs, 1) && ngraph::validate_host_tensor_vector(inputs, 2)); - OPENVINO_SUPPRESS_DEPRECATED_END - return prelu::evaluate_prelu(inputs[0], inputs[1], outputs[0]); + OPENVINO_ASSERT(outputs.size() == 1); + OPENVINO_ASSERT(inputs.size() == 2); + + auto& out = outputs[0]; + const auto& arg_shape = inputs[0].get_shape(); + out.set_shape(arg_shape); + + using namespace ov::element; + return IfTypeOf::apply(inputs[0].get_element_type(), + inputs[0], + inputs[1], + out, + arg_shape, + inputs[1].get_shape()); } -bool ov::op::v0::PRelu::has_evaluate() const { +bool PRelu::has_evaluate() const { OV_OP_SCOPE(v0_PRelu_has_evaluate); switch (get_input_element_type(0)) { - case ngraph::element::i8: - case ngraph::element::bf16: - case ngraph::element::f16: - case ngraph::element::f32: + case element::bf16: + case element::f16: + case element::f32: + case element::i8: return true; default: - break; + return false; } - return false; } +} // namespace v0 +} // namespace op +} // namespace ov