Repalce HostTensor with ov::Tensor

This commit is contained in:
Tomasz Jankowski 2023-11-23 10:25:19 +01:00
parent afbf8461ed
commit 19c2da736e
2 changed files with 54 additions and 69 deletions

View File

@ -60,9 +60,7 @@ public:
const AutoBroadcastSpec& get_autob() const override {
return m_auto_broadcast;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_upper(TensorVector& outputs) const override;
bool evaluate_lower(TensorVector& outputs) const override;
bool has_evaluate() const override;

View File

@ -7,16 +7,43 @@
#include <memory>
#include "bound_evaluate.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/validation_util.hpp" // tbr
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/reference/select.hpp"
#include "select_shape_inference.hpp"
using namespace ngraph;
namespace ov {
namespace op {
namespace select {
struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t DATA_ET,
class DT = fundamental_type_for<DATA_ET>,
class BT = fundamental_type_for<element::Type_t::boolean>>
static result_type visit(const Tensor& cond_input,
const Tensor& then_input,
const Tensor& else_input,
Tensor& output,
const Shape& cond_shape,
const Shape& then_shape,
const Shape& else_shape,
const AutoBroadcastSpec& auto_broadcast) {
using namespace ov::element;
reference::select(cond_input.data<const BT>(),
then_input.data<const DT>(),
else_input.data<const DT>(),
output.data<DT>(),
cond_shape,
then_shape,
else_shape,
auto_broadcast);
return true;
}
};
} // namespace select
namespace v1 {
Select::Select(const Output<Node>& arg0,
const Output<Node>& arg1,
@ -61,70 +88,30 @@ bool Select::visit_attributes(AttributeVisitor& visitor) {
return true;
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace detail {
namespace {
template <element::Type_t ET>
bool evaluate(const HostTensorVector& output_values,
const HostTensorVector& input_values,
const AutoBroadcastSpec& autob) {
using T = typename element_type_traits<ET>::value_type;
const auto& in_cond = input_values[0];
const auto& in_then = input_values[1];
const auto& in_else = input_values[2];
const auto& out = output_values[0];
reference::select<T>(in_cond->get_data_ptr<char>(),
in_then->get_data_ptr<T>(),
in_else->get_data_ptr<T>(),
out->get_data_ptr<T>(),
in_cond->get_shape(),
in_then->get_shape(),
in_else->get_shape(),
autob);
return true;
}
bool evaluate_select(const HostTensorVector& output_values,
const HostTensorVector& input_values,
const AutoBroadcastSpec& autob,
const element::Type_t& et) {
bool rc = false;
switch (et) {
OPENVINO_TYPE_CASE(evaluate_select, i8, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, i16, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, i32, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, i64, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, u8, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, u16, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, u32, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, u64, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, bf16, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, f16, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, f32, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, f64, output_values, input_values, autob);
OPENVINO_TYPE_CASE(evaluate_select, boolean, output_values, input_values, autob);
default:
rc = false;
break;
}
return rc;
}
} // namespace
} // namespace detail
bool Select::evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const {
bool Select::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_Select_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(validate_host_tensor_vector(input_values, 3));
OPENVINO_ASSERT(validate_host_tensor_vector(output_values, 1));
OPENVINO_SUPPRESS_DEPRECATED_END
const auto autob = get_auto_broadcast();
return detail::evaluate_select(output_values, input_values, autob, output_values[0]->get_element_type());
OPENVINO_ASSERT(inputs.size() == 3);
OPENVINO_ASSERT(outputs.size() == 1);
const auto& cond_input = inputs[0];
const auto& then_input = inputs[1];
const auto& else_input = inputs[2];
const auto output_shape = shape_infer(this, ov::util::get_tensors_partial_shapes(inputs)).front().to_shape();
auto& output = outputs[0];
output.set_shape(output_shape);
using namespace ov::element;
return IfTypeOf<boolean, bf16, f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64>::apply<select::Evaluate>(
then_input.get_element_type(),
cond_input,
then_input,
else_input,
output,
cond_input.get_shape(),
then_input.get_shape(),
else_input.get_shape(),
m_auto_broadcast);
}
bool Select::evaluate_lower(TensorVector& output_values) const {