[core] Migrate ScatterNDUpdate operator to new API (#21231)

* Drop legacy stuff

* Repalce HostTensor with ov::Tensor
This commit is contained in:
Tomasz Jankowski 2023-11-25 08:33:28 +01:00 committed by GitHub
parent 9f87f72ca6
commit 8231d57c38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 99 deletions

View File

@ -22,9 +22,7 @@ public:
: util::ScatterNDBase(inputs, indices, updates) {}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& output_values) const override;
bool evaluate_upper(TensorVector& output_values) const override;
bool evaluate_label(TensorLabelVector& output_labels) const override;

View File

@ -2,134 +2,134 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/scatter_nd_update.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "bound_evaluate.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/reference/scatter_nd_update.hpp"
using namespace std;
using namespace ngraph;
namespace ov {
namespace op {
namespace scatter_nd_update {
struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
shared_ptr<Node> op::v3::ScatterNDUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
template <element::Type_t DATA_ET, class DT = fundamental_type_for<DATA_ET>>
static result_type visit(const Tensor& data,
const Tensor& indices,
const Tensor& updates,
Tensor& output,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape) {
using namespace ov::element;
return IfTypeOf<i32, i64>::apply<EvaluateByIndicesType>(indices.get_element_type(),
data.data<const DT>(),
indices,
updates.data<const DT>(),
output.data<DT>(),
data_shape,
indices_shape,
updates_shape);
}
private:
struct EvaluateByIndicesType : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t INDICES_ET, class DT, class IT = fundamental_type_for<INDICES_ET>>
static result_type visit(const DT* const data,
const Tensor& indices,
const DT* const updates,
DT* const output,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape) {
reference::scatterNdUpdate(data,
indices.data<IT>(),
updates,
output,
data_shape,
indices_shape,
updates_shape);
return true;
}
};
};
} // namespace scatter_nd_update
namespace v3 {
std::shared_ptr<Node> ScatterNDUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<op::v3::ScatterNDUpdate>(new_args.at(op::util::ScatterNDBase::INPUTS),
new_args.at(op::util::ScatterNDBase::INDICES),
new_args.at(op::util::ScatterNDBase::UPDATES));
return std::make_shared<ScatterNDUpdate>(new_args.at(util::ScatterNDBase::INPUTS),
new_args.at(util::ScatterNDBase::INDICES),
new_args.at(util::ScatterNDBase::UPDATES));
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace scatter {
namespace {
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& arg2,
const HostTensorPtr& out) {
using T = typename element_type_traits<ET>::value_type;
out->set_shape(arg0->get_shape());
if (arg1->get_element_type() == element::i64) {
ov::reference::scatterNdUpdate<T, int64_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int64_t>(),
arg2->get_data_ptr<ET>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
arg2->get_shape());
} else if (arg1->get_element_type() == element::i32) {
ov::reference::scatterNdUpdate<T, int32_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int32_t>(),
arg2->get_data_ptr<ET>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
arg2->get_shape());
} else {
OPENVINO_THROW("Unexpected type ",
arg1->get_element_type().c_type_string(),
" for ScatterNDUpdate evaluate method.");
}
return true;
}
bool evaluate_scatter(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& arg2,
const HostTensorPtr& out) {
bool rc = true;
switch (out->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_scatter, i32, arg0, arg1, arg2, out);
OPENVINO_TYPE_CASE(evaluate_scatter, i64, arg0, arg1, arg2, out);
OPENVINO_TYPE_CASE(evaluate_scatter, u32, arg0, arg1, arg2, out);
OPENVINO_TYPE_CASE(evaluate_scatter, u64, arg0, arg1, arg2, out);
OPENVINO_TYPE_CASE(evaluate_scatter, f16, arg0, arg1, arg2, out);
OPENVINO_TYPE_CASE(evaluate_scatter, f32, arg0, arg1, arg2, out);
OPENVINO_TYPE_CASE(evaluate_scatter, boolean, arg0, arg1, arg2, out);
default:
rc = false;
break;
}
return rc;
}
} // namespace
} // namespace scatter
bool op::v3::ScatterNDUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool ScatterNDUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate);
OPENVINO_ASSERT(!inputs.empty());
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(validate_host_tensor_vector(inputs, 3));
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1));
OPENVINO_SUPPRESS_DEPRECATED_END
OPENVINO_ASSERT(inputs.size() == 3);
OPENVINO_ASSERT(outputs.size() == 1);
return scatter::evaluate_scatter(inputs[0], inputs[1], inputs[2], outputs[0]);
const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
auto& output = outputs[0];
const auto& data_shape = data.get_shape();
const auto& indices_shape = indices.get_shape();
const auto& updates_shape = updates.get_shape();
output.set_shape(data_shape);
using namespace ov::element;
return IfTypeOf<boolean, f16, f32, i32, i64, u32, u64>::apply<scatter_nd_update::Evaluate>(data.get_element_type(),
data,
indices,
updates,
output,
data_shape,
indices_shape,
updates_shape);
}
bool op::v3::ScatterNDUpdate::has_evaluate() const {
bool ScatterNDUpdate::has_evaluate() const {
OV_OP_SCOPE(v3_ScatterNDUpdate_has_evaluate);
switch (get_output_element_type(0)) {
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u32:
case ngraph::element::u64:
case ngraph::element::f16:
case ngraph::element::f32:
case ngraph::element::boolean:
case element::boolean:
case element::f16:
case element::f32:
case element::i32:
case element::i64:
case element::u32:
case element::u64:
break;
default:
return false;
}
switch (get_input_element_type(1)) {
case ngraph::element::i32:
case ngraph::element::i64:
break;
case element::i32:
case element::i64:
return true;
default:
return false;
}
return true;
}
bool op::v3::ScatterNDUpdate::evaluate_lower(ov::TensorVector& output_values) const {
bool ScatterNDUpdate::evaluate_lower(TensorVector& output_values) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_lower);
return get_input_tensor(1).has_and_set_bound() && ov::default_lower_bound_evaluator(this, output_values);
return get_input_tensor(1).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
}
bool op::v3::ScatterNDUpdate::evaluate_upper(ov::TensorVector& output_values) const {
bool ScatterNDUpdate::evaluate_upper(TensorVector& output_values) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_upper);
return get_input_tensor(1).has_and_set_bound() && ov::default_upper_bound_evaluator(this, output_values);
return get_input_tensor(1).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
}
bool op::v3::ScatterNDUpdate::evaluate_label(TensorLabelVector& output_labels) const {
bool ScatterNDUpdate::evaluate_label(TensorLabelVector& output_labels) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_label);
OPENVINO_SUPPRESS_DEPRECATED_START
return ov::default_label_evaluator(this, {0, 2}, output_labels);
OPENVINO_SUPPRESS_DEPRECATED_END
return default_label_evaluator(this, {0, 2}, output_labels);
}
} // namespace v3
} // namespace op
} // namespace ov