[core]Migrate Sqrt operator to new API (#20632)
* Migrate Sqrt operator to new API * Remove 'visit_attributes' is same as base
This commit is contained in:
parent
f9b76024aa
commit
261e570a81
@ -35,11 +35,8 @@ public:
|
||||
Sqrt(const Output<Node>& arg);
|
||||
Sqrt() = default;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
};
|
||||
} // namespace v0
|
||||
|
@ -6,21 +6,33 @@
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
|
||||
#include "openvino/reference/utils/type_util.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
template <typename T>
|
||||
typename std::enable_if<!std::is_integral<T>::value>::type sqrt(const T* arg, T* out, size_t count) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
out[i] = std::sqrt(arg[i]);
|
||||
namespace func {
|
||||
template <class T, typename std::enable_if<ov::is_floating_point<T>()>::type* = nullptr>
|
||||
T sqrt(const T in) {
|
||||
return std::sqrt(in);
|
||||
}
|
||||
|
||||
template <class T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
|
||||
T sqrt(const T in) {
|
||||
return static_cast<T>(std::round(std::sqrt(in)));
|
||||
}
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value>::type sqrt(const T* arg, T* out, size_t count) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
out[i] = static_cast<T>(std::round(std::sqrt(arg[i])));
|
||||
}
|
||||
} // namespace func
|
||||
|
||||
/**
|
||||
* @brief Reference implementation of Sqrt operator.
|
||||
*
|
||||
* @param arg Pointer to input data.
|
||||
* @param out Pointer to output data.
|
||||
* @param count Number of elements in input buffer.
|
||||
*/
|
||||
template <class T>
|
||||
void sqrt(const T* arg, T* out, const size_t count) {
|
||||
std::transform(arg, arg + count, out, func::sqrt<T>);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace ov
|
||||
|
@ -2,80 +2,66 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/sqrt.hpp"
|
||||
#include "openvino/op/sqrt.hpp"
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/reference/sqrt.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace sqrt {
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
op::Sqrt::Sqrt(const Output<Node>& arg) : UnaryElementwiseArithmetic(arg) {
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& arg0, Tensor& out, const size_t count) {
|
||||
reference::sqrt(arg0.data<const T>(), out.data<T>(), count);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace sqrt
|
||||
|
||||
namespace v0 {
|
||||
Sqrt::Sqrt(const Output<Node>& arg) : UnaryElementwiseArithmetic(arg) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v0::Sqrt::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v0_Sqrt_visit_attrinutes);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::Sqrt::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Sqrt::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v0_Sqrt_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<Sqrt>(new_args.at(0));
|
||||
return std::make_shared<Sqrt>(new_args.at(0));
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace sqrtop {
|
||||
namespace {
|
||||
template <element::Type_t ET>
|
||||
inline bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count) {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
ov::reference::sqrt<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_sqrt(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count) {
|
||||
bool rc = true;
|
||||
out->set_unary(arg0);
|
||||
switch (arg0->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, i32, arg0, out, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, i64, arg0, out, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, u32, arg0, out, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, u64, arg0, out, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, f16, arg0, out, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, f32, arg0, out, count);
|
||||
OPENVINO_TYPE_CASE(evaluate_sqrt, f64, arg0, out, count);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace sqrtop
|
||||
|
||||
bool op::Sqrt::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool Sqrt::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v0_Sqrt_evaluate);
|
||||
return sqrtop::evaluate_sqrt(inputs[0], outputs[0], shape_size(inputs[0]->get_shape()));
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
OPENVINO_ASSERT(inputs.size() == 1);
|
||||
|
||||
const auto& in_shape = inputs[0].get_shape();
|
||||
outputs[0].set_shape(in_shape);
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<f16, f32, f64, i32, i64, u32, u64>::apply<sqrt::Evaluate>(inputs[0].get_element_type(),
|
||||
inputs[0],
|
||||
outputs[0],
|
||||
shape_size(in_shape));
|
||||
}
|
||||
|
||||
bool op::Sqrt::has_evaluate() const {
|
||||
bool Sqrt::has_evaluate() const {
|
||||
OV_OP_SCOPE(v0_Sqrt_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case ngraph::element::f64:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::f64:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace v0
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user