Migrate Split operator to new API (#20263)

This commit is contained in:
Pawel Raasz
2023-10-16 06:16:43 +02:00
committed by GitHub
parent 6f6017724f
commit f107b7663f
4 changed files with 96 additions and 76 deletions

View File

@@ -39,9 +39,8 @@ public:
void set_num_splits(const size_t num_splits) {
m_num_splits = num_splits;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& outputs) const override;
bool evaluate_upper(TensorVector& outputs) const override;
bool has_evaluate() const override;

View File

@@ -4,17 +4,28 @@
#pragma once
#include <cmath>
#include <cstddef>
#include "openvino/reference/slice.hpp"
#include "openvino/core/shape.hpp"
namespace ov {
namespace reference {
/**
* @brief Reference implementation of Split operator.
*
* @param data Pointer to input data.
* @param data_shape Input data shape.
* @param elem_size Size of single element type.
* @param axis Axis used for split input data.
* @param num_splits Number of splits
* @param out_data Pointer to output data pointers (must have size of num_splits)
*/
void split(const char* data,
const Shape& data_shape,
size_t elem_size,
int64_t axis,
size_t num_splits,
char** out_data);
}
} // namespace reference
} // namespace ov

View File

@@ -6,35 +6,43 @@
#include <stdio.h>
#include <cmath>
#include <iterator>
using namespace ov;
#include "openvino/core/coordinate.hpp"
#include "openvino/reference/slice.hpp"
void reference::split(const char* data,
const Shape& data_shape,
size_t elem_size,
int64_t axis,
size_t num_splits,
char** out_data) {
namespace ov {
namespace reference {
void split(const char* data,
const Shape& data_shape,
const size_t elem_size,
const int64_t axis,
const size_t num_splits,
char** out_data) {
const size_t part_length = data_shape.at(axis) / num_splits;
Shape output_shape = data_shape;
output_shape.at(axis) = part_length;
auto output_shape = data_shape;
output_shape[axis] = part_length;
std::vector<size_t> lower_bounds(data_shape.size(), 0);
std::vector<size_t> upper_bounds = data_shape;
upper_bounds.at(axis) = part_length;
Coordinate lower_bounds(data_shape.size(), 0);
Coordinate upper_bounds = output_shape;
auto& lb_at_axis = lower_bounds[axis];
auto& ub_at_axis = upper_bounds[axis];
for (size_t i = 0; i < num_splits; ++i) {
const auto out_last = std::next(out_data, num_splits);
for (auto out_first = out_data; out_first != out_last; ++out_first) {
reference::slice(data,
out_data[i],
*out_first,
data_shape,
lower_bounds,
upper_bounds,
Strides(lower_bounds.size(), 1),
output_shape,
elem_size);
lower_bounds.at(axis) += part_length;
upper_bounds.at(axis) += part_length;
lb_at_axis += part_length;
ub_at_axis += part_length;
}
}
} // namespace reference
} // namespace ov

View File

@@ -2,42 +2,46 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/reference/split.hpp"
#include "openvino/op/split.hpp"
#include <numeric>
#include <split_shape_inference.hpp>
#include "bound_evaluate.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/split.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/reference/split.hpp"
#include "split_shape_inference.hpp"
using namespace std;
using namespace ngraph;
namespace ov {
namespace op {
op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
namespace v1 {
namespace validate {
namespace {
bool axis_type(const element::Type& et) {
return et.is_integral_number();
}
} // namespace
} // namespace validate
Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
: Op({data, axis}),
m_num_splits{num_splits} {
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Split::visit_attributes(AttributeVisitor& visitor) {
bool Split::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v1_Split_visit_attributes);
visitor.on_attribute("num_splits", m_num_splits);
return true;
}
void op::v1::Split::validate_and_infer_types() {
void Split::validate_and_infer_types() {
OV_OP_SCOPE(v1_Split_validate_and_infer_types);
const auto& axis_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
axis_et.is_integral_number(),
validate::axis_type(axis_et),
"Element type of 'axis' input must be integer. Got: ",
axis_et);
@@ -58,72 +62,70 @@ void op::v1::Split::validate_and_infer_types() {
set_input_is_relevant_to_shape(0);
}
shared_ptr<Node> op::v1::Split::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Split::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_Split_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v1::Split>(new_args.at(0), new_args.at(1), m_num_splits);
return std::make_shared<Split>(new_args.at(0), new_args.at(1), m_num_splits);
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool op::v1::Split::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool Split::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_Split_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, m_num_splits) && validate_host_tensor_vector(inputs, 2));
OPENVINO_SUPPRESS_DEPRECATED_END
OPENVINO_ASSERT(outputs.size() == m_num_splits);
if (has_evaluate()) {
const auto output_shapes =
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs));
const auto& axis_tensor = inputs[1];
const auto result = validate::axis_type(axis_tensor.get_element_type());
if (result) {
const auto& data_tensor = inputs[0];
const auto& axis_tensor = inputs[1];
const auto input_shapes =
std::vector<PartialShape>{data_tensor->get_partial_shape(), axis_tensor->get_partial_shape()};
auto output_shapes = shape_infer(this, input_shapes, make_tensor_accessor(inputs));
auto outputs_data = std::vector<char*>(m_num_splits);
for (size_t i = 0; i < m_num_splits; ++i) {
outputs[i]->set_shape(output_shapes[i].get_shape());
outputs_data[i] = outputs[i]->get_data_ptr<char>();
{
auto outputs_it = outputs.begin();
auto outputs_data_it = outputs_data.begin();
for (const auto& p_shape : output_shapes) {
outputs_it->set_shape(p_shape.get_shape());
*outputs_data_it = static_cast<char*>(outputs_it->data());
++outputs_it, ++outputs_data_it;
}
}
OPENVINO_SUPPRESS_DEPRECATED_START
auto axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
axis = normalize_axis(this, axis, data_tensor->get_partial_shape().rank());
OPENVINO_SUPPRESS_DEPRECATED_END
auto axis = get_tensor_data_as<int64_t>(axis_tensor).front();
axis = ov::util::normalize(axis, data_tensor.get_shape().size());
ov::reference::split(data_tensor->get_data_ptr<char>(),
data_tensor->get_shape(),
data_tensor->get_element_type().size(),
ov::reference::split(static_cast<char*>(data_tensor.data()),
data_tensor.get_shape(),
data_tensor.get_element_type().size(),
axis,
m_num_splits,
outputs_data.data());
return true;
}
return false;
}
OPENVINO_SUPPRESS_DEPRECATED_END
bool op::v1::Split::has_evaluate() const {
return result;
}
bool Split::has_evaluate() const {
OV_OP_SCOPE(v1_Split_has_evaluate);
return get_input_element_type(1).is_integral_number();
return validate::axis_type(get_input_element_type(1));
}
bool op::v1::Split::evaluate_lower(ov::TensorVector& output_values) const {
bool Split::evaluate_lower(ov::TensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_lower);
return input(1).get_tensor().has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
return get_input_tensor(1).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
}
bool op::v1::Split::evaluate_upper(ov::TensorVector& output_values) const {
bool Split::evaluate_upper(ov::TensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_upper);
return input(1).get_tensor().has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
return get_input_tensor(1).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
}
bool op::v1::Split::evaluate_label(TensorLabelVector& output_labels) const {
bool Split::evaluate_label(TensorLabelVector& output_labels) const {
OPENVINO_ASSERT(output_labels.size() == get_num_splits());
OPENVINO_SUPPRESS_DEPRECATED_START
return input(1).get_tensor().has_and_set_bound() && default_label_evaluator(this, output_labels);
return get_input_tensor(1).has_and_set_bound() && default_label_evaluator(this, output_labels);
OPENVINO_SUPPRESS_DEPRECATED_END
}
} // namespace v1
} // namespace op
} // namespace ov