[core] Migrate ScatterElementsUpdate operator to new API (#21212)
* Remove redundant code * Repalce HostTensor with ov::Tensor for v12 * Repalce HostTensor with ov::Tensor for v3 * Add Tensors count assertion * Rename * Revert axis normalization * Don't duplicate the code
This commit is contained in:
parent
37bac6ebcd
commit
cf58a83094
@ -28,16 +28,9 @@ public:
|
||||
const Output<Node>& updates,
|
||||
const Output<Node>& axis);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
private:
|
||||
bool evaluate_scatter_elements_update(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
};
|
||||
} // namespace v3
|
||||
namespace v12 {
|
||||
@ -87,12 +80,9 @@ public:
|
||||
|
||||
bool has_evaluate() const override;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
|
||||
private:
|
||||
bool evaluate_scatter_elements_update(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
|
||||
Reduction m_reduction = Reduction::NONE;
|
||||
bool m_use_init_val = true;
|
||||
};
|
||||
|
@ -35,7 +35,7 @@ public:
|
||||
|
||||
protected:
|
||||
bool is_supported_index_input_element_type() const;
|
||||
int64_t get_normalized_axis(const HostTensorVector& inputs) const;
|
||||
int64_t get_normalized_axis(const TensorVector& inputs) const;
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
|
@ -4,13 +4,11 @@
|
||||
|
||||
#include "openvino/op/scatter_elements_update.hpp"
|
||||
|
||||
#include <scatter_elements_update_shape_inference.hpp>
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/reference/scatter_elements_update.hpp"
|
||||
|
||||
using namespace std;
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
|
||||
namespace ov {
|
||||
op::v3::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
|
||||
@ -21,12 +19,7 @@ op::v3::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v3::ScatterElementsUpdate::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v3_ScatterElementsUpdate_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
std::shared_ptr<Node> op::v3::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
OV_OP_SCOPE(v3_ScatterElementsUpdate_clone_with_new_inputs);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
inputs.size() == get_input_size(),
|
||||
@ -35,7 +28,7 @@ shared_ptr<Node> op::v3::ScatterElementsUpdate::clone_with_new_inputs(const Outp
|
||||
"Got: ",
|
||||
inputs.size());
|
||||
|
||||
return make_shared<v3::ScatterElementsUpdate>(inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3));
|
||||
return std::make_shared<v3::ScatterElementsUpdate>(inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3));
|
||||
}
|
||||
|
||||
op::v12::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
|
||||
@ -69,7 +62,7 @@ void op::v12::ScatterElementsUpdate::validate_and_infer_types() {
|
||||
ScatterElementsUpdateBase::validate_and_infer_types();
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v12::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
std::shared_ptr<Node> op::v12::ScatterElementsUpdate::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
OV_OP_SCOPE(v12_ScatterElementsUpdate_clone_with_new_inputs);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
inputs.size() == get_input_size(),
|
||||
@ -78,12 +71,12 @@ shared_ptr<Node> op::v12::ScatterElementsUpdate::clone_with_new_inputs(const Out
|
||||
"Got: ",
|
||||
inputs.size());
|
||||
|
||||
return make_shared<v12::ScatterElementsUpdate>(inputs.at(0),
|
||||
inputs.at(1),
|
||||
inputs.at(2),
|
||||
inputs.at(3),
|
||||
m_reduction,
|
||||
m_use_init_val);
|
||||
return std::make_shared<v12::ScatterElementsUpdate>(inputs.at(0),
|
||||
inputs.at(1),
|
||||
inputs.at(2),
|
||||
inputs.at(3),
|
||||
m_reduction,
|
||||
m_use_init_val);
|
||||
}
|
||||
|
||||
bool op::v12::ScatterElementsUpdate::has_evaluate() const {
|
||||
@ -91,243 +84,104 @@ bool op::v12::ScatterElementsUpdate::has_evaluate() const {
|
||||
(get_output_element_type(0) == element::boolean && is_supported_index_input_element_type());
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace scatter_elements_update {
|
||||
struct Evaluate : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t DATA_ET, class DT = fundamental_type_for<DATA_ET>>
|
||||
static result_type visit(const Tensor& data,
|
||||
const Tensor& indices,
|
||||
const Tensor& updates,
|
||||
Tensor& output,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const int64_t axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction,
|
||||
const bool use_init_value
|
||||
|
||||
) {
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<i8, i16, i32, i64, u8, u16, u32, u64>::apply<EvaluateByIndicesType>(indices.get_element_type(),
|
||||
data.data<const DT>(),
|
||||
indices,
|
||||
updates.data<const DT>(),
|
||||
output.data<DT>(),
|
||||
data_shape,
|
||||
indices_shape,
|
||||
axis,
|
||||
reduction,
|
||||
use_init_value);
|
||||
}
|
||||
|
||||
private:
|
||||
struct EvaluateByIndicesType : public element::NoAction<bool> {
|
||||
using element::NoAction<bool>::visit;
|
||||
|
||||
template <element::Type_t INDEX_ET, class DT, class IT = fundamental_type_for<INDEX_ET>>
|
||||
static result_type visit(const DT* const data,
|
||||
const Tensor& indices,
|
||||
const DT* const updates,
|
||||
DT* const output,
|
||||
const Shape& data_shape,
|
||||
const Shape& indices_shape,
|
||||
const int64_t axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction,
|
||||
const bool use_init_value) {
|
||||
reference::scatter_elem_update(data,
|
||||
indices.data<IT>(),
|
||||
updates,
|
||||
axis,
|
||||
output,
|
||||
data_shape,
|
||||
indices_shape,
|
||||
reduction,
|
||||
use_init_value);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
};
|
||||
namespace {
|
||||
template <element::Type_t DT, element::Type_t IT, element::Type_t AT>
|
||||
bool evaluate(const ngraph::HostTensorPtr& data,
|
||||
const ngraph::HostTensorPtr& indices,
|
||||
const ngraph::HostTensorPtr& updates,
|
||||
const ngraph::HostTensorPtr& axis,
|
||||
const ngraph::HostTensorPtr& out,
|
||||
const int64_t normalized_axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction_type,
|
||||
bool evaluate(TensorVector& outputs,
|
||||
const TensorVector& inputs,
|
||||
const int64_t axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction,
|
||||
const bool use_init_value) {
|
||||
using DataType = typename element_type_traits<DT>::value_type;
|
||||
using IndicesType = typename element_type_traits<IT>::value_type;
|
||||
OPENVINO_ASSERT(inputs.size() == 4);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
out->set_shape(data->get_shape());
|
||||
|
||||
ov::reference::scatter_elem_update<DataType, IndicesType>(data->get_data_ptr<DT>(),
|
||||
indices->get_data_ptr<IT>(),
|
||||
updates->get_data_ptr<DT>(),
|
||||
normalized_axis,
|
||||
out->get_data_ptr<DT>(),
|
||||
data->get_shape(),
|
||||
indices->get_shape(),
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#define TYPE_AXS_CASE(a, ...) \
|
||||
case element::Type_t::a: { \
|
||||
OV_OP_SCOPE(OV_PP_CAT3(scatter_element_update_axs, _, a)); \
|
||||
rc = evaluate<DT, IT, element::Type_t::a>(__VA_ARGS__); \
|
||||
} break;
|
||||
|
||||
template <element::Type_t DT, element::Type_t IT>
|
||||
bool evaluate(const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& arg1,
|
||||
const ngraph::HostTensorPtr& arg2,
|
||||
const ngraph::HostTensorPtr& arg3,
|
||||
const ngraph::HostTensorPtr& out,
|
||||
const int64_t normalized_axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction_type,
|
||||
const bool use_init_value) {
|
||||
auto axis_type = arg3->get_element_type();
|
||||
|
||||
// Dispatch specialization based on axis data type.
|
||||
bool rc = true;
|
||||
|
||||
switch (axis_type) {
|
||||
TYPE_AXS_CASE(i8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(i16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(i32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(i64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(u8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(u16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(u32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_AXS_CASE(u64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
#define TYPE_IND_CASE(a, ...) \
|
||||
case element::Type_t::a: { \
|
||||
OV_OP_SCOPE(OV_PP_CAT3(scatter_element_update_ind, _, a)); \
|
||||
rc = evaluate<DT, element::Type_t::a>(__VA_ARGS__); \
|
||||
} break;
|
||||
|
||||
template <element::Type_t DT>
|
||||
bool evaluate(const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& arg1,
|
||||
const ngraph::HostTensorPtr& arg2,
|
||||
const ngraph::HostTensorPtr& arg3,
|
||||
const ngraph::HostTensorPtr& out,
|
||||
const int64_t normalized_axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction_type,
|
||||
const bool use_init_value) {
|
||||
auto indices_type = arg1->get_element_type();
|
||||
|
||||
// Dispatch specialization based on indicies data type.
|
||||
bool rc = true;
|
||||
|
||||
switch (indices_type) {
|
||||
TYPE_IND_CASE(i8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(i16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(i32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(i64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(u8, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(u16, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(u32, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
TYPE_IND_CASE(u64, arg0, arg1, arg2, arg3, out, normalized_axis, reduction_type, use_init_value);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool evaluate_scatter_elements_update(
|
||||
const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& arg1,
|
||||
const ngraph::HostTensorPtr& arg2,
|
||||
const ngraph::HostTensorPtr& arg3,
|
||||
const ngraph::HostTensorPtr& out,
|
||||
const int64_t normalized_axis,
|
||||
const op::v12::ScatterElementsUpdate::Reduction reduction_type = op::v12::ScatterElementsUpdate::Reduction::NONE,
|
||||
const bool use_init_value = false) {
|
||||
bool rc = true;
|
||||
|
||||
switch (out->get_element_type()) {
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
i16,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
i32,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
i64,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
u32,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
u64,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
f16,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
f32,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
OPENVINO_TYPE_CASE(evaluate_scatter_element_update,
|
||||
boolean,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
out,
|
||||
normalized_axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
}
|
||||
return rc;
|
||||
const auto& data = inputs[0];
|
||||
const auto& indices = inputs[1];
|
||||
const auto& updates = inputs[2];
|
||||
auto& output = outputs[0];
|
||||
const auto& data_shape = data.get_shape();
|
||||
const auto& indices_shape = indices.get_shape();
|
||||
output.set_shape(data_shape);
|
||||
using namespace ov::element;
|
||||
return IfTypeOf<boolean, f16, f32, i16, i32, i64, u32, u64>::apply<scatter_elements_update::Evaluate>(
|
||||
data.get_element_type(),
|
||||
data,
|
||||
indices,
|
||||
updates,
|
||||
output,
|
||||
data_shape,
|
||||
indices_shape,
|
||||
axis,
|
||||
reduction,
|
||||
use_init_value);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace scatter_elements_update
|
||||
|
||||
bool op::v3::ScatterElementsUpdate::evaluate_scatter_elements_update(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const {
|
||||
const auto normalized_axis = get_normalized_axis(inputs);
|
||||
|
||||
return scatter_elements_update::evaluate_scatter_elements_update(inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
outputs[0],
|
||||
normalized_axis);
|
||||
}
|
||||
|
||||
bool op::v12::ScatterElementsUpdate::evaluate_scatter_elements_update(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const {
|
||||
const auto normalized_axis = get_normalized_axis(inputs);
|
||||
|
||||
return scatter_elements_update::evaluate_scatter_elements_update(inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
outputs[0],
|
||||
normalized_axis,
|
||||
m_reduction,
|
||||
m_use_init_val);
|
||||
}
|
||||
|
||||
bool op::v3::ScatterElementsUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool op::v3::ScatterElementsUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v3_ScatterElementsUpdate_evaluate);
|
||||
return evaluate_scatter_elements_update(outputs, inputs);
|
||||
constexpr auto reduction = op::v12::ScatterElementsUpdate::Reduction::NONE;
|
||||
constexpr auto use_init_value = false;
|
||||
return scatter_elements_update::evaluate(outputs, inputs, get_normalized_axis(inputs), reduction, use_init_value);
|
||||
}
|
||||
|
||||
bool op::v12::ScatterElementsUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool op::v12::ScatterElementsUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v12_ScatterElementsUpdate_evaluate);
|
||||
return evaluate_scatter_elements_update(outputs, inputs);
|
||||
return scatter_elements_update::evaluate(outputs, inputs, get_normalized_axis(inputs), m_reduction, m_use_init_val);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -8,25 +8,26 @@
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
|
||||
ov::op::util::ScatterElementsUpdateBase::ScatterElementsUpdateBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates,
|
||||
const Output<Node>& axis)
|
||||
util::ScatterElementsUpdateBase::ScatterElementsUpdateBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates,
|
||||
const Output<Node>& axis)
|
||||
: Op({data, indices, updates, axis}) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void ov::op::util::ScatterElementsUpdateBase::validate_and_infer_types() {
|
||||
void util::ScatterElementsUpdateBase::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(util_ScatterElementsUpdateBase_validate_and_infer_types);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const element::Type& data_et = get_input_element_type(0);
|
||||
const element::Type& indices_et = get_input_element_type(1);
|
||||
const element::Type& updates_et = get_input_element_type(2);
|
||||
const element::Type& axis_et = get_input_element_type(3);
|
||||
const auto& data_et = get_input_element_type(0);
|
||||
const auto& indices_et = get_input_element_type(1);
|
||||
const auto& updates_et = get_input_element_type(2);
|
||||
const auto& axis_et = get_input_element_type(3);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_et.is_integral(),
|
||||
@ -45,7 +46,7 @@ void ov::op::util::ScatterElementsUpdateBase::validate_and_infer_types() {
|
||||
updates_et);
|
||||
const auto output_shape = shape_infer(this, get_node_input_partial_shapes(*this)).front();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
element::Type out_et = get_input_element_type(0);
|
||||
auto out_et = get_input_element_type(0);
|
||||
std::ignore = element::Type::merge(out_et, get_input_element_type(0), get_input_element_type(2));
|
||||
set_output_type(0, out_et, output_shape);
|
||||
if (output_shape.is_dynamic()) {
|
||||
@ -53,7 +54,7 @@ void ov::op::util::ScatterElementsUpdateBase::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
bool op::util::ScatterElementsUpdateBase::has_evaluate() const {
|
||||
bool util::ScatterElementsUpdateBase::has_evaluate() const {
|
||||
OV_OP_SCOPE(util_ScatterElementsUpdateBase_has_evaluate);
|
||||
|
||||
switch (get_output_element_type(0)) {
|
||||
@ -72,7 +73,7 @@ bool op::util::ScatterElementsUpdateBase::has_evaluate() const {
|
||||
return is_supported_index_input_element_type();
|
||||
}
|
||||
|
||||
bool op::util::ScatterElementsUpdateBase::is_supported_index_input_element_type() const {
|
||||
bool util::ScatterElementsUpdateBase::is_supported_index_input_element_type() const {
|
||||
switch (get_input_element_type(1)) {
|
||||
case element::i8:
|
||||
case element::i16:
|
||||
@ -88,46 +89,31 @@ bool op::util::ScatterElementsUpdateBase::is_supported_index_input_element_type(
|
||||
}
|
||||
}
|
||||
|
||||
bool op::util::ScatterElementsUpdateBase::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
bool util::ScatterElementsUpdateBase::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
OV_OP_SCOPE(util_ScatterNDUpdate_evaluate_lower);
|
||||
return get_input_tensor(1).has_and_set_bound() && ov::default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::util::ScatterElementsUpdateBase::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
bool util::ScatterElementsUpdateBase::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
OV_OP_SCOPE(util_ScatterNDUpdate_evaluate_upper);
|
||||
return get_input_tensor(1).has_and_set_bound() && ov::default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::util::ScatterElementsUpdateBase::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
bool util::ScatterElementsUpdateBase::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
OV_OP_SCOPE(util_ScatterNDUpdate_evaluate_label);
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
return ov::default_label_evaluator(this, {0, 2}, output_labels);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
int64_t op::util::ScatterElementsUpdateBase::get_normalized_axis(const HostTensorVector& inputs) const {
|
||||
OPENVINO_ASSERT(inputs[3]->get_element_type().is_integral_number(), "axis element type is not integral data type");
|
||||
int64_t util::ScatterElementsUpdateBase::get_normalized_axis(const TensorVector& inputs) const {
|
||||
const auto& axis_input = inputs[3];
|
||||
OPENVINO_ASSERT(axis_input.get_element_type().is_integral_number(), "axis element type is not integral data type");
|
||||
|
||||
const auto axis = get_tensor_data_as<int64_t>(axis_input)[0];
|
||||
const auto data_rank = static_cast<int64_t>(inputs[0].get_shape().size());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
int64_t axis = host_tensor_2_vector<int64_t>(inputs[3])[0];
|
||||
return ov::normalize_axis(this, axis, data_rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
const auto& input_rank = get_input_partial_shape(0).rank();
|
||||
int64_t normalized_axis = axis;
|
||||
|
||||
if (normalized_axis < 0) {
|
||||
if (input_rank.is_static()) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
normalized_axis = ov::normalize_axis(this, axis, input_rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
} else {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
normalized_axis = ov::normalize_axis(this, axis, static_cast<int64_t>(inputs[0]->get_shape().size()));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
}
|
||||
return normalized_axis;
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user