Migrate the Abs operator to new API (#19763)

This commit is contained in:
Pawel Raasz 2023-09-12 13:26:45 +02:00 committed by GitHub
parent f3d4665f7b
commit 693c6d7a11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 73 deletions

View File

@ -30,9 +30,7 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0

View File

@ -7,19 +7,32 @@
#include <cstddef>
#include <type_traits>
#include "openvino/reference/utils/type_util.hpp"
namespace ov {
namespace reference {
template <typename T, typename std::enable_if<std::is_unsigned<T>::value, bool>::type = true>
void abs(const T* arg, T* out, size_t count) {
std::copy(arg, arg + count, out);
namespace func {
template <class T, typename std::enable_if<std::is_unsigned<T>::value>::type* = nullptr>
constexpr T abs(const T num) {
return num;
}
template <typename T, typename std::enable_if<!std::is_unsigned<T>::value, bool>::type = true>
void abs(const T* arg, T* out, size_t count) {
for (size_t i = 0; i < count; i++) {
// TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < T(0) ? T(-arg[i]) : arg[i]);
}
template <class T, typename std::enable_if<std::is_signed<T>::value || ov::is_floating_point<T>()>::type* = nullptr>
T abs(const T num) {
return std::abs(num);
}
} // namespace func
/**
* @brief Reference implementation of Abs operator.
*
* @param in Input pointer to data.
* @param out Output pointer to results.
* @param count Number of elements in input buffer.
*/
template <class T>
void abs(const T* in, T* out, const size_t count) {
std::transform(in, std::next(in, count), out, &func::abs<T>);
}
} // namespace reference
} // namespace ov

View File

@ -7,6 +7,7 @@
#include <cmath>
#include <numeric>
#include "openvino/reference/abs.hpp"
#include "openvino/reference/sum.hpp"
#include "openvino/reference/utils/type_util.hpp"
#include "shape_util.hpp"
@ -14,18 +15,6 @@
namespace ov {
namespace reference {
namespace func {
template <class T, typename std::enable_if<std::is_unsigned<T>::value>::type* = nullptr>
constexpr T abs(const T num) {
return num;
}
template <class T, typename std::enable_if<std::is_signed<T>::value || ov::is_floating_point<T>()>::type* = nullptr>
T abs(const T num) {
return std::abs(num);
}
} // namespace func
/**
* @brief Reference implementation of ReduceL1 operator.
*

View File

@ -2,73 +2,67 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/abs.hpp"
#include "openvino/op/abs.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/reference/abs.hpp"
ov::op::v0::Abs::Abs(const Output<Node>& arg) : UnaryElementwiseArithmetic(arg) {
namespace ov {
namespace op {
namespace abs {
struct Evaluate : ov::element::NoAction<bool> {
using ov::element::NoAction<bool>::visit;
template <element::Type_t ET>
static result_type visit(const Tensor& in, Tensor& out, const size_t count) {
using T = typename element_type_traits<ET>::value_type;
reference::abs(in.data<const T>(), out.data<T>(), count);
return true;
}
};
} // namespace abs
namespace v0 {
Abs::Abs(const Output<Node>& arg) : UnaryElementwiseArithmetic(arg) {
constructor_validate_and_infer_types();
}
std::shared_ptr<ov::Node> ov::op::v0::Abs::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Abs::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v0_Abs_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<Abs>(new_args.at(0));
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace absop {
namespace {
template <ov::element::Type_t ET>
inline bool evaluate(const ngraph::HostTensorPtr& arg0, const ngraph::HostTensorPtr& out, const size_t count) {
using T = typename ov::element_type_traits<ET>::value_type;
ov::reference::abs<T>((arg0->get_data_ptr<ET>()), (out->get_data_ptr<ET>()), count);
return true;
}
bool evaluate_abs(const ngraph::HostTensorPtr& arg0, const ngraph::HostTensorPtr& out, const size_t count) {
bool rc = true;
out->set_unary(arg0);
switch (arg0->get_element_type()) {
NGRAPH_TYPE_CASE(evaluate_abs, i32, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_abs, i64, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_abs, u32, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_abs, u64, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_abs, f16, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_abs, f32, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_abs, bf16, arg0, out, count);
default:
rc = false;
break;
}
return rc;
}
} // namespace
} // namespace absop
bool ov::op::v0::Abs::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool Abs::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_Abs_evaluate);
return absop::evaluate_abs(inputs[0], outputs[0], shape_size(inputs[0]->get_shape()));
OPENVINO_ASSERT(inputs.size() == 1);
OPENVINO_ASSERT(outputs.size() == 1);
outputs[0].set_shape(inputs[0].get_shape());
using namespace ov::element;
return IfTypeOf<bf16, f16, f32, i32, i64, u32, u64>::apply<abs::Evaluate>(inputs[0].get_element_type(),
inputs[0],
outputs[0],
shape_size(inputs[0].get_shape()));
}
bool ov::op::v0::Abs::has_evaluate() const {
bool Abs::has_evaluate() const {
OV_OP_SCOPE(v0_Abs_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u32:
case ngraph::element::u64:
case ngraph::element::f16:
case ngraph::element::f32:
case ngraph::element::bf16:
case element::bf16:
case element::f16:
case element::f32:
case element::i32:
case element::i64:
case element::u32:
case element::u64:
return true;
default:
break;
return false;
}
return false;
}
} // namespace v0
} // namespace op
} // namespace ov