[core]Migrate Range operator to new API (#21259)

* Migrate Range operator to new API
- remove legacy function an duplicated shape inference
- Minor change range reference implementation

* Move accessing tensors after validation
This commit is contained in:
Pawel Raasz 2023-11-30 08:27:08 +01:00 committed by GitHub
parent eec370a88b
commit 8ee8f4e112
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 326 deletions

View File

@ -32,9 +32,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
void set_output_type(element::Type output_type) {
m_output_type = output_type;
@ -75,9 +73,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v0

View File

@ -4,37 +4,49 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <type_traits>
#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"
#include "openvino/reference/utils/coordinate_transform.hpp"
#include "openvino/reference/utils/type_util.hpp"
namespace ov {
namespace reference {
// Return type is `void`, only enabled if `T` is a built-in FP
// type, or OpenVINO's `bfloat16` or `float16` type.
/**
* @brief Reference implementation for Range operator (floating-point types).
*
* @param start Start value.
* @param step Step is difference value for consecutive values.
* @param num_elem Number of elements to generate
* @param out Pointer to output data.
*/
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value>::type
range(const T* start, const T* step, const size_t& num_elem, T* out) {
for (size_t i = 0; i < num_elem; i++) {
out[i] = *start + (static_cast<T>(i) * (*step));
typename std::enable_if<ov::is_floating_point<T>()>::type range(const T start,
const T step,
const size_t num_elem,
T* out) {
for (size_t i = 0; i < num_elem; ++i) {
out[i] = start + (static_cast<T>(i) * (step));
}
}
// Return type is `void`, only enabled if `T` is `is_integral`.
/**
* @brief Reference implementation for Range operator (integral types).
*
* @param start Start value.
* @param step Step is difference value for consecutive values.
* @param num_elem Number of elements to generate
* @param out Pointer to output data.
*/
template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type range(const T* start,
const T* step,
const size_t& num_elem,
typename std::enable_if<std::is_integral<T>::value>::type range(const T start,
const T step,
const size_t num_elem,
T* out) {
T val = *start;
for (size_t i = 0; i < num_elem; i++) {
auto val = start;
for (size_t i = 0; i < num_elem; ++i, val += step) {
out[i] = val;
val += *step;
}
}
} // namespace reference

View File

@ -2,64 +2,72 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/range.hpp"
#include <algorithm>
#include <ngraph/validation_util.hpp>
#include "openvino/op/range.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "openvino/reference/range.hpp"
#include "range_shape_inference.hpp"
using namespace std;
using namespace ngraph;
namespace ov {
namespace op {
namespace range {
//
// The code in the following three functions is a bit awkward, to work around some compiler
// warnings and the need to support our custom float16/bfloat16 type:
//
// (1) We can't use STL things like isnan, because our custom float16/bfloat16 types don't always
// support them.
// (2) We check whether (x - x) == (x - x) to check for "is_finite".
// (3) We have to break (x - x) out into a temporary because otherwise the compiler throws a
// warning about == on floats.
// (4) We check <0 || >0 to check for != 0, because otherwise the compiler throws a warning about
// == on floats.
//
template <typename T>
static typename std::enable_if<std::is_integral<T>::value, bool>::type check_value(T value) {
// Nothing to check for integral types.
return true;
#define RANGE_ET_LIST bf16, f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const double start, const double step, const size_t count, Tensor& out) {
reference::range(static_cast<T>(start), static_cast<T>(step), count, out.data<T>());
return true;
}
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& start, const Tensor& step, const size_t count, Tensor& out) {
reference::range(*start.data<const T>(), *step.data<const T>(), count, out.data<T>());
return true;
}
};
namespace {
bool is_input_valid_et(const element::Type& et) {
switch (et) {
case element::bf16:
case element::f16:
case element::f32:
case element::f64:
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
return true;
default:
return false;
}
}
} // namespace
} // namespace range
template <typename T>
static typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
std::is_same<T, bfloat16>::value,
bool>::type
check_value(T value) {
T value_minus_value = value - value;
return value == value && value_minus_value == value_minus_value;
}
op::v4::Range::Range(const Output<Node>& start,
const Output<Node>& stop,
const Output<Node>& step,
element::Type output_type)
namespace v4 {
Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step, element::Type output_type)
: Op({start, stop, step}),
m_output_type(output_type) {
constructor_validate_and_infer_types();
}
bool ngraph::op::v4::Range::visit_attributes(AttributeVisitor& visitor) {
bool Range::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v4_Range_visit_attributes);
visitor.on_attribute("output_type", m_output_type);
return true;
}
void op::v4::Range::validate_and_infer_types() {
void Range::validate_and_infer_types() {
OV_OP_SCOPE(v4_Range_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
m_output_type.is_integral_number() || m_output_type.is_real(),
@ -87,258 +95,55 @@ void op::v4::Range::validate_and_infer_types() {
for (size_t i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
const auto result_shapes = op::v4::shape_infer(this, input_shapes);
const auto result_shapes = shape_infer(this, input_shapes);
set_output_type(0, m_output_type, result_shapes[0]);
}
shared_ptr<Node> op::v4::Range::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Range::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v4_Range_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v4::Range>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
return std::make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
}
template <typename T>
bool get_casted_value(const HostTensorPtr& tensor, T* val) {
switch (tensor->get_element_type()) {
case element::Type_t::bf16:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::bf16>());
break;
case element::Type_t::f16:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::f16>());
break;
case element::Type_t::f32:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::f32>());
break;
case element::Type_t::i8:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::i8>());
break;
case element::Type_t::i32:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::i32>());
break;
case element::Type_t::i64:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::i64>());
break;
case element::Type_t::u8:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::u8>());
break;
case element::Type_t::u32:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::u32>());
break;
case element::Type_t::u64:
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::u64>());
break;
default:
return false;
}
return true;
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace rangeop {
namespace {
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& out,
const HostTensorPtr& start,
const HostTensorPtr& stop,
const HostTensorPtr& step,
int version) {
using T = typename element_type_traits<ET>::value_type;
double start_val;
double stop_val;
double step_val;
if (version < 4) {
start_val = static_cast<double>(*start->get_data_ptr<ET>());
stop_val = static_cast<double>(*stop->get_data_ptr<ET>());
step_val = static_cast<double>(*step->get_data_ptr<ET>());
if (!(check_value(start_val) && check_value(stop_val) && check_value(step_val) &&
(step_val != static_cast<T>(0)))) {
return false;
}
} else {
if (!(get_casted_value<double>(start, &start_val) && get_casted_value<double>(stop, &stop_val) &&
get_casted_value<double>(step, &step_val))) {
return false;
}
}
int64_t out_size = 0;
if (ov::element::Type(ET).is_integral_number()) {
start_val = std::trunc(start_val);
stop_val = std::trunc(stop_val);
step_val = std::trunc(step_val);
}
int64_t steps = static_cast<int64_t>(std::ceil(double(stop_val - start_val) / step_val));
if (steps > 0) {
out_size = steps;
}
ov::Shape out_shape = ov::Shape({static_cast<size_t>(out_size)});
out->set_shape(out_shape);
T start_val_casted = static_cast<T>(start_val);
T step_val_casted = static_cast<T>(step_val);
ov::reference::range(&start_val_casted, &step_val_casted, shape_size(out_shape), out->get_data_ptr<ET>());
return true;
}
bool evaluate_range(const HostTensorPtr& out,
const HostTensorPtr& start,
const HostTensorPtr& stop,
const HostTensorPtr& step,
const element::Type& output_type,
int version) {
bool rc = true;
switch (output_type) {
OPENVINO_TYPE_CASE(evaluate_range, bf16, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, f16, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, f32, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, f64, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, i8, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, i16, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, i32, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, i64, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, u8, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, u16, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, u32, out, start, stop, step, version);
OPENVINO_TYPE_CASE(evaluate_range, u64, out, start, stop, step, version);
default:
rc = false;
break;
}
return rc;
}
} // namespace
} // namespace rangeop
bool op::v4::Range::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool Range::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v4_Range_evaluate);
HostTensorPtr out = outputs[0];
HostTensorPtr start = inputs[0];
HostTensorPtr stop = inputs[1];
HostTensorPtr step = inputs[2];
return rangeop::evaluate_range(out, start, stop, step, m_output_type, 4);
OPENVINO_ASSERT(outputs.size() == 1);
const auto out_shape =
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs))[0].to_shape();
auto& out = outputs[0];
out.set_shape(out_shape);
const auto start = get_tensor_data_as<double>(inputs[0])[0];
const auto step = get_tensor_data_as<double>(inputs[2])[0];
using namespace ov::element;
return IfTypeOf<RANGE_ET_LIST>::apply<range::Evaluate>(out.get_element_type(),
start,
step,
shape_size(out_shape),
out);
}
bool op::v4::Range::has_evaluate() const {
bool Range::has_evaluate() const {
OV_OP_SCOPE(v4_Range_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::bf16:
case ngraph::element::f16:
case ngraph::element::f32:
case ngraph::element::f64:
case ngraph::element::i8:
case ngraph::element::i16:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u8:
case ngraph::element::u16:
case ngraph::element::u32:
case ngraph::element::u64:
return true;
default:
break;
}
return false;
return range::is_input_valid_et(get_input_element_type(0));
}
} // namespace v4
op::v0::Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step)
: Op({start, stop, step}) {
namespace v0 {
Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step) : Op({start, stop, step}) {
constructor_validate_and_infer_types();
}
template <typename T>
static void check_start(const op::v0::Range* node, T start) {
NODE_VALIDATION_CHECK(node, check_value(start), "'start' cannot be nan or infinite.");
}
template <typename T>
void check_stop(const op::v0::Range* node, T stop) {
NODE_VALIDATION_CHECK(node, check_value(stop), "'stop' cannot be nan or infinite.");
}
template <typename T>
void static check_step(const op::v0::Range* node, T step) {
NODE_VALIDATION_CHECK(node,
check_value(step) && ((step > static_cast<T>(0) || step < static_cast<T>(0))),
"'step' cannot be zero, nan, or infinite.");
}
template <typename T>
static typename std::enable_if<std::is_integral<T>::value, T>::type adjust_for_step_and_sign(T span, T step) {
return ceil_div(span < 0 ? -static_cast<typename std::make_signed<T>::type>(span) : span,
step < 0 ? -static_cast<typename std::make_signed<T>::type>(step) : step);
}
template <typename T>
static typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
std::is_same<T, bfloat16>::value,
T>::type
adjust_for_step_and_sign(T span, T step) {
return ceil(fabs(span) / fabs(step));
}
template <typename T>
static ov::PartialShape infer_output_shape(const op::v0::Range* node, const element::Type& /* et */) {
OPENVINO_SUPPRESS_DEPRECATED_START
auto const_start = get_constant_from_source(node->input_value(0));
auto const_stop = get_constant_from_source(node->input_value(1));
auto const_step = get_constant_from_source(node->input_value(2));
OPENVINO_SUPPRESS_DEPRECATED_END
T start = static_cast<T>(0);
T stop = static_cast<T>(0);
T step = static_cast<T>(0);
if (const_start != nullptr) {
std::vector<T> start_val = const_start->get_vector<T>();
NODE_VALIDATION_CHECK(node, start_val.size() == 1);
start = start_val[0];
check_start<T>(node, start);
}
if (const_stop != nullptr) {
std::vector<T> stop_val = const_stop->get_vector<T>();
NODE_VALIDATION_CHECK(node, stop_val.size() == 1);
stop = stop_val[0];
check_stop<T>(node, stop);
}
if (const_step != nullptr) {
std::vector<T> step_val = const_step->get_vector<T>();
NODE_VALIDATION_CHECK(node, step_val.size() == 1);
step = step_val[0];
check_step<T>(node, step);
}
ov::PartialShape result{ov::PartialShape::dynamic(1)};
if (const_start != nullptr && const_stop != nullptr && const_step != nullptr) {
T span;
if (step > static_cast<T>(0) && start >= stop) {
span = static_cast<T>(0);
} else if (step < static_cast<T>(0) && start <= stop) {
span = static_cast<T>(0);
} else {
span = stop - start;
}
T strided = adjust_for_step_and_sign<T>(span, step);
result = ov::PartialShape{Dimension(static_cast<int64_t>(strided))};
}
return result;
}
bool ngraph::op::v0::Range::visit_attributes(AttributeVisitor& visitor) {
bool Range::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v0_Range_visit_attributes);
return true;
}
void op::v0::Range::validate_and_infer_types() {
void Range::validate_and_infer_types() {
OV_OP_SCOPE(v0_Range_validate_and_infer_types);
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
@ -369,48 +174,42 @@ void op::v0::Range::validate_and_infer_types() {
for (size_t i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
const auto result_shapes = op::v0::shape_infer(this, input_shapes);
const auto result_shapes = shape_infer(this, input_shapes);
set_output_type(0, result_et, result_shapes[0]);
}
}
shared_ptr<Node> op::v0::Range::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Range::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v0_Range_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2));
return std::make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2));
}
template <element::Type_t ET, typename T>
void positive_range(T start_val, T stop_val, T step_val) {}
bool op::v0::Range::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool Range::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_Range_evaluate);
HostTensorPtr out = outputs[0];
HostTensorPtr start = inputs[0];
HostTensorPtr stop = inputs[1];
HostTensorPtr step = inputs[2];
return rangeop::evaluate_range(out, start, stop, step, start->get_element_type(), 0);
OPENVINO_ASSERT(outputs.size() == 1);
const auto out_shape =
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs))[0].to_shape();
const auto& start = inputs[0];
const auto& step = inputs[2];
auto& out = outputs[0];
out.set_shape(out_shape);
using namespace ov::element;
return IfTypeOf<RANGE_ET_LIST>::apply<range::Evaluate>(out.get_element_type(),
start,
step,
shape_size(out_shape),
out);
}
bool op::v0::Range::has_evaluate() const {
bool Range::has_evaluate() const {
OV_OP_SCOPE(v0_Range_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::bf16:
case ngraph::element::f16:
case ngraph::element::f32:
case ngraph::element::f64:
case ngraph::element::i8:
case ngraph::element::i16:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u8:
case ngraph::element::u16:
case ngraph::element::u32:
case ngraph::element::u64:
return true;
default:
break;
}
return false;
return range::is_input_valid_et(get_input_element_type(0));
}
} // namespace v0
} // namespace op
} // namespace ov