[core]Migrate Slice to new API (#20417)
* Migrate slice to new API * Remove visit_attributes, is same as base class * Move shape checks to shape_infer - minor refactor Slice op * Move `get_tensors_partial_shapes` to dev API * Correct comment Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com> --------- Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>
This commit is contained in:
parent
7874adb58e
commit
e1a33f10d5
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
#include "openvino/op/op.hpp"
|
#include "openvino/op/op.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
@ -40,14 +41,10 @@ public:
|
|||||||
const Output<Node>& axes);
|
const Output<Node>& axes);
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
|
||||||
|
|
||||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
// TODO: Update to use new evaluate with TensorVector
|
bool evaluate(TensorVector&, const TensorVector&) const override;
|
||||||
bool evaluate(const HostTensorVector&, const HostTensorVector&) const override;
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
|
||||||
bool evaluate_lower(TensorVector& outputs) const override;
|
bool evaluate_lower(TensorVector& outputs) const override;
|
||||||
bool evaluate_upper(TensorVector& outputs) const override;
|
bool evaluate_upper(TensorVector& outputs) const override;
|
||||||
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
||||||
|
@ -57,6 +57,14 @@ std::vector<TRShape> shape_infer(const Slice* op,
|
|||||||
const auto& input_shape = input_shapes[0];
|
const auto& input_shape = input_shapes[0];
|
||||||
const auto& input_rank = input_shape.rank();
|
const auto& input_rank = input_shape.rank();
|
||||||
|
|
||||||
|
// it is not possible to define output shape if input data shape rank is undefined
|
||||||
|
// even if lengths of begin, end, or strides are defined
|
||||||
|
if (input_rank.is_dynamic()) {
|
||||||
|
return {PartialShape::dynamic()};
|
||||||
|
} else {
|
||||||
|
NODE_SHAPE_INFER_CHECK(op, input_shapes, input_rank.get_length() > 0, "Slice `data` input can't be a scalar.");
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 1; i < input_shapes.size(); ++i) {
|
for (size_t i = 1; i < input_shapes.size(); ++i) {
|
||||||
const auto& shape = input_shapes[i];
|
const auto& shape = input_shapes[i];
|
||||||
const auto& shape_rank = shape.rank();
|
const auto& shape_rank = shape.rank();
|
||||||
@ -87,12 +95,6 @@ std::vector<TRShape> shape_infer(const Slice* op,
|
|||||||
"Slice `start`, `stop`, `step` inputs must have compatible shapes.");
|
"Slice `start`, `stop`, `step` inputs must have compatible shapes.");
|
||||||
|
|
||||||
auto output_shapes = std::vector<TRShape>(1);
|
auto output_shapes = std::vector<TRShape>(1);
|
||||||
// it is not possible to define output shape if input data shape rank is undefined
|
|
||||||
// even the lengths of begin, end, or strides are defined
|
|
||||||
if (input_rank.is_dynamic()) {
|
|
||||||
output_shapes[0] = PartialShape::dynamic();
|
|
||||||
return output_shapes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute constant values of begin, end, and strides if possible
|
// compute constant values of begin, end, and strides if possible
|
||||||
const auto start = get_input_bounds<TRShape, int64_t>(op, 1, ta);
|
const auto start = get_input_bounds<TRShape, int64_t>(op, 1, ta);
|
||||||
|
@ -2,223 +2,156 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "ngraph/op/slice.hpp"
|
#include "openvino/op/slice.hpp"
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "bound_evaluate.hpp"
|
#include "bound_evaluate.hpp"
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/attribute_visitor.hpp"
|
|
||||||
#include "ngraph/graph_util.hpp"
|
|
||||||
#include "ngraph/op/constant.hpp"
|
|
||||||
#include "openvino/reference/slice.hpp"
|
#include "openvino/reference/slice.hpp"
|
||||||
#include "slice_shape_inference.hpp"
|
#include "slice_shape_inference.hpp"
|
||||||
|
|
||||||
using namespace std;
|
namespace ov {
|
||||||
using namespace ngraph;
|
namespace op {
|
||||||
|
namespace {
|
||||||
|
std::vector<int64_t> default_axes(const size_t n) {
|
||||||
|
std::vector<int64_t> axes;
|
||||||
|
axes.reserve(n);
|
||||||
|
std::generate_n(std::back_inserter(axes), n, SeqGen<int64_t>(0));
|
||||||
|
return axes;
|
||||||
|
}
|
||||||
|
|
||||||
op::v8::Slice::Slice(const Output<Node>& data,
|
bool slice_bound_check(const ov::Node* const node) {
|
||||||
const Output<Node>& start,
|
return ov::have_node_inputs_bounds_set(node, 1, node->get_input_size() - 1);
|
||||||
const Output<Node>& stop,
|
}
|
||||||
const Output<Node>& step)
|
|
||||||
|
bool slice_no_axes(const Node* const node) {
|
||||||
|
return node->get_input_size() < 5;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace v8 {
|
||||||
|
using ov::op::v0::Constant;
|
||||||
|
|
||||||
|
Slice::Slice(const Output<Node>& data, const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step)
|
||||||
: Op({data, start, stop, step}) {
|
: Op({data, start, stop, step}) {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
op::v8::Slice::Slice(const Output<Node>& data,
|
Slice::Slice(const Output<Node>& data,
|
||||||
const Output<Node>& start,
|
const Output<Node>& start,
|
||||||
const Output<Node>& stop,
|
const Output<Node>& stop,
|
||||||
const Output<Node>& step,
|
const Output<Node>& step,
|
||||||
const Output<Node>& axes)
|
const Output<Node>& axes)
|
||||||
: Op({data, start, stop, step, axes}) {
|
: Op({data, start, stop, step, axes}) {
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool op::v8::Slice::visit_attributes(AttributeVisitor& visitor) {
|
std::shared_ptr<Constant> Slice::get_default_const_axes(const Output<Node>& start) const {
|
||||||
OV_OP_SCOPE(v8_Slice_visit_attributes);
|
const auto& start_pshape = start.get_partial_shape();
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<op::v0::Constant> op::v8::Slice::get_default_const_axes(const Output<Node>& start) const {
|
|
||||||
const auto start_pshape = start.get_partial_shape();
|
|
||||||
// Static case
|
// Static case
|
||||||
if (start_pshape.rank().is_static() && start_pshape.rank().get_length() == 1 && start_pshape[0].is_static()) {
|
if (start_pshape.is_static() && start_pshape.size() == 1) {
|
||||||
size_t axes_length = start_pshape[0].get_length();
|
const auto axes = default_axes(static_cast<size_t>(start_pshape[0].get_length()));
|
||||||
std::vector<int64_t> axes(axes_length);
|
return Constant::create(element::i64, start_pshape.get_shape(), axes);
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
|
||||||
return v0::Constant::create(element::i64, Shape{axes_length}, axes);
|
|
||||||
} else {
|
} else {
|
||||||
// Dynamic case
|
// Dynamic case
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
}
|
} // namespace ov
|
||||||
|
|
||||||
void op::v8::Slice::validate_and_infer_types() {
|
void Slice::validate_and_infer_types() {
|
||||||
OV_OP_SCOPE(v8_Slice_validate_and_infer_types);
|
OV_OP_SCOPE(v8_Slice_validate_and_infer_types);
|
||||||
|
|
||||||
const auto inputs_size = get_input_size();
|
if (slice_no_axes(this)) {
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
inputs_size == 4 || inputs_size == 5,
|
|
||||||
"Slice has to have 4 or 5 inputs. Got: ",
|
|
||||||
inputs_size);
|
|
||||||
|
|
||||||
const PartialShape& data_shape = get_input_partial_shape(0);
|
|
||||||
const auto& data_rank = data_shape.rank();
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
data_rank.is_dynamic() || data_rank.get_length() > 0,
|
|
||||||
"Slice `data` input can't be a scalar.");
|
|
||||||
|
|
||||||
if (get_input_size() < 5) {
|
|
||||||
if (auto axes_const = get_default_const_axes(input_value(1))) {
|
if (auto axes_const = get_default_const_axes(input_value(1))) {
|
||||||
set_argument(4, axes_const);
|
set_argument(4, axes_const);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < get_input_size(); ++i) {
|
|
||||||
if (i > 0) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
get_input_element_type(i).is_integral_number(),
|
|
||||||
"Slice `",
|
|
||||||
slice::shape_names[i - 1],
|
|
||||||
"` input type must be integer.");
|
|
||||||
}
|
|
||||||
|
|
||||||
set_input_is_relevant_to_shape(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
const auto output_shapes = shape_infer(this, input_shapes);
|
const auto output_shapes = shape_infer(this, input_shapes);
|
||||||
|
|
||||||
|
set_input_is_relevant_to_shape(0);
|
||||||
|
for (size_t i = 1; i < get_input_size(); ++i) {
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
get_input_element_type(i).is_integral_number(),
|
||||||
|
"Slice `",
|
||||||
|
slice::shape_names[i - 1],
|
||||||
|
"` input type must be integer.");
|
||||||
|
set_input_is_relevant_to_shape(i);
|
||||||
|
}
|
||||||
|
|
||||||
set_output_type(0, get_input_element_type(0), output_shapes.front());
|
set_output_type(0, get_input_element_type(0), output_shapes.front());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> op::v8::Slice::clone_with_new_inputs(const OutputVector& new_args) const {
|
std::shared_ptr<Node> Slice::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||||
OV_OP_SCOPE(v8_Slice_clone_with_new_inputs);
|
OV_OP_SCOPE(v8_Slice_clone_with_new_inputs);
|
||||||
check_new_args_count(this, new_args);
|
check_new_args_count(this, new_args);
|
||||||
if (new_args.size() == 4) {
|
if (new_args.size() == 4) {
|
||||||
return std::make_shared<v8::Slice>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
|
return std::make_shared<Slice>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
|
||||||
} else {
|
} else {
|
||||||
return std::make_shared<v8::Slice>(new_args.at(0),
|
return std::make_shared<Slice>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
|
||||||
new_args.at(1),
|
|
||||||
new_args.at(2),
|
|
||||||
new_args.at(3),
|
|
||||||
new_args.at(4));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool op::v8::Slice::has_evaluate() const {
|
bool Slice::has_evaluate() const {
|
||||||
OV_OP_SCOPE(v8_Slice_has_evaluate);
|
OV_OP_SCOPE(v8_Slice_has_evaluate);
|
||||||
switch (get_input_element_type(1)) {
|
|
||||||
case ngraph::element::i8:
|
|
||||||
case ngraph::element::i16:
|
|
||||||
case ngraph::element::i32:
|
|
||||||
case ngraph::element::i64:
|
|
||||||
case ngraph::element::u8:
|
|
||||||
case ngraph::element::u16:
|
|
||||||
case ngraph::element::u32:
|
|
||||||
case ngraph::element::u64:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (get_input_size() > 4) {
|
const auto valid_integral_type = [](const element::Type& et) -> bool {
|
||||||
switch (get_input_element_type(4)) {
|
switch (et) {
|
||||||
case ngraph::element::i8:
|
case element::i8:
|
||||||
case ngraph::element::i16:
|
case element::i16:
|
||||||
case ngraph::element::i32:
|
case element::i32:
|
||||||
case ngraph::element::i64:
|
case element::i64:
|
||||||
case ngraph::element::u8:
|
case element::u8:
|
||||||
case ngraph::element::u16:
|
case element::u16:
|
||||||
case ngraph::element::u32:
|
case element::u32:
|
||||||
case ngraph::element::u64:
|
case element::u64:
|
||||||
break;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
return true;
|
return valid_integral_type(get_input_element_type(1)) &&
|
||||||
|
(slice_no_axes(this) || valid_integral_type(get_input_element_type(4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
bool Slice::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||||
bool op::v8::Slice::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
|
||||||
OV_OP_SCOPE(v8_Slice_evaluate);
|
OV_OP_SCOPE(v8_Slice_evaluate);
|
||||||
OPENVINO_ASSERT(inputs.size() >= 4, "Slice evaluate needs at least 4 inputs.");
|
|
||||||
|
|
||||||
// Static HostTensor data shape is needed to clamp and normalize `start` values
|
const auto output_shapes =
|
||||||
OPENVINO_ASSERT(inputs[0]->get_partial_shape().is_static(),
|
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs));
|
||||||
"Can't evaluate Slice elements without static HostTensor data shape.");
|
outputs[0].set_shape(output_shapes.front().to_shape());
|
||||||
|
|
||||||
auto input_shapes = std::vector<PartialShape>();
|
const auto starts = ov::get_tensor_data_as<int64_t>(inputs[1]);
|
||||||
input_shapes.reserve(inputs.size());
|
const auto steps = ov::get_tensor_data_as<int64_t>(inputs[3]);
|
||||||
|
const auto axes = slice_no_axes(this) ? default_axes(starts.size()) : ov::get_tensor_data_as<int64_t>(inputs[4]);
|
||||||
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
reference::slice(static_cast<const char*>(inputs[0].data()),
|
||||||
auto&& tensor = inputs[i];
|
inputs[0].get_shape(),
|
||||||
input_shapes.push_back(tensor->get_partial_shape());
|
static_cast<char*>(outputs[0].data()),
|
||||||
}
|
outputs[0].get_shape(),
|
||||||
|
inputs[0].get_element_type().size(),
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
starts,
|
||||||
const auto starts = host_tensor_2_vector<int64_t>(inputs[1]);
|
steps,
|
||||||
const auto stops = host_tensor_2_vector<int64_t>(inputs[2]);
|
axes);
|
||||||
const auto steps = host_tensor_2_vector<int64_t>(inputs[3]);
|
|
||||||
|
|
||||||
std::vector<int64_t> axes;
|
|
||||||
if (inputs.size() < 5) {
|
|
||||||
axes.reserve(starts.size());
|
|
||||||
std::generate_n(std::back_inserter(axes), starts.size(), SeqGen<int64_t>(0));
|
|
||||||
} else {
|
|
||||||
axes = host_tensor_2_vector<int64_t>(inputs[4]);
|
|
||||||
}
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
|
||||||
|
|
||||||
const auto output_shapes = shape_infer(this, input_shapes, make_tensor_accessor(inputs));
|
|
||||||
OPENVINO_ASSERT(output_shapes.front().is_static(), "Can't calculate static output shape for Slice evaluation.");
|
|
||||||
|
|
||||||
outputs[0]->set_shape(output_shapes.front().to_shape());
|
|
||||||
outputs[0]->set_element_type(inputs[0]->get_element_type());
|
|
||||||
|
|
||||||
ov::reference::slice(inputs[0]->get_data_ptr<char>(),
|
|
||||||
inputs[0]->get_shape(),
|
|
||||||
outputs[0]->get_data_ptr<char>(),
|
|
||||||
outputs[0]->get_shape(),
|
|
||||||
inputs[0]->get_element_type().size(),
|
|
||||||
starts,
|
|
||||||
steps,
|
|
||||||
axes);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
|
||||||
|
|
||||||
namespace {
|
bool Slice::evaluate_lower(ov::TensorVector& output_values) const {
|
||||||
bool slice_input_check(const ov::Node* node) {
|
return slice_bound_check(this) && default_lower_bound_evaluator(this, output_values);
|
||||||
if (!node->get_input_tensor(1).has_and_set_bound())
|
|
||||||
return false;
|
|
||||||
if (!node->get_input_tensor(2).has_and_set_bound())
|
|
||||||
return false;
|
|
||||||
if (!node->get_input_tensor(3).has_and_set_bound())
|
|
||||||
return false;
|
|
||||||
if (node->get_input_size() == 5 && !node->get_input_tensor(4).has_and_set_bound())
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
bool op::v8::Slice::evaluate_lower(ov::TensorVector& output_values) const {
|
|
||||||
return slice_input_check(this) && default_lower_bound_evaluator(this, output_values);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool op::v8::Slice::evaluate_upper(ov::TensorVector& output_values) const {
|
bool Slice::evaluate_upper(ov::TensorVector& output_values) const {
|
||||||
return slice_input_check(this) && default_upper_bound_evaluator(this, output_values);
|
return slice_bound_check(this) && default_upper_bound_evaluator(this, output_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool op::v8::Slice::evaluate_label(TensorLabelVector& output_labels) const {
|
bool Slice::evaluate_label(TensorLabelVector& output_labels) const {
|
||||||
if (!slice_input_check(this))
|
|
||||||
return false;
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
return default_label_evaluator(this, output_labels);
|
return slice_bound_check(this) && default_label_evaluator(this, output_labels);
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
}
|
}
|
||||||
|
} // namespace v8
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
||||||
|
Loading…
Reference in New Issue
Block a user