[core] Migrate ScatterNDUpdate operator to new API (#21231)
* Drop legacy stuff * Repalce HostTensor with ov::Tensor
This commit is contained in:
parent
9f87f72ca6
commit
8231d57c38
@ -22,9 +22,7 @@ public:
|
||||
: util::ScatterNDBase(inputs, indices, updates) {}
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool evaluate_lower(TensorVector& output_values) const override;
|
||||
bool evaluate_upper(TensorVector& output_values) const override;
|
||||
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
||||
|
@ -2,134 +2,134 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/scatter_nd_update.hpp"
|
||||
#include "openvino/op/scatter_nd_update.hpp"
|
||||
|
||||
#include "bound_evaluate.hpp"
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/reference/scatter_nd_update.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace scatter_nd_update {
|
||||
struct Evaluate : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
shared_ptr<Node> op::v3::ScatterNDUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
template <element::Type_t DATA_ET, class DT = fundamental_type_for<DATA_ET>>
|
||||
static result_type visit(const Tensor& data,
|
||||
const Tensor& indices,
|
||||
const Tensor& updates,
|
||||
Tensor& output,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& updates_shape) {
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<i32, i64>::apply<EvaluateByIndicesType>(indices.get_element_type(),
|
||||
data.data<const DT>(),
|
||||
indices,
|
||||
updates.data<const DT>(),
|
||||
output.data<DT>(),
|
||||
data_shape,
|
||||
indices_shape,
|
||||
updates_shape);
|
||||
}
|
||||
|
||||
private:
|
||||
struct EvaluateByIndicesType : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t INDICES_ET, class DT, class IT = fundamental_type_for<INDICES_ET>>
|
||||
static result_type visit(const DT* const data,
|
||||
const Tensor& indices,
|
||||
const DT* const updates,
|
||||
DT* const output,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const Shape& updates_shape) {
|
||||
reference::scatterNdUpdate(data,
|
||||
indices.data<IT>(),
|
||||
updates,
|
||||
output,
|
||||
data_shape,
|
||||
indices_shape,
|
||||
updates_shape);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace scatter_nd_update
|
||||
|
||||
namespace v3 {
|
||||
std::shared_ptr<Node> ScatterNDUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v3_ScatterNDUpdate_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::v3::ScatterNDUpdate>(new_args.at(op::util::ScatterNDBase::INPUTS),
|
||||
new_args.at(op::util::ScatterNDBase::INDICES),
|
||||
new_args.at(op::util::ScatterNDBase::UPDATES));
|
||||
return std::make_shared<ScatterNDUpdate>(new_args.at(util::ScatterNDBase::INPUTS),
|
||||
new_args.at(util::ScatterNDBase::INDICES),
|
||||
new_args.at(util::ScatterNDBase::UPDATES));
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace scatter {
|
||||
namespace {
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out) {
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
out->set_shape(arg0->get_shape());
|
||||
|
||||
if (arg1->get_element_type() == element::i64) {
|
||||
ov::reference::scatterNdUpdate<T, int64_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int64_t>(),
|
||||
arg2->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
arg2->get_shape());
|
||||
} else if (arg1->get_element_type() == element::i32) {
|
||||
ov::reference::scatterNdUpdate<T, int32_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int32_t>(),
|
||||
arg2->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
arg2->get_shape());
|
||||
} else {
|
||||
OPENVINO_THROW("Unexpected type ",
|
||||
arg1->get_element_type().c_type_string(),
|
||||
" for ScatterNDUpdate evaluate method.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_scatter(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out) {
|
||||
bool rc = true;
|
||||
|
||||
switch (out->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, i32, arg0, arg1, arg2, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, i64, arg0, arg1, arg2, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, u32, arg0, arg1, arg2, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, u64, arg0, arg1, arg2, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, f16, arg0, arg1, arg2, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, f32, arg0, arg1, arg2, out);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter, boolean, arg0, arg1, arg2, out);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace scatter
|
||||
|
||||
bool op::v3::ScatterNDUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool ScatterNDUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate);
|
||||
OPENVINO_ASSERT(!inputs.empty());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(inputs, 3));
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
OPENVINO_ASSERT(inputs.size() == 3);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
return scatter::evaluate_scatter(inputs[0], inputs[1], inputs[2], outputs[0]);
|
||||
const auto& data = inputs[0];
|
||||
const auto& indices = inputs[1];
|
||||
const auto& updates = inputs[2];
|
||||
auto& output = outputs[0];
|
||||
const auto& data_shape = data.get_shape();
|
||||
const auto& indices_shape = indices.get_shape();
|
||||
const auto& updates_shape = updates.get_shape();
|
||||
output.set_shape(data_shape);
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<boolean, f16, f32, i32, i64, u32, u64>::apply<scatter_nd_update::Evaluate>(data.get_element_type(),
|
||||
data,
|
||||
indices,
|
||||
updates,
|
||||
output,
|
||||
data_shape,
|
||||
indices_shape,
|
||||
updates_shape);
|
||||
}
|
||||
|
||||
bool op::v3::ScatterNDUpdate::has_evaluate() const {
|
||||
bool ScatterNDUpdate::has_evaluate() const {
|
||||
OV_OP_SCOPE(v3_ScatterNDUpdate_has_evaluate);
|
||||
|
||||
switch (get_output_element_type(0)) {
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
case ngraph::element::u32:
|
||||
case ngraph::element::u64:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case ngraph::element::boolean:
|
||||
case element::boolean:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
case element::u32:
|
||||
case element::u64:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
switch (get_input_element_type(1)) {
|
||||
case ngraph::element::i32:
|
||||
case ngraph::element::i64:
|
||||
break;
|
||||
case element::i32:
|
||||
case element::i64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool op::v3::ScatterNDUpdate::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
bool ScatterNDUpdate::evaluate_lower(TensorVector& output_values) const {
|
||||
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_lower);
|
||||
return get_input_tensor(1).has_and_set_bound() && ov::default_lower_bound_evaluator(this, output_values);
|
||||
return get_input_tensor(1).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v3::ScatterNDUpdate::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
bool ScatterNDUpdate::evaluate_upper(TensorVector& output_values) const {
|
||||
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_upper);
|
||||
return get_input_tensor(1).has_and_set_bound() && ov::default_upper_bound_evaluator(this, output_values);
|
||||
return get_input_tensor(1).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v3::ScatterNDUpdate::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
bool ScatterNDUpdate::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_label);
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
return ov::default_label_evaluator(this, {0, 2}, output_labels);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
return default_label_evaluator(this, {0, 2}, output_labels);
|
||||
}
|
||||
} // namespace v3
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user