[core] Migrate ScatterElementsUpdate operator to new API (#21212)

* Remove redundant code

* Repalce HostTensor with ov::Tensor for v12

* Repalce HostTensor with ov::Tensor for v3

* Add Tensors count assertion

* Rename

* Revert axis normalization

* Don't duplicate the code
This commit is contained in:
Tomasz Jankowski 2023-11-28 09:32:30 +01:00 committed by GitHub
parent 37bac6ebcd
commit cf58a83094
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 292 deletions

View File

@ -28,16 +28,9 @@ public:
const Output<Node>& updates,
const Output<Node>& axis);
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
private:
bool evaluate_scatter_elements_update(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
};
} // namespace v3
namespace v12 {
@ -87,12 +80,9 @@ public:
bool has_evaluate() const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
private:
bool evaluate_scatter_elements_update(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
Reduction m_reduction = Reduction::NONE;
bool m_use_init_val = true;
};

View File

@ -35,7 +35,7 @@ public:
protected:
bool is_supported_index_input_element_type() const;
int64_t get_normalized_axis(const HostTensorVector& inputs) const;
int64_t get_normalized_axis(const TensorVector& inputs) const;
};
} // namespace util
} // namespace op

View File

@ -4,13 +4,11 @@
#include "openvino/op/scatter_elements_update.hpp"
#include <scatter_elements_update_shape_inference.hpp>
#include "element_visitor.hpp"
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/reference/scatter_elements_update.hpp"
using namespace std;
#include "scatter_elements_update_shape_inference.hpp"
namespace ov {
op::v3::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
@ -21,12 +19,7 @@ op::v3::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
constructor_validate_and_infer_types();
}
bool op::v3::ScatterElementsUpdate::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v3_ScatterElementsUpdate_visit_attributes);
return true;
}
shared_ptr<Node> op::v3::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
std::shared_ptr<Node> op::v3::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
OV_OP_SCOPE(v3_ScatterElementsUpdate_clone_with_new_inputs);
NODE_VALIDATION_CHECK(this,
inputs.size() == get_input_size(),
@ -35,7 +28,7 @@ shared_ptr<Node> op::v3::ScatterElementsUpdate::clone_with_new_inputs(const Outp
"Got: ",
inputs.size());
return make_shared<v3::ScatterElementsUpdate>(inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3));
return std::make_shared<v3::ScatterElementsUpdate>(inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3));
}
op::v12::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
@ -69,7 +62,7 @@ void op::v12::ScatterElementsUpdate::validate_and_infer_types() {
ScatterElementsUpdateBase::validate_and_infer_types();
}
shared_ptr<Node> op::v12::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
std::shared_ptr<Node> op::v12::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
OV_OP_SCOPE(v12_ScatterElementsUpdate_clone_with_new_inputs);
NODE_VALIDATION_CHECK(this,
inputs.size() == get_input_size(),
@ -78,12 +71,12 @@ shared_ptr<Node> op::v12::ScatterElementsUpdate::clone_with_new_inputs(const Out
"Got: ",
inputs.size());
return make_shared<v12::ScatterElementsUpdate>(inputs.at(0),
inputs.at(1),
inputs.at(2),
inputs.at(3),
m_reduction,
m_use_init_val);
return std::make_shared<v12::ScatterElementsUpdate>(inputs.at(0),
inputs.at(1),
inputs.at(2),
inputs.at(3),
m_reduction,
m_use_init_val);
}
bool op::v12::ScatterElementsUpdate::has_evaluate() const {
@ -91,243 +84,104 @@ bool op::v12::ScatterElementsUpdate::has_evaluate() const {
(get_output_element_type(0) == element::boolean && is_supported_index_input_element_type());
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace scatter_elements_update {
struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t DATA_ET, class DT = fundamental_type_for<DATA_ET>>
static result_type visit(const Tensor& data,
const Tensor& indices,
const Tensor& updates,
Tensor& output,
const Shape& data_shape,
const Shape& indices_shape,
const int64_t axis,
const op::v12::ScatterElementsUpdate::Reduction reduction,
const bool use_init_value
) {
using namespace ov::element;
return IfTypeOf<i8, i16, i32, i64, u8, u16, u32, u64>::apply<EvaluateByIndicesType>(indices.get_element_type(),
data.data<const DT>(),
indices,
updates.data<const DT>(),
output.data<DT>(),
data_shape,
indices_shape,
axis,
reduction,
use_init_value);
}
private:
struct EvaluateByIndicesType : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t INDEX_ET, class DT, class IT = fundamental_type_for<INDEX_ET>>
static result_type visit(const DT* const data,
const Tensor& indices,
const DT* const updates,
DT* const output,
const Shape& data_shape,
const Shape& indices_shape,
const int64_t axis,
const op::v12::ScatterElementsUpdate::Reduction reduction,
const bool use_init_value) {
reference::scatter_elem_update(data,
indices.data<IT>(),
updates,
axis,
output,
data_shape,
indices_shape,
reduction,
use_init_value);
return true;
}
};
};
namespace {
template <element::Type_t DT, element::Type_t IT, element::Type_t AT>
bool evaluate(const ngraph::HostTensorPtr& data,
const ngraph::HostTensorPtr& indices,
const ngraph::HostTensorPtr& updates,
const ngraph::HostTensorPtr& axis,
const ngraph::HostTensorPtr& out,
const int64_t normalized_axis,
const op::v12::ScatterElementsUpdate::Reduction reduction_type,
bool evaluate(TensorVector& outputs,
const TensorVector& inputs,
const int64_t axis,
const op::v12::ScatterElementsUpdate::Reduction reduction,
const bool use_init_value) {
using DataType = typename element_type_traits<DT>::value_type;
using IndicesType = typename element_type_traits<IT>::value_type;
OPENVINO_ASSERT(inputs.size() == 4);
OPENVINO_ASSERT(outputs.size() == 1);
out->set_shape(data->get_shape());
ov::reference::scatter_elem_update<DataType, IndicesType>(data->get_data_ptr<DT>(),
indices->get_data_ptr<IT>(),
updates->get_data_ptr<DT>(),
normalized_axis,
out->get_data_ptr<DT>(),
data->get_shape(),
indices->get_shape(),
reduction_type,
use_init_value);
return true;
}
#define TYPE_AXS_CASE(a, ...) \
case element::Type_t::a: { \
OV_OP_SCOPE(OV_PP_CAT3(scatter_element_update_axs, _, a)); \
rc = evaluate<DT, IT, element::Type_t::a>(__VA_ARGS__); \
} break;
template <element::Type_t DT, element::Type_t IT>
bool evaluate(const ngraph::HostTensorPtr& arg0,
const ngraph::HostTensorPtr& arg1,
const ngraph::HostTensorPtr& arg2,
const ngraph::HostTensorPtr& arg3,
const ngraph::HostTensorPtr& out,
const int64_t normalized_axis,
const op::v12::ScatterElementsUpdate::Reduction reduction_type,
const bool use_init_value) {
auto axis_type = arg3->get_element_type();
// Dispatch specialization based on axis data type.
bool rc = true;
switch (axis_type) {
TYPE_AXS_CASE(i8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(i16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(i32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(i64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(u8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(u16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(u32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_AXS_CASE(u64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
default:
rc = false;
break;
}
return rc;
}
#define TYPE_IND_CASE(a, ...) \
case element::Type_t::a: { \
OV_OP_SCOPE(OV_PP_CAT3(scatter_element_update_ind, _, a)); \
rc = evaluate<DT, element::Type_t::a>(__VA_ARGS__); \
} break;
template <element::Type_t DT>
bool evaluate(const ngraph::HostTensorPtr& arg0,
const ngraph::HostTensorPtr& arg1,
const ngraph::HostTensorPtr& arg2,
const ngraph::HostTensorPtr& arg3,
const ngraph::HostTensorPtr& out,
const int64_t normalized_axis,
const op::v12::ScatterElementsUpdate::Reduction reduction_type,
const bool use_init_value) {
auto indices_type = arg1->get_element_type();
// Dispatch specialization based on indicies data type.
bool rc = true;
switch (indices_type) {
TYPE_IND_CASE(i8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(i16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(i32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(i64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(u8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(u16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(u32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
TYPE_IND_CASE(u64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
default:
rc = false;
break;
}
return rc;
}
bool evaluate_scatter_elements_update(
const ngraph::HostTensorPtr& arg0,
const ngraph::HostTensorPtr& arg1,
const ngraph::HostTensorPtr& arg2,
const ngraph::HostTensorPtr& arg3,
const ngraph::HostTensorPtr& out,
const int64_t normalized_axis,
const op::v12::ScatterElementsUpdate::Reduction reduction_type = op::v12::ScatterElementsUpdate::Reduction::NONE,
const bool use_init_value = false) {
bool rc = true;
switch (out->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
i16,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
i32,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
i64,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
u32,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
u64,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
f16,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
f32,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
boolean,
arg0,
arg1,
arg2,
arg3,
out,
normalized_axis,
reduction_type,
use_init_value);
default:
rc = false;
break;
}
return rc;
const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
auto& output = outputs[0];
const auto& data_shape = data.get_shape();
const auto& indices_shape = indices.get_shape();
output.set_shape(data_shape);
using namespace ov::element;
return IfTypeOf<boolean, f16, f32, i16, i32, i64, u32, u64>::apply<scatter_elements_update::Evaluate>(
data.get_element_type(),
data,
indices,
updates,
output,
data_shape,
indices_shape,
axis,
reduction,
use_init_value);
}
} // namespace
} // namespace scatter_elements_update
bool op::v3::ScatterElementsUpdate::evaluate_scatter_elements_update(const HostTensorVector& outputs,
const HostTensorVector& inputs) const {
const auto normalized_axis = get_normalized_axis(inputs);
return scatter_elements_update::evaluate_scatter_elements_update(inputs[0],
inputs[1],
inputs[2],
inputs[3],
outputs[0],
normalized_axis);
}
bool op::v12::ScatterElementsUpdate::evaluate_scatter_elements_update(const HostTensorVector& outputs,
const HostTensorVector& inputs) const {
const auto normalized_axis = get_normalized_axis(inputs);
return scatter_elements_update::evaluate_scatter_elements_update(inputs[0],
inputs[1],
inputs[2],
inputs[3],
outputs[0],
normalized_axis,
m_reduction,
m_use_init_val);
}
bool op::v3::ScatterElementsUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool op::v3::ScatterElementsUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v3_ScatterElementsUpdate_evaluate);
return evaluate_scatter_elements_update(outputs, inputs);
constexpr auto reduction = op::v12::ScatterElementsUpdate::Reduction::NONE;
constexpr auto use_init_value = false;
return scatter_elements_update::evaluate(outputs, inputs, get_normalized_axis(inputs), reduction, use_init_value);
}
bool op::v12::ScatterElementsUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool op::v12::ScatterElementsUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v12_ScatterElementsUpdate_evaluate);
return evaluate_scatter_elements_update(outputs, inputs);
return scatter_elements_update::evaluate(outputs, inputs, get_normalized_axis(inputs), m_reduction, m_use_init_val);
}
template <>

View File

@ -8,25 +8,26 @@
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "scatter_elements_update_shape_inference.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
ov::op::util::ScatterElementsUpdateBase::ScatterElementsUpdateBase(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& updates,
const Output<Node>& axis)
util::ScatterElementsUpdateBase::ScatterElementsUpdateBase(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& updates,
const Output<Node>& axis)
: Op({data, indices, updates, axis}) {
constructor_validate_and_infer_types();
}
void ov::op::util::ScatterElementsUpdateBase::validate_and_infer_types() {
void util::ScatterElementsUpdateBase::validate_and_infer_types() {
OV_OP_SCOPE(util_ScatterElementsUpdateBase_validate_and_infer_types);
OPENVINO_SUPPRESS_DEPRECATED_START
const element::Type& data_et = get_input_element_type(0);
const element::Type& indices_et = get_input_element_type(1);
const element::Type& updates_et = get_input_element_type(2);
const element::Type& axis_et = get_input_element_type(3);
const auto& data_et = get_input_element_type(0);
const auto& indices_et = get_input_element_type(1);
const auto& updates_et = get_input_element_type(2);
const auto& axis_et = get_input_element_type(3);
NODE_VALIDATION_CHECK(this,
indices_et.is_integral(),
@ -45,7 +46,7 @@ void ov::op::util::ScatterElementsUpdateBase::validate_and_infer_types() {
updates_et);
const auto output_shape = shape_infer(this, get_node_input_partial_shapes(*this)).front();
OPENVINO_SUPPRESS_DEPRECATED_END
element::Type out_et = get_input_element_type(0);
auto out_et = get_input_element_type(0);
std::ignore = element::Type::merge(out_et, get_input_element_type(0), get_input_element_type(2));
set_output_type(0, out_et, output_shape);
if (output_shape.is_dynamic()) {
@ -53,7 +54,7 @@ void ov::op::util::ScatterElementsUpdateBase::validate_and_infer_types() {
}
}
bool op::util::ScatterElementsUpdateBase::has_evaluate() const {
bool util::ScatterElementsUpdateBase::has_evaluate() const {
OV_OP_SCOPE(util_ScatterElementsUpdateBase_has_evaluate);
switch (get_output_element_type(0)) {
@ -72,7 +73,7 @@ bool op::util::ScatterElementsUpdateBase::has_evaluate() const {
return is_supported_index_input_element_type();
}
bool op::util::ScatterElementsUpdateBase::is_supported_index_input_element_type() const {
bool util::ScatterElementsUpdateBase::is_supported_index_input_element_type() const {
switch (get_input_element_type(1)) {
case element::i8:
case element::i16:
@ -88,46 +89,31 @@ bool op::util::ScatterElementsUpdateBase::is_supported_index_input_element_type(
}
}
bool op::util::ScatterElementsUpdateBase::evaluate_lower(ov::TensorVector& output_values) const {
bool util::ScatterElementsUpdateBase::evaluate_lower(ov::TensorVector& output_values) const {
OV_OP_SCOPE(util_ScatterNDUpdate_evaluate_lower);
return get_input_tensor(1).has_and_set_bound() && ov::default_lower_bound_evaluator(this, output_values);
}
bool op::util::ScatterElementsUpdateBase::evaluate_upper(ov::TensorVector& output_values) const {
bool util::ScatterElementsUpdateBase::evaluate_upper(ov::TensorVector& output_values) const {
OV_OP_SCOPE(util_ScatterNDUpdate_evaluate_upper);
return get_input_tensor(1).has_and_set_bound() && ov::default_upper_bound_evaluator(this, output_values);
}
bool op::util::ScatterElementsUpdateBase::evaluate_label(TensorLabelVector& output_labels) const {
bool util::ScatterElementsUpdateBase::evaluate_label(TensorLabelVector& output_labels) const {
OV_OP_SCOPE(util_ScatterNDUpdate_evaluate_label);
OPENVINO_SUPPRESS_DEPRECATED_START
return ov::default_label_evaluator(this, {0, 2}, output_labels);
OPENVINO_SUPPRESS_DEPRECATED_END
}
OPENVINO_SUPPRESS_DEPRECATED_START
int64_t op::util::ScatterElementsUpdateBase::get_normalized_axis(const HostTensorVector& inputs) const {
OPENVINO_ASSERT(inputs[3]->get_element_type().is_integral_number(), "axis element type is not integral data type");
int64_t util::ScatterElementsUpdateBase::get_normalized_axis(const TensorVector& inputs) const {
const auto& axis_input = inputs[3];
OPENVINO_ASSERT(axis_input.get_element_type().is_integral_number(), "axis element type is not integral data type");
const auto axis = get_tensor_data_as<int64_t>(axis_input)[0];
const auto data_rank = static_cast<int64_t>(inputs[0].get_shape().size());
OPENVINO_SUPPRESS_DEPRECATED_START
int64_t axis = host_tensor_2_vector<int64_t>(inputs[3])[0];
return ov::normalize_axis(this, axis, data_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
const auto& input_rank = get_input_partial_shape(0).rank();
int64_t normalized_axis = axis;
if (normalized_axis < 0) {
if (input_rank.is_static()) {
OPENVINO_SUPPRESS_DEPRECATED_START
normalized_axis = ov::normalize_axis(this, axis, input_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
} else {
OPENVINO_SUPPRESS_DEPRECATED_START
normalized_axis = ov::normalize_axis(this, axis, static_cast<int64_t>(inputs[0]->get_shape().size()));
OPENVINO_SUPPRESS_DEPRECATED_END
}
}
return normalized_axis;
}
} // namespace op
} // namespace ov