[core]Migrate Gathers operators to new API (#21390)
* Migrate Gather operators to new API * Remove redundant code form reference * Use IF_TYPE_OF macro * Remove unused include * Use common utils in gather base * Fix normalize after merge issues
This commit is contained in:
parent
e4c38e3afd
commit
635f5d373d
@ -27,10 +27,7 @@ public:
|
||||
void validate_and_infer_types() override;
|
||||
virtual int64_t get_axis() const;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool evaluate_lower(TensorVector& outputs) const override;
|
||||
bool evaluate_upper(TensorVector& outputs) const override;
|
||||
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
||||
|
@ -28,7 +28,6 @@ void gather(const T* const data,
|
||||
|
||||
int64_t batch_data_mul = shape_size(span(data_shape).subspan(batch_dims));
|
||||
int64_t batch_out_mul = shape_size(span(out_shape).subspan(batch_dims));
|
||||
int64_t batch_indices_mul = shape_size(span(indices_shape).subspan(batch_dims));
|
||||
|
||||
int64_t axis_size = data_shape[axis];
|
||||
int64_t data_offset, out_offset, idx;
|
||||
@ -40,7 +39,7 @@ void gather(const T* const data,
|
||||
data_offset = batch_data_mul * batch + inner_size * axis_size * outer_idx;
|
||||
out_offset = batch_out_mul * batch + indices_size * inner_size * outer_idx;
|
||||
for (int64_t i = 0; i < indices_size; i++) {
|
||||
idx = indices[i + batch_indices_mul * batch];
|
||||
idx = indices[i + indices_size * batch];
|
||||
if (idx < 0)
|
||||
idx += axis_size;
|
||||
// for out of bound values have to be filled with zeros
|
||||
@ -48,9 +47,8 @@ void gather(const T* const data,
|
||||
continue;
|
||||
|
||||
const auto src_begin = std::next(data, data_offset + inner_size * idx);
|
||||
const auto src_end = std::next(src_begin, inner_size);
|
||||
const auto out_ptr = std::next(out, out_offset + inner_size * i);
|
||||
std::copy(src_begin, src_end, out_ptr);
|
||||
std::copy_n(src_begin, inner_size, out_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,39 +5,37 @@
|
||||
#include "openvino/op/gather.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "validation_util.hpp"
|
||||
|
||||
namespace ov {
|
||||
|
||||
op::v1::Gather::Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axes)
|
||||
namespace op {
|
||||
namespace v1 {
|
||||
Gather::Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axes)
|
||||
: GatherBase(params, indices, axes) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
int64_t op::v1::Gather::get_axis() const {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
if (!get_constant_from_source(input_value(2))) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
return AXIS_NOT_SET_VALUE;
|
||||
}
|
||||
return GatherBase::get_axis();
|
||||
int64_t Gather::get_axis() const {
|
||||
return ov::util::get_constant_from_source(input_value(2)) ? GatherBase::get_axis() : AXIS_NOT_SET_VALUE;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v1_Gather_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<v1::Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
} // namespace v1
|
||||
|
||||
op::v7::Gather::Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
namespace v7 {
|
||||
Gather::Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: GatherBase(data, indices, axis, batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v7::Gather::validate_and_infer_types() {
|
||||
void Gather::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v7_Gather_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(1).is_integral_number(),
|
||||
@ -47,37 +45,39 @@ void op::v7::Gather::validate_and_infer_types() {
|
||||
get_input_element_type(2).is_integral_number(),
|
||||
"Axis element type must be of an integral number type.");
|
||||
|
||||
op::util::GatherBase::validate_and_infer_types();
|
||||
util::GatherBase::validate_and_infer_types();
|
||||
}
|
||||
|
||||
int64_t op::v7::Gather::get_batch_dims() const {
|
||||
int64_t Gather::get_batch_dims() const {
|
||||
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
|
||||
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
|
||||
else
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
bool op::v7::Gather::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Gather::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v7_Gather_visit_attributes);
|
||||
visitor.on_attribute("batch_dims", m_batch_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v7_Gather_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||
}
|
||||
} // namespace v7
|
||||
|
||||
op::v8::Gather::Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
namespace v8 {
|
||||
Gather::Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: GatherBase(data, indices, axis, batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v8::Gather::validate_and_infer_types() {
|
||||
void Gather::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v8_Gather_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(1).is_integral_number(),
|
||||
@ -90,22 +90,24 @@ void op::v8::Gather::validate_and_infer_types() {
|
||||
op::util::GatherBase::validate_and_infer_types();
|
||||
}
|
||||
|
||||
int64_t op::v8::Gather::get_batch_dims() const {
|
||||
int64_t Gather::get_batch_dims() const {
|
||||
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
|
||||
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
|
||||
else
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
bool op::v8::Gather::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Gather::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v8_Gather_visit_attributes);
|
||||
visitor.on_attribute("batch_dims", m_batch_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v8::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v8_Gather_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<v8::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||
}
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -5,145 +5,40 @@
|
||||
#include "openvino/op/util/gather_base.hpp"
|
||||
|
||||
#include "bound_evaluate.hpp"
|
||||
#include "element_visitor.hpp"
|
||||
#include "gather_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/reference/gather.hpp"
|
||||
#include "validation_util.hpp"
|
||||
|
||||
ov::op::util::GatherBase::GatherBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: Op({data, indices, axis}),
|
||||
m_batch_dims(batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void ov::op::util::GatherBase::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(util_GatherBase_validate_and_infer_types);
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
const auto& axis_pshape = get_input_partial_shape(2);
|
||||
std::vector<PartialShape> input_shapes = {data_pshape, indices_pshape, axis_pshape};
|
||||
const auto output_shapes = shape_infer(this, input_shapes);
|
||||
set_output_type(0, data_type, output_shapes[0]);
|
||||
}
|
||||
|
||||
int64_t ov::op::util::GatherBase::get_axis() const {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto& const_op = get_constant_from_source(input_value(2));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
OPENVINO_ASSERT(const_op, "axis value is not set");
|
||||
|
||||
int64_t axis = const_op->cast_vector<int64_t>()[0];
|
||||
if (axis < 0) {
|
||||
const auto& data_rank = get_input_partial_shape(0).rank();
|
||||
if (data_rank.is_static()) {
|
||||
axis += data_rank.get_length();
|
||||
}
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
const int64_t& ov::op::util::GatherBase::get_batch_dims() const {
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
void ov::op::util::GatherBase::set_batch_dims(int64_t batch_dims) {
|
||||
m_batch_dims = batch_dims;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace gather {
|
||||
namespace {
|
||||
template <ov::element::Type_t ET>
|
||||
bool evaluate(const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& arg1,
|
||||
const ngraph::HostTensorPtr& out,
|
||||
int64_t axis,
|
||||
int64_t batch_dims) {
|
||||
using T = typename ov::element_type_traits<ET>::value_type;
|
||||
ov::Shape params_shape = arg0->get_shape();
|
||||
ov::Shape indices_shape = arg1->get_shape();
|
||||
ov::Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims);
|
||||
int64_t i = 0;
|
||||
for (; i < axis; i++) {
|
||||
out_shape[i] = params_shape[i];
|
||||
}
|
||||
for (int64_t j = batch_dims; j < static_cast<int64_t>(indices_shape.size()); i++, j++) {
|
||||
out_shape[i] = indices_shape[j];
|
||||
}
|
||||
for (int64_t j = axis + 1; j < static_cast<int64_t>(params_shape.size()); i++, j++) {
|
||||
out_shape[i] = params_shape[j];
|
||||
}
|
||||
|
||||
out->set_shape(out_shape);
|
||||
Shape out_shape_infer(const Shape& data_shape, const Shape& indices_shape, int64_t axis, int64_t batch_dims) {
|
||||
Shape out_shape;
|
||||
out_shape.reserve(data_shape.size() + indices_shape.size() - 1 - batch_dims);
|
||||
auto out_dim_inserter = std::copy_n(data_shape.begin(), axis, std::back_inserter(out_shape));
|
||||
out_dim_inserter = std::copy(indices_shape.begin() + batch_dims, indices_shape.end(), out_dim_inserter);
|
||||
std::copy(std::next(data_shape.begin(), axis + 1), data_shape.end(), out_dim_inserter);
|
||||
|
||||
if (arg1->get_element_type() == ov::element::i64) {
|
||||
ov::reference::gather<T, int64_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int64_t>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
out->get_shape(),
|
||||
axis,
|
||||
batch_dims);
|
||||
} else if (arg1->get_element_type() == ov::element::i32) {
|
||||
ov::reference::gather<T, int32_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int32_t>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
out->get_shape(),
|
||||
axis,
|
||||
batch_dims);
|
||||
} else {
|
||||
OPENVINO_THROW("Unexpected type ", arg1->get_element_type().c_type_string(), " for Gather evaluate method.");
|
||||
}
|
||||
|
||||
return true;
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool evaluate_gather(const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& arg1,
|
||||
const ngraph::HostTensorPtr& out,
|
||||
int64_t axis,
|
||||
int64_t batch_dims = 0) {
|
||||
bool rc = true;
|
||||
|
||||
using ov::element::Type_t;
|
||||
switch (out->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, i8, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, u8, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims);
|
||||
OPENVINO_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool cf_gather_with_subgraph(ov::OutputVector& output_values,
|
||||
const ov::OutputVector& input_values,
|
||||
const ov::PartialShape& gather_ps) {
|
||||
bool cf_gather_with_subgraph(OutputVector& output_values,
|
||||
const OutputVector& input_values,
|
||||
const PartialShape& gather_ps) {
|
||||
if (gather_ps.is_dynamic() || input_values.size() != 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto concat = std::dynamic_pointer_cast<ov::op::v0::Concat>(input_values[0].get_node_shared_ptr());
|
||||
const auto indices = std::dynamic_pointer_cast<ov::op::v0::Constant>(input_values[1].get_node_shared_ptr());
|
||||
const auto axis = std::dynamic_pointer_cast<ov::op::v0::Constant>(input_values[2].get_node_shared_ptr());
|
||||
const auto concat = std::dynamic_pointer_cast<v0::Concat>(input_values[0].get_node_shared_ptr());
|
||||
const auto indices = std::dynamic_pointer_cast<v0::Constant>(input_values[1].get_node_shared_ptr());
|
||||
const auto axis = std::dynamic_pointer_cast<v0::Constant>(input_values[2].get_node_shared_ptr());
|
||||
|
||||
if (!concat || !indices || !axis) {
|
||||
return false;
|
||||
@ -169,8 +64,8 @@ bool cf_gather_with_subgraph(ov::OutputVector& output_values,
|
||||
}
|
||||
|
||||
const int64_t rank = concat->get_shape()[0];
|
||||
const int64_t raw_index = indices->cast_vector<int64_t>()[0];
|
||||
const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index;
|
||||
const auto raw_index = indices->cast_vector<int64_t>()[0];
|
||||
const auto positive_index = ov::util::normalize(raw_index, rank);
|
||||
OPENVINO_ASSERT(positive_index >= 0 && positive_index < rank);
|
||||
|
||||
// gather takes exactly one element out of the Concat output
|
||||
@ -179,88 +74,164 @@ bool cf_gather_with_subgraph(ov::OutputVector& output_values,
|
||||
auto gathered = gathered_concat_input;
|
||||
if (indices_shape.empty()) {
|
||||
// gathering a scalar
|
||||
const auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
|
||||
gathered = std::make_shared<ov::op::v0::Squeeze>(gathered_concat_input, axis_const);
|
||||
const auto axis_const = v0::Constant::create(element::i64, Shape{1}, {0});
|
||||
gathered = std::make_shared<v0::Squeeze>(gathered_concat_input, axis_const);
|
||||
}
|
||||
|
||||
output_values[0] = gathered;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool have_indices_and_axis_bound_set(const util::GatherBase* const gather) {
|
||||
return ov::have_node_inputs_bounds_set(gather, 1, 2);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct Evaluate : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t DATA_ET, class DT = fundamental_type_for<DATA_ET>>
|
||||
static result_type visit(const Tensor& data,
|
||||
const Tensor& indices,
|
||||
Tensor& out,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& out_shape,
|
||||
const int64_t axis,
|
||||
const int64_t batch_dims) {
|
||||
using namespace ov::element;
|
||||
return IF_TYPE_OF(util_GatherBase_indices_type,
|
||||
OV_PP_ET_LIST(i32, i64),
|
||||
EvaluateByIndicesType,
|
||||
indices.get_element_type(),
|
||||
data.data<const DT>(),
|
||||
indices,
|
||||
out.data<DT>(),
|
||||
data_shape,
|
||||
indices_shape,
|
||||
out_shape,
|
||||
axis,
|
||||
batch_dims);
|
||||
}
|
||||
|
||||
private:
|
||||
struct EvaluateByIndicesType : element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t INDICES_ET, class DT, class IT = fundamental_type_for<INDICES_ET>>
|
||||
static result_type visit(const DT* const data,
|
||||
const Tensor& indices,
|
||||
DT* const output,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& out_shape,
|
||||
const int64_t axis,
|
||||
const int64_t batch_dims) {
|
||||
reference::gather(data,
|
||||
indices.data<const IT>(),
|
||||
output,
|
||||
data_shape,
|
||||
indices_shape,
|
||||
out_shape,
|
||||
axis,
|
||||
batch_dims);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace gather
|
||||
|
||||
bool ov::op::util::GatherBase::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
namespace util {
|
||||
|
||||
GatherBase::GatherBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: Op({data, indices, axis}),
|
||||
m_batch_dims(batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void GatherBase::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(util_GatherBase_validate_and_infer_types);
|
||||
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this));
|
||||
|
||||
set_output_type(0, data_type, output_shapes[0]);
|
||||
}
|
||||
|
||||
int64_t GatherBase::get_axis() const {
|
||||
const auto& const_op = ov::util::get_constant_from_source(input_value(2));
|
||||
OPENVINO_ASSERT(const_op, "axis value is not set");
|
||||
|
||||
const auto axis = const_op->cast_vector<int64_t>()[0];
|
||||
if (axis < 0 && get_input_partial_shape(0).rank().is_static()) {
|
||||
return axis + get_input_partial_shape(0).rank().get_length();
|
||||
} else {
|
||||
return axis;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t& GatherBase::get_batch_dims() const {
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
void GatherBase::set_batch_dims(int64_t batch_dims) {
|
||||
m_batch_dims = batch_dims;
|
||||
}
|
||||
|
||||
bool GatherBase::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(util_GatherBase_evaluate);
|
||||
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
OPENVINO_ASSERT(inputs.size() == 3);
|
||||
|
||||
OPENVINO_ASSERT(inputs[2].get_element_type().is_integral_number(), "axis must be of integral data type.");
|
||||
|
||||
const auto& data = inputs[0];
|
||||
const auto& data_shape = data.get_shape();
|
||||
const auto& indices = inputs[1];
|
||||
const auto& indices_shape = indices.get_shape();
|
||||
|
||||
const auto axis = ov::util::normalize(get_tensor_data_as<int64_t>(inputs[2])[0], data_shape.size());
|
||||
const auto batch_dims = ov::util::normalize(m_batch_dims, indices_shape.size());
|
||||
|
||||
const auto out_shape = gather::out_shape_infer(data_shape, indices_shape, axis, batch_dims);
|
||||
auto& output = outputs[0];
|
||||
output.set_shape(out_shape);
|
||||
|
||||
using namespace ov::element;
|
||||
return IF_TYPE_OF(util_GatherBase_evaluate,
|
||||
OV_PP_ET_LIST(boolean, f16, f32, i8, i32, i64, u8, u32, u64),
|
||||
gather::Evaluate,
|
||||
data.get_element_type(),
|
||||
data,
|
||||
indices,
|
||||
output,
|
||||
data_shape,
|
||||
indices_shape,
|
||||
out_shape,
|
||||
axis,
|
||||
batch_dims);
|
||||
}
|
||||
|
||||
bool GatherBase::evaluate_lower(TensorVector& output_values) const {
|
||||
return gather::have_indices_and_axis_bound_set(this) && default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool GatherBase::evaluate_upper(TensorVector& output_values) const {
|
||||
return gather::have_indices_and_axis_bound_set(this) && default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool GatherBase::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(inputs, 3));
|
||||
OPENVINO_ASSERT(ngraph::validate_host_tensor_vector(outputs, 1));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
int64_t axis = 0;
|
||||
switch (inputs[2]->get_element_type()) {
|
||||
case element::Type_t::i32:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0];
|
||||
break;
|
||||
case element::Type_t::i64:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0];
|
||||
break;
|
||||
case element::Type_t::i8:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::i8>()[0];
|
||||
break;
|
||||
case element::Type_t::i16:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::i16>()[0];
|
||||
break;
|
||||
case element::Type_t::u8:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::u8>()[0];
|
||||
break;
|
||||
case element::Type_t::u16:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::u16>()[0];
|
||||
break;
|
||||
case element::Type_t::u32:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::u32>()[0];
|
||||
break;
|
||||
case element::Type_t::u64:
|
||||
axis = inputs[2]->get_data_ptr<element::Type_t::u64>()[0];
|
||||
break;
|
||||
default:
|
||||
OPENVINO_THROW("axis must be of integral data type.");
|
||||
}
|
||||
|
||||
if (axis < 0) {
|
||||
const auto input_rank = inputs[0]->get_shape().size();
|
||||
axis += input_rank;
|
||||
}
|
||||
|
||||
int64_t batch_dims = m_batch_dims;
|
||||
if (batch_dims < 0) {
|
||||
const auto indices_rank = inputs[1]->get_shape().size();
|
||||
batch_dims += indices_rank;
|
||||
}
|
||||
|
||||
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, batch_dims);
|
||||
}
|
||||
|
||||
bool ov::op::util::GatherBase::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
if (!get_input_tensor(1).has_and_set_bound() || !get_input_tensor(2).has_and_set_bound())
|
||||
return false;
|
||||
return default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool ov::op::util::GatherBase::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
if (!get_input_tensor(1).has_and_set_bound() || !get_input_tensor(2).has_and_set_bound())
|
||||
return false;
|
||||
return default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool ov::op::util::GatherBase::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
if (!get_input_tensor(1).has_and_set_bound() || !get_input_tensor(2).has_and_set_bound())
|
||||
return false;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
return default_label_evaluator(this, output_labels);
|
||||
return gather::have_indices_and_axis_bound_set(this) && default_label_evaluator(this, output_labels);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
bool ov::op::util::GatherBase::constant_fold(OutputVector& output_values, const OutputVector& input_values) {
|
||||
bool GatherBase::constant_fold(OutputVector& output_values, const OutputVector& input_values) {
|
||||
// try the regular constant folding just for the Gather node
|
||||
if (Node::constant_fold(output_values, input_values)) {
|
||||
return true;
|
||||
@ -268,3 +239,6 @@ bool ov::op::util::GatherBase::constant_fold(OutputVector& output_values, const
|
||||
return gather::cf_gather_with_subgraph(output_values, input_values, get_output_partial_shape(0));
|
||||
}
|
||||
}
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user