Migrate PReLU operator to new API (#21098)
This commit is contained in:
parent
f78d0950d6
commit
e72afbec7e
@ -24,15 +24,10 @@ public:
|
||||
/// \param slope Multipliers for negative values
|
||||
PRelu(const Output<Node>& data, const Output<Node>& slope);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
};
|
||||
} // namespace v0
|
||||
|
@ -13,6 +13,14 @@
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
namespace func {
|
||||
// Usage of custom function instead of lambda, gives smaller binary size.
|
||||
template <class T>
|
||||
T prelu(const T x, const T y) {
|
||||
return x < T(0) ? x * y : x;
|
||||
}
|
||||
} // namespace func
|
||||
|
||||
template <typename T>
|
||||
void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const Shape& slope_shape) {
|
||||
Shape slope_shape_tmp = slope_shape;
|
||||
@ -22,9 +30,7 @@ void prelu(const T* arg, const T* slope, T* out, const Shape& arg_shape, const S
|
||||
channel_slope_shape[channel_dim_idx] = slope_shape[0];
|
||||
std::swap(slope_shape_tmp, channel_slope_shape);
|
||||
}
|
||||
autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, [](T x, T y) -> T {
|
||||
return x < T(0) ? T(x * y) : x;
|
||||
});
|
||||
autobroadcast_binop(arg, slope, out, arg_shape, slope_shape_tmp, op::AutoBroadcastType::NUMPY, func::prelu<T>);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace ov
|
||||
|
@ -2,85 +2,79 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/prelu.hpp"
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include "openvino/op/prelu.hpp"
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/reference/prelu.hpp"
|
||||
|
||||
using namespace std;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace prelu {
|
||||
|
||||
ov::op::v0::PRelu::PRelu() : Op() {}
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
ov::op::v0::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope) : Op({data, slope}) {
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& arg,
|
||||
const Tensor& slope,
|
||||
Tensor& out,
|
||||
const Shape& arg_shape,
|
||||
const Shape& slope_shape) {
|
||||
reference::prelu(arg.data<const T>(), slope.data<const T>(), out.data<T>(), arg_shape, slope_shape);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace prelu
|
||||
|
||||
namespace v0 {
|
||||
|
||||
PRelu::PRelu() : Op() {}
|
||||
|
||||
PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope) : Op({data, slope}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v0_PRelu_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ngraph::op::v0::PRelu::validate_and_infer_types() {
|
||||
void PRelu::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
shared_ptr<ov::Node> ov::op::v0::PRelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<ov::Node> PRelu::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v0_PRelu_clone_with_new_inputs);
|
||||
OPENVINO_ASSERT(new_args.size() == 2, "Incorrect number of new arguments");
|
||||
return make_shared<PRelu>(new_args.at(0), new_args.at(1));
|
||||
return std::make_shared<PRelu>(new_args.at(0), new_args.at(1));
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace prelu {
|
||||
namespace {
|
||||
template <ov::element::Type_t ET>
|
||||
bool evaluate(const ngraph::HostTensorPtr& arg, const ngraph::HostTensorPtr& slope, const ngraph::HostTensorPtr& out) {
|
||||
ov::reference::prelu(arg->get_data_ptr<ET>(),
|
||||
slope->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg->get_shape(),
|
||||
slope->get_shape());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_prelu(const ngraph::HostTensorPtr& arg,
|
||||
const ngraph::HostTensorPtr& slope,
|
||||
const ngraph::HostTensorPtr& out) {
|
||||
bool rc = true;
|
||||
switch (arg->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_prelu, i8, arg, slope, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_prelu, bf16, arg, slope, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_prelu, f16, arg, slope, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_prelu, f32, arg, slope, out);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace prelu
|
||||
|
||||
bool ov::op::v0::PRelu::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool PRelu::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v0_PRelu_evaluate);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(outputs, 1) && ngraph::validate_host_tensor_vector(inputs, 2));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
return prelu::evaluate_prelu(inputs[0], inputs[1], outputs[0]);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
OPENVINO_ASSERT(inputs.size() == 2);
|
||||
|
||||
auto& out = outputs[0];
|
||||
const auto& arg_shape = inputs[0].get_shape();
|
||||
out.set_shape(arg_shape);
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<bf16, f16, f32, i8>::apply<prelu::Evaluate>(inputs[0].get_element_type(),
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
out,
|
||||
arg_shape,
|
||||
inputs[1].get_shape());
|
||||
}
|
||||
|
||||
bool ov::op::v0::PRelu::has_evaluate() const {
|
||||
bool PRelu::has_evaluate() const {
|
||||
OV_OP_SCOPE(v0_PRelu_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::i8:
|
||||
case ngraph::element::bf16:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case element::bf16:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::i8:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace v0
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user