Migrate Reverse operator to new API (#21277)

This commit is contained in:
Pawel Raasz 2023-11-28 09:41:56 +01:00 committed by GitHub
parent 21201833ec
commit 685ac0d0a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 91 deletions

View File

@ -41,9 +41,7 @@ public:
void set_mode(const Mode mode) {
m_mode = mode;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
protected:
@ -56,9 +54,6 @@ protected:
/// Alternatively it can contain a boolean mask that indicates which axes should be
/// reversed.
Mode m_mode;
private:
bool evaluate_reverse(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
};
} // namespace v1
} // namespace op

View File

@ -2,39 +2,55 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/reverse.hpp"
#include "openvino/op/reverse.hpp"
#include <algorithm>
#include <iterator>
#include <sstream>
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "openvino/op/util/axes_util.hpp"
#include "openvino/reference/reverse.hpp"
#include "reverse_shape_inference.hpp"
ov::op::v1::Reverse::Reverse(const Output<Node>& data, const Output<Node>& reversed_axes, const std::string& mode)
namespace ov {
namespace op {
namespace v1 {
namespace {
bool validate_axes_indices_et(const element::Type& et) {
switch (et) {
case element::i8:
case element::i16:
case element::i32:
case element::i64:
case element::u8:
case element::u16:
case element::u32:
case element::u64:
return true;
default:
return false;
}
}
} // namespace
Reverse::Reverse(const Output<Node>& data, const Output<Node>& reversed_axes, const std::string& mode)
: Op({data, reversed_axes}),
m_mode{mode_from_string(mode)} {
constructor_validate_and_infer_types();
}
ov::op::v1::Reverse::Reverse(const Output<Node>& data, const Output<Node>& reversed_axes, const Mode mode)
Reverse::Reverse(const Output<Node>& data, const Output<Node>& reversed_axes, const Mode mode)
: Op({data, reversed_axes}),
m_mode{mode} {
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Reverse::visit_attributes(AttributeVisitor& visitor) {
bool Reverse::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v1_Reverse_visit_attributes);
visitor.on_attribute("mode", m_mode);
return true;
}
void ov::op::v1::Reverse::validate_and_infer_types() {
void Reverse::validate_and_infer_types() {
OV_OP_SCOPE(v1_Reverse_validate_and_infer_types);
if (m_mode == Mode::MASK) {
NODE_VALIDATION_CHECK(this,
@ -53,13 +69,13 @@ void ov::op::v1::Reverse::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), output_shape);
}
std::shared_ptr<ov::Node> ov::op::v1::Reverse::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<ov::Node> Reverse::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_Reverse_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<op::v1::Reverse>(new_args.at(0), new_args.at(1), m_mode);
return std::make_shared<Reverse>(new_args.at(0), new_args.at(1), m_mode);
}
ov::op::v1::Reverse::Mode ov::op::v1::Reverse::mode_from_string(const std::string& mode) const {
Reverse::Mode Reverse::mode_from_string(const std::string& mode) const {
static const std::map<std::string, Mode> allowed_values = {{"index", Mode::INDEX}, {"mask", Mode::MASK}};
NODE_VALIDATION_CHECK(this, allowed_values.count(mode) > 0, "Invalid 'mode' value passed in.");
@ -67,93 +83,56 @@ ov::op::v1::Reverse::Mode ov::op::v1::Reverse::mode_from_string(const std::strin
return allowed_values.at(mode);
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace reverseop {
template <ov::element::Type_t ET>
void get_axes(ov::AxisSet& axes, const ngraph::HostTensorPtr& in) {
auto axes_indices = in->get_data_ptr<ET>();
size_t axes_rank = in->get_element_count();
std::copy(axes_indices, axes_indices + axes_rank, std::inserter(axes, axes.end()));
}
} // namespace reverseop
bool Reverse::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_Reverse_evaluate);
OPENVINO_ASSERT(outputs.size() == 1);
OPENVINO_ASSERT(inputs.size() == 2);
#define GET_AXES(a, ...) \
case element::Type_t::a: { \
OV_OP_SCOPE(OV_PP_CAT3(get_reverse_axes, _, a)); \
reverseop::get_axes<element::Type_t::a>(__VA_ARGS__); \
} break;
const auto& data = inputs[0];
const auto& axes = inputs[1];
const auto& data_shape = data.get_shape();
bool ov::op::v1::Reverse::evaluate_reverse(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
AxisSet axes{};
if (get_mode() == op::v1::Reverse::Mode::INDEX) {
switch (inputs[1]->get_element_type()) {
GET_AXES(i8, axes, inputs[1]);
GET_AXES(i16, axes, inputs[1]);
GET_AXES(i32, axes, inputs[1]);
GET_AXES(i64, axes, inputs[1]);
GET_AXES(u8, axes, inputs[1]);
GET_AXES(u16, axes, inputs[1]);
GET_AXES(u32, axes, inputs[1]);
GET_AXES(u64, axes, inputs[1]);
default:
OPENVINO_ASSERT(false, "Not supported axes type", inputs[1]->get_element_type());
}
} else // Mode::MASK
{
auto axes_mask = inputs[1]->get_data_ptr<bool>();
for (size_t i = 0; i < inputs[1]->get_element_count(); ++i) {
if (axes_mask[i]) {
axes.emplace(i);
AxisSet reversed_axes{};
if (get_mode() == Reverse::Mode::MASK) {
auto axes_mask = axes.data<const fundamental_type_for<element::boolean>>();
for (size_t i = 0; i < axes.get_size(); ++i, ++axes_mask) {
if (*axes_mask) {
reversed_axes.emplace(i);
}
}
} else if (validate_axes_indices_et(axes.get_element_type())) {
reversed_axes = util::get_normalized_axes_from_tensor(this, axes, data_shape.size());
} else {
return false;
}
ov::reference::reverse(inputs[0]->get_data_ptr<const char>(),
outputs[0]->get_data_ptr<char>(),
inputs[0]->get_shape(),
outputs[0]->get_shape(),
axes,
inputs[0]->get_element_type().size());
auto& output = outputs[0];
output.set_shape(data_shape);
reference::reverse(static_cast<const char*>(data.data()),
static_cast<char*>(output.data()),
data_shape,
output.get_shape(),
reversed_axes,
data.get_element_type().size());
return true;
}
bool ov::op::v1::Reverse::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
OV_OP_SCOPE(v1_Reverse_evaluate);
return evaluate_reverse(outputs, inputs);
}
bool ov::op::v1::Reverse::has_evaluate() const {
bool Reverse::has_evaluate() const {
OV_OP_SCOPE(v1_Reverse_has_evaluate);
if (get_mode() == op::v1::Reverse::Mode::INDEX) {
switch (get_input_element_type(1)) {
case ngraph::element::i8:
case ngraph::element::i16:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u8:
case ngraph::element::u16:
case ngraph::element::u32:
case ngraph::element::u64:
return true;
default:
return false;
;
}
} else {
return true;
}
return (m_mode == Reverse::Mode::MASK) || validate_axes_indices_et(get_input_element_type(1));
}
} // namespace v1
} // namespace op
std::ostream& ov::operator<<(std::ostream& s, const op::v1::Reverse::Mode& type) {
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type) {
return s << as_string(type);
}
namespace ov {
template <>
NGRAPH_API EnumNames<ngraph::op::v1::Reverse::Mode>& EnumNames<ngraph::op::v1::Reverse::Mode>::get() {
static auto enum_names = EnumNames<ngraph::op::v1::Reverse::Mode>(
OPENVINO_API EnumNames<op::v1::Reverse::Mode>& EnumNames<op::v1::Reverse::Mode>::get() {
static auto enum_names = EnumNames<op::v1::Reverse::Mode>(
"op::v1::Reverse::Mode",
{{"index", ngraph::op::v1::Reverse::Mode::INDEX}, {"mask", ngraph::op::v1::Reverse::Mode::MASK}});
{{"index", op::v1::Reverse::Mode::INDEX}, {"mask", op::v1::Reverse::Mode::MASK}});
return enum_names;
}
} // namespace ov