[core]Migrate Range operator to new API (#21259)
* Migrate Range operator to new API - remove legacy function an duplicated shape inference - Minor change range reference implementation * Move accessing tensors after validation
This commit is contained in:
parent
eec370a88b
commit
8ee8f4e112
@ -32,9 +32,7 @@ public:
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
void set_output_type(element::Type output_type) {
|
||||
m_output_type = output_type;
|
||||
@ -75,9 +73,7 @@ public:
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
};
|
||||
} // namespace v0
|
||||
|
@ -4,37 +4,49 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <type_traits>
|
||||
|
||||
#include "openvino/core/type/bfloat16.hpp"
|
||||
#include "openvino/core/type/float16.hpp"
|
||||
#include "openvino/reference/utils/coordinate_transform.hpp"
|
||||
#include "openvino/reference/utils/type_util.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
// Return type is `void`, only enabled if `T` is a built-in FP
|
||||
// type, or OpenVINO's `bfloat16` or `float16` type.
|
||||
|
||||
/**
|
||||
* @brief Reference implementation for Range operator (floating-point types).
|
||||
*
|
||||
* @param start Start value.
|
||||
* @param step Step is difference value for consecutive values.
|
||||
* @param num_elem Number of elements to generate
|
||||
* @param out Pointer to output data.
|
||||
*/
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, bfloat16>::value ||
|
||||
std::is_same<T, float16>::value>::type
|
||||
range(const T* start, const T* step, const size_t& num_elem, T* out) {
|
||||
for (size_t i = 0; i < num_elem; i++) {
|
||||
out[i] = *start + (static_cast<T>(i) * (*step));
|
||||
typename std::enable_if<ov::is_floating_point<T>()>::type range(const T start,
|
||||
const T step,
|
||||
const size_t num_elem,
|
||||
T* out) {
|
||||
for (size_t i = 0; i < num_elem; ++i) {
|
||||
out[i] = start + (static_cast<T>(i) * (step));
|
||||
}
|
||||
}
|
||||
|
||||
// Return type is `void`, only enabled if `T` is `is_integral`.
|
||||
/**
|
||||
* @brief Reference implementation for Range operator (integral types).
|
||||
*
|
||||
* @param start Start value.
|
||||
* @param step Step is difference value for consecutive values.
|
||||
* @param num_elem Number of elements to generate
|
||||
* @param out Pointer to output data.
|
||||
*/
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value>::type range(const T* start,
|
||||
const T* step,
|
||||
const size_t& num_elem,
|
||||
typename std::enable_if<std::is_integral<T>::value>::type range(const T start,
|
||||
const T step,
|
||||
const size_t num_elem,
|
||||
T* out) {
|
||||
T val = *start;
|
||||
|
||||
for (size_t i = 0; i < num_elem; i++) {
|
||||
auto val = start;
|
||||
for (size_t i = 0; i < num_elem; ++i, val += step) {
|
||||
out[i] = val;
|
||||
val += *step;
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
|
@ -2,64 +2,72 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/range.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include "openvino/op/range.hpp"
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/type/element_type_traits.hpp"
|
||||
#include "openvino/reference/range.hpp"
|
||||
#include "range_shape_inference.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace range {
|
||||
|
||||
//
|
||||
// The code in the following three functions is a bit awkward, to work around some compiler
|
||||
// warnings and the need to support our custom float16/bfloat16 type:
|
||||
//
|
||||
// (1) We can't use STL things like isnan, because our custom float16/bfloat16 types don't always
|
||||
// support them.
|
||||
// (2) We check whether (x - x) == (x - x) to check for "is_finite".
|
||||
// (3) We have to break (x - x) out into a temporary because otherwise the compiler throws a
|
||||
// warning about == on floats.
|
||||
// (4) We check <0 || >0 to check for != 0, because otherwise the compiler throws a warning about
|
||||
// == on floats.
|
||||
//
|
||||
template <typename T>
|
||||
static typename std::enable_if<std::is_integral<T>::value, bool>::type check_value(T value) {
|
||||
// Nothing to check for integral types.
|
||||
return true;
|
||||
#define RANGE_ET_LIST bf16, f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64
|
||||
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const double start, const double step, const size_t count, Tensor& out) {
|
||||
reference::range(static_cast<T>(start), static_cast<T>(step), count, out.data<T>());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||
static result_type visit(const Tensor& start, const Tensor& step, const size_t count, Tensor& out) {
|
||||
reference::range(*start.data<const T>(), *step.data<const T>(), count, out.data<T>());
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
bool is_input_valid_et(const element::Type& et) {
|
||||
switch (et) {
|
||||
case element::bf16:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::f64:
|
||||
case element::i8:
|
||||
case element::i16:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u8:
|
||||
case element::u16:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace range
|
||||
|
||||
template <typename T>
|
||||
static typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
|
||||
std::is_same<T, bfloat16>::value,
|
||||
bool>::type
|
||||
check_value(T value) {
|
||||
T value_minus_value = value - value;
|
||||
return value == value && value_minus_value == value_minus_value;
|
||||
}
|
||||
|
||||
op::v4::Range::Range(const Output<Node>& start,
|
||||
const Output<Node>& stop,
|
||||
const Output<Node>& step,
|
||||
element::Type output_type)
|
||||
namespace v4 {
|
||||
Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step, element::Type output_type)
|
||||
: Op({start, stop, step}),
|
||||
m_output_type(output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v4::Range::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Range::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v4_Range_visit_attributes);
|
||||
visitor.on_attribute("output_type", m_output_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v4::Range::validate_and_infer_types() {
|
||||
void Range::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v4_Range_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_output_type.is_integral_number() || m_output_type.is_real(),
|
||||
@ -87,258 +95,55 @@ void op::v4::Range::validate_and_infer_types() {
|
||||
for (size_t i = 0; i < get_input_size(); i++)
|
||||
input_shapes.push_back(get_input_partial_shape(i));
|
||||
|
||||
const auto result_shapes = op::v4::shape_infer(this, input_shapes);
|
||||
const auto result_shapes = shape_infer(this, input_shapes);
|
||||
|
||||
set_output_type(0, m_output_type, result_shapes[0]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v4::Range::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Range::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v4_Range_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v4::Range>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
|
||||
return std::make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool get_casted_value(const HostTensorPtr& tensor, T* val) {
|
||||
switch (tensor->get_element_type()) {
|
||||
case element::Type_t::bf16:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::bf16>());
|
||||
break;
|
||||
case element::Type_t::f16:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::f16>());
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::f32>());
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::i8>());
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::i32>());
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::i64>());
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::u8>());
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::u32>());
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
*val = static_cast<T>(*tensor->get_data_ptr<element::Type_t::u64>());
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace rangeop {
|
||||
namespace {
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& out,
|
||||
const HostTensorPtr& start,
|
||||
const HostTensorPtr& stop,
|
||||
const HostTensorPtr& step,
|
||||
int version) {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
double start_val;
|
||||
double stop_val;
|
||||
double step_val;
|
||||
if (version < 4) {
|
||||
start_val = static_cast<double>(*start->get_data_ptr<ET>());
|
||||
stop_val = static_cast<double>(*stop->get_data_ptr<ET>());
|
||||
step_val = static_cast<double>(*step->get_data_ptr<ET>());
|
||||
if (!(check_value(start_val) && check_value(stop_val) && check_value(step_val) &&
|
||||
(step_val != static_cast<T>(0)))) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!(get_casted_value<double>(start, &start_val) && get_casted_value<double>(stop, &stop_val) &&
|
||||
get_casted_value<double>(step, &step_val))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t out_size = 0;
|
||||
|
||||
if (ov::element::Type(ET).is_integral_number()) {
|
||||
start_val = std::trunc(start_val);
|
||||
stop_val = std::trunc(stop_val);
|
||||
step_val = std::trunc(step_val);
|
||||
}
|
||||
|
||||
int64_t steps = static_cast<int64_t>(std::ceil(double(stop_val - start_val) / step_val));
|
||||
if (steps > 0) {
|
||||
out_size = steps;
|
||||
}
|
||||
ov::Shape out_shape = ov::Shape({static_cast<size_t>(out_size)});
|
||||
out->set_shape(out_shape);
|
||||
|
||||
T start_val_casted = static_cast<T>(start_val);
|
||||
T step_val_casted = static_cast<T>(step_val);
|
||||
ov::reference::range(&start_val_casted, &step_val_casted, shape_size(out_shape), out->get_data_ptr<ET>());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_range(const HostTensorPtr& out,
|
||||
const HostTensorPtr& start,
|
||||
const HostTensorPtr& stop,
|
||||
const HostTensorPtr& step,
|
||||
const element::Type& output_type,
|
||||
int version) {
|
||||
bool rc = true;
|
||||
switch (output_type) {
|
||||
OPENVINO_TYPE_CASE(evaluate_range, bf16, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, f16, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, f32, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, f64, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, i8, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, i16, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, i32, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, i64, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, u8, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, u16, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, u32, out, start, stop, step, version);
|
||||
OPENVINO_TYPE_CASE(evaluate_range, u64, out, start, stop, step, version);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace rangeop
|
||||
|
||||
bool op::v4::Range::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool Range::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v4_Range_evaluate);
|
||||
HostTensorPtr out = outputs[0];
|
||||
HostTensorPtr start = inputs[0];
|
||||
HostTensorPtr stop = inputs[1];
|
||||
HostTensorPtr step = inputs[2];
|
||||
return rangeop::evaluate_range(out, start, stop, step, m_output_type, 4);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
const auto out_shape =
|
||||
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs))[0].to_shape();
|
||||
auto& out = outputs[0];
|
||||
out.set_shape(out_shape);
|
||||
|
||||
const auto start = get_tensor_data_as<double>(inputs[0])[0];
|
||||
const auto step = get_tensor_data_as<double>(inputs[2])[0];
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<RANGE_ET_LIST>::apply<range::Evaluate>(out.get_element_type(),
|
||||
start,
|
||||
step,
|
||||
shape_size(out_shape),
|
||||
out);
|
||||
}
|
||||
|
||||
bool op::v4::Range::has_evaluate() const {
|
||||
bool Range::has_evaluate() const {
|
||||
OV_OP_SCOPE(v4_Range_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::bf16:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case ngraph::element::f64:
|
||||
case ngraph::element::i8:
|
||||
case ngraph::element::i16:
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u8:
|
||||
case ngraph::element::u16:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
return range::is_input_valid_et(get_input_element_type(0));
|
||||
}
|
||||
} // namespace v4
|
||||
|
||||
op::v0::Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step)
|
||||
: Op({start, stop, step}) {
|
||||
namespace v0 {
|
||||
|
||||
Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step) : Op({start, stop, step}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void check_start(const op::v0::Range* node, T start) {
|
||||
NODE_VALIDATION_CHECK(node, check_value(start), "'start' cannot be nan or infinite.");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check_stop(const op::v0::Range* node, T stop) {
|
||||
NODE_VALIDATION_CHECK(node, check_value(stop), "'stop' cannot be nan or infinite.");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void static check_step(const op::v0::Range* node, T step) {
|
||||
NODE_VALIDATION_CHECK(node,
|
||||
check_value(step) && ((step > static_cast<T>(0) || step < static_cast<T>(0))),
|
||||
"'step' cannot be zero, nan, or infinite.");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static typename std::enable_if<std::is_integral<T>::value, T>::type adjust_for_step_and_sign(T span, T step) {
|
||||
return ceil_div(span < 0 ? -static_cast<typename std::make_signed<T>::type>(span) : span,
|
||||
step < 0 ? -static_cast<typename std::make_signed<T>::type>(step) : step);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static typename std::enable_if<std::is_floating_point<T>::value || std::is_same<T, float16>::value ||
|
||||
std::is_same<T, bfloat16>::value,
|
||||
T>::type
|
||||
adjust_for_step_and_sign(T span, T step) {
|
||||
return ceil(fabs(span) / fabs(step));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static ov::PartialShape infer_output_shape(const op::v0::Range* node, const element::Type& /* et */) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
auto const_start = get_constant_from_source(node->input_value(0));
|
||||
auto const_stop = get_constant_from_source(node->input_value(1));
|
||||
auto const_step = get_constant_from_source(node->input_value(2));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
T start = static_cast<T>(0);
|
||||
T stop = static_cast<T>(0);
|
||||
T step = static_cast<T>(0);
|
||||
|
||||
if (const_start != nullptr) {
|
||||
std::vector<T> start_val = const_start->get_vector<T>();
|
||||
NODE_VALIDATION_CHECK(node, start_val.size() == 1);
|
||||
start = start_val[0];
|
||||
check_start<T>(node, start);
|
||||
}
|
||||
|
||||
if (const_stop != nullptr) {
|
||||
std::vector<T> stop_val = const_stop->get_vector<T>();
|
||||
NODE_VALIDATION_CHECK(node, stop_val.size() == 1);
|
||||
stop = stop_val[0];
|
||||
check_stop<T>(node, stop);
|
||||
}
|
||||
|
||||
if (const_step != nullptr) {
|
||||
std::vector<T> step_val = const_step->get_vector<T>();
|
||||
NODE_VALIDATION_CHECK(node, step_val.size() == 1);
|
||||
step = step_val[0];
|
||||
check_step<T>(node, step);
|
||||
}
|
||||
|
||||
ov::PartialShape result{ov::PartialShape::dynamic(1)};
|
||||
|
||||
if (const_start != nullptr && const_stop != nullptr && const_step != nullptr) {
|
||||
T span;
|
||||
|
||||
if (step > static_cast<T>(0) && start >= stop) {
|
||||
span = static_cast<T>(0);
|
||||
} else if (step < static_cast<T>(0) && start <= stop) {
|
||||
span = static_cast<T>(0);
|
||||
} else {
|
||||
span = stop - start;
|
||||
}
|
||||
|
||||
T strided = adjust_for_step_and_sign<T>(span, step);
|
||||
|
||||
result = ov::PartialShape{Dimension(static_cast<int64_t>(strided))};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ngraph::op::v0::Range::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Range::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v0_Range_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v0::Range::validate_and_infer_types() {
|
||||
void Range::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v0_Range_validate_and_infer_types);
|
||||
set_input_is_relevant_to_shape(0);
|
||||
set_input_is_relevant_to_shape(1);
|
||||
@ -369,48 +174,42 @@ void op::v0::Range::validate_and_infer_types() {
|
||||
for (size_t i = 0; i < get_input_size(); i++)
|
||||
input_shapes.push_back(get_input_partial_shape(i));
|
||||
|
||||
const auto result_shapes = op::v0::shape_infer(this, input_shapes);
|
||||
const auto result_shapes = shape_infer(this, input_shapes);
|
||||
|
||||
set_output_type(0, result_et, result_shapes[0]);
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v0::Range::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Range::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v0_Range_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
return std::make_shared<Range>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
||||
template <element::Type_t ET, typename T>
|
||||
void positive_range(T start_val, T stop_val, T step_val) {}
|
||||
|
||||
bool op::v0::Range::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool Range::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v0_Range_evaluate);
|
||||
HostTensorPtr out = outputs[0];
|
||||
HostTensorPtr start = inputs[0];
|
||||
HostTensorPtr stop = inputs[1];
|
||||
HostTensorPtr step = inputs[2];
|
||||
return rangeop::evaluate_range(out, start, stop, step, start->get_element_type(), 0);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
const auto out_shape =
|
||||
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs))[0].to_shape();
|
||||
const auto& start = inputs[0];
|
||||
const auto& step = inputs[2];
|
||||
|
||||
auto& out = outputs[0];
|
||||
out.set_shape(out_shape);
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<RANGE_ET_LIST>::apply<range::Evaluate>(out.get_element_type(),
|
||||
start,
|
||||
step,
|
||||
shape_size(out_shape),
|
||||
out);
|
||||
}
|
||||
|
||||
bool op::v0::Range::has_evaluate() const {
|
||||
bool Range::has_evaluate() const {
|
||||
OV_OP_SCOPE(v0_Range_has_evaluate);
|
||||
switch (get_input_element_type(0)) {
|
||||
case ngraph::element::bf16:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case ngraph::element::f64:
|
||||
case ngraph::element::i8:
|
||||
case ngraph::element::i16:
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u8:
|
||||
case ngraph::element::u16:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
return range::is_input_valid_et(get_input_element_type(0));
|
||||
}
|
||||
} // namespace v0
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user