[core] Migrate ScatterUpdate operator to new API (#21241)

* Drop legacy stuff

* Repalce HostTensor with ov::Tensor

* Use dedicated function to obtain Tensor data
This commit is contained in:
Tomasz Jankowski 2023-11-28 09:34:43 +01:00 committed by GitHub
parent cf58a83094
commit 21201833ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 91 deletions

View File

@ -32,16 +32,11 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& outputs) const override;
bool evaluate_upper(TensorVector& outputs) const override;
bool evaluate_label(TensorLabelVector& output_labels) const override;
bool has_evaluate() const override;
private:
bool evaluate_scatter_update(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
};
} // namespace v3
} // namespace op

View File

@ -2,134 +2,116 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/scatter_update.hpp"
#include "openvino/op/scatter_update.hpp"
#include "bound_evaluate.hpp"
#include "itt.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/core/shape.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/core/type/element_type_traits.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/reference/scatter_update.hpp"
#include "utils.hpp"
#include "validation_util.hpp"
using namespace std;
using namespace ngraph;
op::v3::ScatterUpdate::ScatterUpdate(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& updates,
const Output<Node>& axis)
namespace ov {
namespace op {
namespace v3 {
ScatterUpdate::ScatterUpdate(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& updates,
const Output<Node>& axis)
: util::ScatterBase(data, indices, updates, axis) {}
shared_ptr<Node> op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> ScatterUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_ScatterUpdate_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v3::ScatterUpdate>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
return std::make_shared<ScatterUpdate>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace scatter_update {
namespace {
template <element::Type_t ET>
std::vector<int64_t> get_indices(const HostTensorPtr& in) {
auto data_ptr = in->get_data_ptr<ET>();
return std::vector<int64_t>(data_ptr, data_ptr + in->get_element_count());
}
} // namespace
} // namespace scatter_update
bool ScatterUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate);
OPENVINO_ASSERT(inputs.size() == 4);
OPENVINO_ASSERT(outputs.size() == 1);
#define GET_INDICES(a, ...) \
case element::Type_t::a: { \
OV_OP_SCOPE(OV_PP_CAT3(get_scatter_update_indices, _, a)); \
indices_casted_vector = scatter_update::get_indices<element::Type_t::a>(__VA_ARGS__); \
} break;
bool op::v3::ScatterUpdate::evaluate_scatter_update(const HostTensorVector& outputs,
const HostTensorVector& inputs) const {
const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
const auto& axis = inputs[3];
const auto& out = outputs[0];
auto& output = outputs[0];
const auto elem_size = data->get_element_type().size();
out->set_shape(data->get_shape());
OPENVINO_ASSERT(axis.get_element_type().is_integral_number(), "axis element type is not integral data type");
OPENVINO_ASSERT(axis->get_element_type().is_integral_number(), "axis element type is not integral data type");
OPENVINO_SUPPRESS_DEPRECATED_START
int64_t axis_val = host_tensor_2_vector<int64_t>(axis)[0];
if (axis_val < 0) {
axis_val = ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
}
OPENVINO_SUPPRESS_DEPRECATED_END
std::vector<int64_t> indices_casted_vector;
switch (indices->get_element_type()) {
GET_INDICES(i8, indices);
GET_INDICES(i16, indices);
GET_INDICES(i32, indices);
GET_INDICES(i64, indices);
GET_INDICES(u8, indices);
GET_INDICES(u16, indices);
GET_INDICES(u32, indices);
GET_INDICES(u64, indices);
switch (indices.get_element_type()) {
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
break;
default:
return false;
}
ov::reference::scatter_update(data->get_data_ptr<char>(),
indices_casted_vector.data(),
updates->get_data_ptr<char>(),
axis_val,
out->get_data_ptr<char>(),
elem_size,
data->get_shape(),
indices->get_shape(),
updates->get_shape());
const auto& data_shape = data.get_shape();
output.set_shape(data_shape);
auto axis_val = get_tensor_data_as<int64_t>(axis)[0];
OPENVINO_SUPPRESS_DEPRECATED_START
axis_val = ov::normalize_axis(this, axis_val, static_cast<int64_t>(data_shape.size()));
OPENVINO_SUPPRESS_DEPRECATED_END
const auto indices_casted_vector = get_tensor_data_as<int64_t>(indices);
reference::scatter_update(static_cast<const char*>(data.data()),
indices_casted_vector.data(),
static_cast<const char*>(updates.data()),
axis_val,
static_cast<char*>(output.data()),
data.get_element_type().size(),
data.get_shape(),
indices.get_shape(),
updates.get_shape());
return true;
}
bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate);
return evaluate_scatter_update(outputs, inputs);
}
bool op::v3::ScatterUpdate::evaluate_lower(ov::TensorVector& outputs) const {
bool ScatterUpdate::evaluate_lower(TensorVector& outputs) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate_lower);
return get_input_tensor(1).has_and_set_bound() && get_input_tensor(3).has_and_set_bound() &&
default_lower_bound_evaluator(this, outputs);
}
bool op::v3::ScatterUpdate::evaluate_upper(ov::TensorVector& outputs) const {
bool ScatterUpdate::evaluate_upper(TensorVector& outputs) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate_upper);
return get_input_tensor(1).has_and_set_bound() && get_input_tensor(3).has_and_set_bound() &&
default_upper_bound_evaluator(this, outputs);
}
bool op::v3::ScatterUpdate::has_evaluate() const {
bool ScatterUpdate::has_evaluate() const {
OV_OP_SCOPE(v3_ScatterUpdate_has_evaluate);
switch (get_input_element_type(1)) {
case ngraph::element::i8:
case ngraph::element::i16:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u8:
case ngraph::element::u16:
case ngraph::element::u32:
case ngraph::element::u64:
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
return true;
default:
break;
return false;
}
return false;
}
bool op::v3::ScatterUpdate::evaluate_label(TensorLabelVector& output_labels) const {
bool ScatterUpdate::evaluate_label(TensorLabelVector& output_labels) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate_label);
OPENVINO_SUPPRESS_DEPRECATED_START
return ov::default_label_evaluator(this, {0, 2}, output_labels);
OPENVINO_SUPPRESS_DEPRECATED_END
return default_label_evaluator(this, {0, 2}, output_labels);
}
} // namespace v3
} // namespace op
} // namespace ov