Repalce HostTensor with ov::Tensor
This commit is contained in:
parent
afbf8461ed
commit
19c2da736e
@ -60,9 +60,7 @@ public:
|
||||
const AutoBroadcastSpec& get_autob() const override {
|
||||
return m_auto_broadcast;
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool evaluate_upper(TensorVector& outputs) const override;
|
||||
bool evaluate_lower(TensorVector& outputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
@ -7,16 +7,43 @@
|
||||
#include <memory>
|
||||
|
||||
#include "bound_evaluate.hpp"
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/validation_util.hpp" // tbr
|
||||
#include "openvino/core/attribute_visitor.hpp"
|
||||
#include "openvino/reference/select.hpp"
|
||||
#include "select_shape_inference.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace select {
|
||||
struct Evaluate : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t DATA_ET,
|
||||
class DT = fundamental_type_for<DATA_ET>,
|
||||
class BT = fundamental_type_for<element::Type_t::boolean>>
|
||||
static result_type visit(const Tensor& cond_input,
|
||||
const Tensor& then_input,
|
||||
const Tensor& else_input,
|
||||
Tensor& output,
|
||||
const Shape& cond_shape,
|
||||
const Shape& then_shape,
|
||||
const Shape& else_shape,
|
||||
const AutoBroadcastSpec& auto_broadcast) {
|
||||
using namespace ov::element;
|
||||
reference::select(cond_input.data<const BT>(),
|
||||
then_input.data<const DT>(),
|
||||
else_input.data<const DT>(),
|
||||
output.data<DT>(),
|
||||
cond_shape,
|
||||
then_shape,
|
||||
else_shape,
|
||||
auto_broadcast);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace select
|
||||
|
||||
namespace v1 {
|
||||
Select::Select(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
@ -61,70 +88,30 @@ bool Select::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace detail {
|
||||
namespace {
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorVector& output_values,
|
||||
const HostTensorVector& input_values,
|
||||
const AutoBroadcastSpec& autob) {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
|
||||
const auto& in_cond = input_values[0];
|
||||
const auto& in_then = input_values[1];
|
||||
const auto& in_else = input_values[2];
|
||||
|
||||
const auto& out = output_values[0];
|
||||
|
||||
reference::select<T>(in_cond->get_data_ptr<char>(),
|
||||
in_then->get_data_ptr<T>(),
|
||||
in_else->get_data_ptr<T>(),
|
||||
out->get_data_ptr<T>(),
|
||||
in_cond->get_shape(),
|
||||
in_then->get_shape(),
|
||||
in_else->get_shape(),
|
||||
autob);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_select(const HostTensorVector& output_values,
|
||||
const HostTensorVector& input_values,
|
||||
const AutoBroadcastSpec& autob,
|
||||
const element::Type_t& et) {
|
||||
bool rc = false;
|
||||
|
||||
switch (et) {
|
||||
OPENVINO_TYPE_CASE(evaluate_select, i8, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, i16, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, i32, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, i64, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, u8, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, u16, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, u32, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, u64, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, bf16, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, f16, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, f32, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, f64, output_values, input_values, autob);
|
||||
OPENVINO_TYPE_CASE(evaluate_select, boolean, output_values, input_values, autob);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace detail
|
||||
|
||||
bool Select::evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const {
|
||||
bool Select::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_Select_evaluate);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(input_values, 3));
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(output_values, 1));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
const auto autob = get_auto_broadcast();
|
||||
return detail::evaluate_select(output_values, input_values, autob, output_values[0]->get_element_type());
|
||||
OPENVINO_ASSERT(inputs.size() == 3);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
const auto& cond_input = inputs[0];
|
||||
const auto& then_input = inputs[1];
|
||||
const auto& else_input = inputs[2];
|
||||
|
||||
const auto output_shape = shape_infer(this, ov::util::get_tensors_partial_shapes(inputs)).front().to_shape();
|
||||
auto& output = outputs[0];
|
||||
output.set_shape(output_shape);
|
||||
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<boolean, bf16, f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64>::apply<select::Evaluate>(
|
||||
then_input.get_element_type(),
|
||||
cond_input,
|
||||
then_input,
|
||||
else_input,
|
||||
output,
|
||||
cond_input.get_shape(),
|
||||
then_input.get_shape(),
|
||||
else_input.get_shape(),
|
||||
m_auto_broadcast);
|
||||
}
|
||||
|
||||
bool Select::evaluate_lower(TensorVector& output_values) const {
|
||||
|
Loading…
Reference in New Issue
Block a user