Migrate Split operator to new API (#20263)
This commit is contained in:
@@ -39,9 +39,8 @@ public:
|
||||
void set_num_splits(const size_t num_splits) {
|
||||
m_num_splits = num_splits;
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool evaluate_lower(TensorVector& outputs) const override;
|
||||
bool evaluate_upper(TensorVector& outputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
|
||||
@@ -4,17 +4,28 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "openvino/reference/slice.hpp"
|
||||
#include "openvino/core/shape.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
|
||||
/**
|
||||
* @brief Reference implementation of Split operator.
|
||||
*
|
||||
* @param data Pointer to input data.
|
||||
* @param data_shape Input data shape.
|
||||
* @param elem_size Size of single element type.
|
||||
* @param axis Axis used for split input data.
|
||||
* @param num_splits Number of splits
|
||||
* @param out_data Pointer to output data pointers (must have size of num_splits)
|
||||
*/
|
||||
void split(const char* data,
|
||||
const Shape& data_shape,
|
||||
size_t elem_size,
|
||||
int64_t axis,
|
||||
size_t num_splits,
|
||||
char** out_data);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace ov
|
||||
|
||||
@@ -6,35 +6,43 @@
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <iterator>
|
||||
|
||||
using namespace ov;
|
||||
#include "openvino/core/coordinate.hpp"
|
||||
#include "openvino/reference/slice.hpp"
|
||||
|
||||
void reference::split(const char* data,
|
||||
const Shape& data_shape,
|
||||
size_t elem_size,
|
||||
int64_t axis,
|
||||
size_t num_splits,
|
||||
char** out_data) {
|
||||
namespace ov {
|
||||
namespace reference {
|
||||
|
||||
void split(const char* data,
|
||||
const Shape& data_shape,
|
||||
const size_t elem_size,
|
||||
const int64_t axis,
|
||||
const size_t num_splits,
|
||||
char** out_data) {
|
||||
const size_t part_length = data_shape.at(axis) / num_splits;
|
||||
|
||||
Shape output_shape = data_shape;
|
||||
output_shape.at(axis) = part_length;
|
||||
auto output_shape = data_shape;
|
||||
output_shape[axis] = part_length;
|
||||
|
||||
std::vector<size_t> lower_bounds(data_shape.size(), 0);
|
||||
std::vector<size_t> upper_bounds = data_shape;
|
||||
upper_bounds.at(axis) = part_length;
|
||||
Coordinate lower_bounds(data_shape.size(), 0);
|
||||
Coordinate upper_bounds = output_shape;
|
||||
auto& lb_at_axis = lower_bounds[axis];
|
||||
auto& ub_at_axis = upper_bounds[axis];
|
||||
|
||||
for (size_t i = 0; i < num_splits; ++i) {
|
||||
const auto out_last = std::next(out_data, num_splits);
|
||||
for (auto out_first = out_data; out_first != out_last; ++out_first) {
|
||||
reference::slice(data,
|
||||
out_data[i],
|
||||
*out_first,
|
||||
data_shape,
|
||||
lower_bounds,
|
||||
upper_bounds,
|
||||
Strides(lower_bounds.size(), 1),
|
||||
output_shape,
|
||||
elem_size);
|
||||
lower_bounds.at(axis) += part_length;
|
||||
upper_bounds.at(axis) += part_length;
|
||||
lb_at_axis += part_length;
|
||||
ub_at_axis += part_length;
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace ov
|
||||
|
||||
@@ -2,42 +2,46 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/reference/split.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
|
||||
#include <numeric>
|
||||
#include <split_shape_inference.hpp>
|
||||
|
||||
#include "bound_evaluate.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/split.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/reference/split.hpp"
|
||||
#include "split_shape_inference.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
|
||||
op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
|
||||
namespace v1 {
|
||||
namespace validate {
|
||||
namespace {
|
||||
bool axis_type(const element::Type& et) {
|
||||
return et.is_integral_number();
|
||||
}
|
||||
} // namespace
|
||||
} // namespace validate
|
||||
|
||||
Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
|
||||
: Op({data, axis}),
|
||||
m_num_splits{num_splits} {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v1::Split::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Split::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v1_Split_visit_attributes);
|
||||
visitor.on_attribute("num_splits", m_num_splits);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v1::Split::validate_and_infer_types() {
|
||||
void Split::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v1_Split_validate_and_infer_types);
|
||||
const auto& axis_et = get_input_element_type(1);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis_et.is_integral_number(),
|
||||
validate::axis_type(axis_et),
|
||||
"Element type of 'axis' input must be integer. Got: ",
|
||||
axis_et);
|
||||
|
||||
@@ -58,72 +62,70 @@ void op::v1::Split::validate_and_infer_types() {
|
||||
set_input_is_relevant_to_shape(0);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Split::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Split::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v1_Split_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v1::Split>(new_args.at(0), new_args.at(1), m_num_splits);
|
||||
return std::make_shared<Split>(new_args.at(0), new_args.at(1), m_num_splits);
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool op::v1::Split::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool Split::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_Split_evaluate);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, m_num_splits) && validate_host_tensor_vector(inputs, 2));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
OPENVINO_ASSERT(outputs.size() == m_num_splits);
|
||||
|
||||
if (has_evaluate()) {
|
||||
const auto output_shapes =
|
||||
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs));
|
||||
const auto& axis_tensor = inputs[1];
|
||||
const auto result = validate::axis_type(axis_tensor.get_element_type());
|
||||
if (result) {
|
||||
const auto& data_tensor = inputs[0];
|
||||
const auto& axis_tensor = inputs[1];
|
||||
|
||||
const auto input_shapes =
|
||||
std::vector<PartialShape>{data_tensor->get_partial_shape(), axis_tensor->get_partial_shape()};
|
||||
|
||||
auto output_shapes = shape_infer(this, input_shapes, make_tensor_accessor(inputs));
|
||||
|
||||
auto outputs_data = std::vector<char*>(m_num_splits);
|
||||
for (size_t i = 0; i < m_num_splits; ++i) {
|
||||
outputs[i]->set_shape(output_shapes[i].get_shape());
|
||||
outputs_data[i] = outputs[i]->get_data_ptr<char>();
|
||||
{
|
||||
auto outputs_it = outputs.begin();
|
||||
auto outputs_data_it = outputs_data.begin();
|
||||
for (const auto& p_shape : output_shapes) {
|
||||
outputs_it->set_shape(p_shape.get_shape());
|
||||
*outputs_data_it = static_cast<char*>(outputs_it->data());
|
||||
++outputs_it, ++outputs_data_it;
|
||||
}
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
auto axis = host_tensor_2_vector<int64_t>(axis_tensor)[0];
|
||||
axis = normalize_axis(this, axis, data_tensor->get_partial_shape().rank());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
auto axis = get_tensor_data_as<int64_t>(axis_tensor).front();
|
||||
axis = ov::util::normalize(axis, data_tensor.get_shape().size());
|
||||
|
||||
ov::reference::split(data_tensor->get_data_ptr<char>(),
|
||||
data_tensor->get_shape(),
|
||||
data_tensor->get_element_type().size(),
|
||||
ov::reference::split(static_cast<char*>(data_tensor.data()),
|
||||
data_tensor.get_shape(),
|
||||
data_tensor.get_element_type().size(),
|
||||
axis,
|
||||
m_num_splits,
|
||||
outputs_data.data());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
bool op::v1::Split::has_evaluate() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool Split::has_evaluate() const {
|
||||
OV_OP_SCOPE(v1_Split_has_evaluate);
|
||||
return get_input_element_type(1).is_integral_number();
|
||||
return validate::axis_type(get_input_element_type(1));
|
||||
}
|
||||
|
||||
bool op::v1::Split::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
bool Split::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
OV_OP_SCOPE(v1_Split_evaluate_lower);
|
||||
|
||||
return input(1).get_tensor().has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
|
||||
return get_input_tensor(1).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::Split::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
bool Split::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
OV_OP_SCOPE(v1_Split_evaluate_upper);
|
||||
|
||||
return input(1).get_tensor().has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
|
||||
return get_input_tensor(1).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::Split::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
bool Split::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
OPENVINO_ASSERT(output_labels.size() == get_num_splits());
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
return input(1).get_tensor().has_and_set_bound() && default_label_evaluator(this, output_labels);
|
||||
return get_input_tensor(1).has_and_set_bound() && default_label_evaluator(this, output_labels);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
||||
Reference in New Issue
Block a user