[core] Migrate StridedSlice operator to new API (#21342)
* Drop legacy stuff * Repalce HostTensor with ov::Tensor
This commit is contained in:
parent
fa1cc89cf3
commit
eec370a88b
@ -108,9 +108,7 @@ public:
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
bool evaluate_lower(TensorVector& outputs) const override;
|
||||
bool evaluate_upper(TensorVector& outputs) const override;
|
||||
|
@ -31,7 +31,7 @@ void strided_slice(const char* arg,
|
||||
}
|
||||
|
||||
ov::AlignedBuffer slice_out_buffer(shape_size(sp.reshape_in_shape) * elem_type);
|
||||
slice(reinterpret_cast<const char*>(arg),
|
||||
slice(arg,
|
||||
slice_out_buffer.get_ptr<char>(),
|
||||
arg_shape,
|
||||
Coordinate(sp.begins.begin(), sp.begins.end()),
|
||||
|
@ -2,55 +2,53 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/strided_slice.hpp"
|
||||
#include "openvino/op/strided_slice.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "bound_evaluate.hpp"
|
||||
#include "compare.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/shape_of.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/type/element_type_traits.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/core/attribute_visitor.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/util/precision_sensitive_attribute.hpp"
|
||||
#include "openvino/op/util/slice_plan.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
#include "openvino/reference/strided_slice.hpp"
|
||||
#include "strided_slice_shape_inference.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
op::v1::StridedSlice::StridedSlice(const Output<Node>& data,
|
||||
const Output<Node>& begin,
|
||||
const Output<Node>& end,
|
||||
const Output<Node>& strides,
|
||||
const std::vector<int64_t>& begin_mask,
|
||||
const std::vector<int64_t>& end_mask,
|
||||
const std::vector<int64_t>& new_axis_mask,
|
||||
const std::vector<int64_t>& shrink_axis_mask,
|
||||
const std::vector<int64_t>& ellipsis_mask)
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v1 {
|
||||
StridedSlice::StridedSlice(const Output<Node>& data,
|
||||
const Output<Node>& begin,
|
||||
const Output<Node>& end,
|
||||
const Output<Node>& strides,
|
||||
const std::vector<int64_t>& begin_mask,
|
||||
const std::vector<int64_t>& end_mask,
|
||||
const std::vector<int64_t>& new_axis_mask,
|
||||
const std::vector<int64_t>& shrink_axis_mask,
|
||||
const std::vector<int64_t>& ellipsis_mask)
|
||||
: Op({data, begin, end, strides}),
|
||||
m_begin_mask{begin_mask},
|
||||
m_end_mask{end_mask},
|
||||
m_new_axis_mask{new_axis_mask},
|
||||
m_shrink_axis_mask{shrink_axis_mask},
|
||||
m_ellipsis_mask{ellipsis_mask} {
|
||||
ov::mark_as_precision_sensitive(input(1));
|
||||
ov::mark_as_precision_sensitive(input(2));
|
||||
ov::mark_as_precision_sensitive(input(3));
|
||||
mark_as_precision_sensitive(input(1));
|
||||
mark_as_precision_sensitive(input(2));
|
||||
mark_as_precision_sensitive(input(3));
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
namespace {
|
||||
shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Output<Node>& end) {
|
||||
std::shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Output<Node>& end) {
|
||||
const auto begin_pshape = begin.get_partial_shape();
|
||||
const auto end_pshape = end.get_partial_shape();
|
||||
|
||||
@ -63,11 +61,11 @@ shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Outp
|
||||
{
|
||||
OPENVINO_ASSERT(begin_pshape.rank().is_static() && begin_pshape.rank().get_length() == 1,
|
||||
"Begin input must be 1D");
|
||||
return std::make_shared<op::v1::Broadcast>(op::Constant::create(element::i64, {}, {1}),
|
||||
std::make_shared<op::ShapeOf>(begin));
|
||||
return std::make_shared<v1::Broadcast>(v0::Constant::create(element::i64, {}, {1}),
|
||||
std::make_shared<v0::ShapeOf>(begin));
|
||||
}
|
||||
|
||||
return op::Constant::create(element::i64, ov::Shape{strides_length}, vector<int64_t>(strides_length, 1));
|
||||
return v0::Constant::create(element::i64, Shape{strides_length}, std::vector<int64_t>(strides_length, 1));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -77,8 +75,8 @@ shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Outp
|
||||
* @param ignored_mask Axis set of ignored indices.
|
||||
* @return True if all ignored other wise false.
|
||||
*/
|
||||
bool all_indices_ignored(const ov::PartialShape& shape, const std::vector<int64_t>& ignore_mask) {
|
||||
auto ignored = shape.rank().is_static() && ov::cmp::le(shape[0].get_interval().get_max_val(), ignore_mask.size());
|
||||
bool all_indices_ignored(const PartialShape& shape, const std::vector<int64_t>& ignore_mask) {
|
||||
auto ignored = shape.rank().is_static() && cmp::le(shape[0].get_interval().get_max_val(), ignore_mask.size());
|
||||
for (size_t i = 0; ignored && i < static_cast<size_t>(shape[0].get_interval().get_max_val()); ++i) {
|
||||
ignored = static_cast<bool>(ignore_mask[i]);
|
||||
}
|
||||
@ -86,14 +84,14 @@ bool all_indices_ignored(const ov::PartialShape& shape, const std::vector<int64_
|
||||
}
|
||||
} // namespace
|
||||
|
||||
op::v1::StridedSlice::StridedSlice(const Output<Node>& data,
|
||||
const Output<Node>& begin,
|
||||
const Output<Node>& end,
|
||||
const std::vector<int64_t>& begin_mask,
|
||||
const std::vector<int64_t>& end_mask,
|
||||
const std::vector<int64_t>& new_axis_mask,
|
||||
const std::vector<int64_t>& shrink_axis_mask,
|
||||
const std::vector<int64_t>& ellipsis_mask)
|
||||
StridedSlice::StridedSlice(const Output<Node>& data,
|
||||
const Output<Node>& begin,
|
||||
const Output<Node>& end,
|
||||
const std::vector<int64_t>& begin_mask,
|
||||
const std::vector<int64_t>& end_mask,
|
||||
const std::vector<int64_t>& new_axis_mask,
|
||||
const std::vector<int64_t>& shrink_axis_mask,
|
||||
const std::vector<int64_t>& ellipsis_mask)
|
||||
: StridedSlice(data,
|
||||
begin,
|
||||
end,
|
||||
@ -104,7 +102,7 @@ op::v1::StridedSlice::StridedSlice(const Output<Node>& data,
|
||||
shrink_axis_mask,
|
||||
ellipsis_mask) {}
|
||||
|
||||
bool op::v1::StridedSlice::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool StridedSlice::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v1_StridedSlice_visit_attributes);
|
||||
visitor.on_attribute("begin_mask", m_begin_mask);
|
||||
visitor.on_attribute("end_mask", m_end_mask);
|
||||
@ -114,7 +112,7 @@ bool op::v1::StridedSlice::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v1::StridedSlice::validate_and_infer_types() {
|
||||
void StridedSlice::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v1_StridedSlice_validate_and_infer_types);
|
||||
const auto& begin_mask_et = get_input_element_type(1);
|
||||
const auto& end_mask_et = get_input_element_type(2);
|
||||
@ -137,11 +135,11 @@ void op::v1::StridedSlice::validate_and_infer_types() {
|
||||
std::all_of(m_ellipsis_mask.begin(), m_ellipsis_mask.end(), are_mask_elem_in_range),
|
||||
"All masks of StridedSlice must have be 0 or 1");
|
||||
|
||||
const vector<size_t> attr_sizes = {m_begin_mask.size(),
|
||||
m_end_mask.size(),
|
||||
m_new_axis_mask.size(),
|
||||
m_shrink_axis_mask.size(),
|
||||
m_ellipsis_mask.size()};
|
||||
const std::vector<size_t> attr_sizes = {m_begin_mask.size(),
|
||||
m_end_mask.size(),
|
||||
m_new_axis_mask.size(),
|
||||
m_shrink_axis_mask.size(),
|
||||
m_ellipsis_mask.size()};
|
||||
const auto are_attr_sizes_eq = std::all_of(attr_sizes.begin(), attr_sizes.end(), [&attr_sizes](size_t s) {
|
||||
return (s == 0) || (attr_sizes[0] == s);
|
||||
});
|
||||
@ -165,7 +163,7 @@ void op::v1::StridedSlice::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||
}
|
||||
|
||||
AxisSet op::v1::StridedSlice::convert_mask_to_axis_set(const std::vector<int64_t>& mask) const {
|
||||
AxisSet StridedSlice::convert_mask_to_axis_set(const std::vector<int64_t>& mask) const {
|
||||
AxisSet axis_set{};
|
||||
for (size_t i = 0; i < static_cast<size_t>(mask.size()); ++i) {
|
||||
if (mask[i] == 1) {
|
||||
@ -175,89 +173,79 @@ AxisSet op::v1::StridedSlice::convert_mask_to_axis_set(const std::vector<int64_t
|
||||
return axis_set;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> StridedSlice::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v1_StridedSlice_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v1::StridedSlice>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
m_begin_mask,
|
||||
m_end_mask,
|
||||
m_new_axis_mask,
|
||||
m_shrink_axis_mask,
|
||||
m_ellipsis_mask);
|
||||
return std::make_shared<v1::StridedSlice>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
m_begin_mask,
|
||||
m_end_mask,
|
||||
m_new_axis_mask,
|
||||
m_shrink_axis_mask,
|
||||
m_ellipsis_mask);
|
||||
}
|
||||
|
||||
namespace strided_slice {
|
||||
namespace {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
inline bool evaluate(const HostTensorPtr& in, const ov::op::util::SlicePlan& sp, const HostTensorPtr& out)
|
||||
|
||||
{
|
||||
auto in_shape = in->get_shape();
|
||||
out->set_shape(sp.reshape_out_shape);
|
||||
ov::reference::strided_slice(in->get_data_ptr<char>(),
|
||||
out->get_data_ptr<char>(),
|
||||
in_shape,
|
||||
sp,
|
||||
in->get_element_type().size());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_strided_slice(const HostTensorPtr& in,
|
||||
const HostTensorPtr& begin,
|
||||
const HostTensorPtr& end,
|
||||
const HostTensorPtr& stride,
|
||||
bool evaluate_strided_slice(const Tensor& data,
|
||||
const Tensor& begin,
|
||||
const Tensor& end,
|
||||
const Tensor& stride,
|
||||
const AxisSet& begin_mask,
|
||||
const AxisSet& end_mask,
|
||||
const AxisSet& new_axis_mask,
|
||||
const AxisSet& shrink_axis_mask,
|
||||
const AxisSet& ellipsis_mask,
|
||||
const HostTensorPtr& out) {
|
||||
std::vector<int64_t> begin_const = host_tensor_2_vector<int64_t>(begin);
|
||||
std::vector<int64_t> end_const = host_tensor_2_vector<int64_t>(end);
|
||||
std::vector<int64_t> stride_const = host_tensor_2_vector<int64_t>(stride);
|
||||
const auto slice_plan = ov::op::util::make_slice_plan(in->get_shape(),
|
||||
begin_const,
|
||||
end_const,
|
||||
stride_const,
|
||||
begin_mask,
|
||||
end_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask,
|
||||
ellipsis_mask);
|
||||
return evaluate(in, slice_plan, out);
|
||||
Tensor& output) {
|
||||
const auto begin_const = get_tensor_data_as<int64_t>(begin);
|
||||
const auto end_const = get_tensor_data_as<int64_t>(end);
|
||||
const auto stride_const = get_tensor_data_as<int64_t>(stride);
|
||||
const auto& data_shape = data.get_shape();
|
||||
const auto slice_plan = util::make_slice_plan(data_shape,
|
||||
begin_const,
|
||||
end_const,
|
||||
stride_const,
|
||||
begin_mask,
|
||||
end_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask,
|
||||
ellipsis_mask);
|
||||
output.set_shape(slice_plan.reshape_out_shape);
|
||||
reference::strided_slice(reinterpret_cast<const char*>(data.data()),
|
||||
reinterpret_cast<char*>(output.data()),
|
||||
data_shape,
|
||||
slice_plan,
|
||||
data.get_element_type().size());
|
||||
return true;
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
} // namespace
|
||||
} // namespace strided_slice
|
||||
|
||||
bool op::v1::StridedSlice::evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const {
|
||||
bool StridedSlice::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v1_StridedSlice_evaluate);
|
||||
// FIXME: 4th input is optional, but it is required by the following code
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(input_values, 4));
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(output_values, 1));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
return strided_slice::evaluate_strided_slice(input_values[0],
|
||||
input_values[1],
|
||||
input_values[2],
|
||||
input_values[3],
|
||||
OPENVINO_ASSERT(inputs.size() == 4);
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
return strided_slice::evaluate_strided_slice(inputs[0],
|
||||
inputs[1],
|
||||
inputs[2],
|
||||
inputs[3],
|
||||
convert_mask_to_axis_set(get_begin_mask()),
|
||||
convert_mask_to_axis_set(get_end_mask()),
|
||||
convert_mask_to_axis_set(get_new_axis_mask()),
|
||||
convert_mask_to_axis_set(get_shrink_axis_mask()),
|
||||
convert_mask_to_axis_set(get_ellipsis_mask()),
|
||||
output_values[0]);
|
||||
outputs[0]);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::has_evaluate() const {
|
||||
bool StridedSlice::has_evaluate() const {
|
||||
OV_OP_SCOPE(v1_StridedSlice_has_evaluate);
|
||||
return get_input_size() == 4;
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::indices_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& mask) const {
|
||||
bool StridedSlice::indices_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& mask) const {
|
||||
const auto& lb_t = get_input_tensor(port).get_lower_value();
|
||||
const auto& ub_t = get_input_tensor(port).get_upper_value();
|
||||
|
||||
@ -267,8 +255,8 @@ bool op::v1::StridedSlice::indices_input_has_and_set_bounds(const size_t port, c
|
||||
if (!valid_bounds && lb_t && ub_t) {
|
||||
using TCast = int64_t;
|
||||
constexpr auto i64_cast = ov::util::Cast<TCast>();
|
||||
const auto lb = ov::get_tensor_data_as<TCast>(lb_t, i64_cast);
|
||||
const auto ub = ov::get_tensor_data_as<TCast>(ub_t, i64_cast);
|
||||
const auto lb = get_tensor_data_as<TCast>(lb_t, i64_cast);
|
||||
const auto ub = get_tensor_data_as<TCast>(ub_t, i64_cast);
|
||||
|
||||
size_t axis = 0;
|
||||
valid_bounds =
|
||||
@ -280,25 +268,25 @@ bool op::v1::StridedSlice::indices_input_has_and_set_bounds(const size_t port, c
|
||||
return valid_bounds;
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
bool StridedSlice::evaluate_lower(TensorVector& output_values) const {
|
||||
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
bool StridedSlice::evaluate_upper(TensorVector& output_values) const {
|
||||
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
bool StridedSlice::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
default_label_evaluator(this, {0}, output_labels);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
|
||||
bool StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
|
||||
auto is_folded = Node::constant_fold(output_values, inputs_values);
|
||||
if (!is_const_fold_disabled() && !is_folded) {
|
||||
// If all ignored mask are set for all begin or end then replace this input by dummy constant
|
||||
@ -316,7 +304,7 @@ bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const Outp
|
||||
size = mask.size();
|
||||
|
||||
const auto& zero_constant =
|
||||
make_shared<ov::opset1::Constant>(inputs_values[port].get_element_type(), ov::Shape{size}, 0);
|
||||
std::make_shared<v0::Constant>(inputs_values[port].get_element_type(), Shape{size}, 0);
|
||||
return all_indices_ignored(inputs_values[port].get_partial_shape(), mask) ? zero_constant
|
||||
: inputs_values[port];
|
||||
};
|
||||
@ -331,12 +319,10 @@ bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const Outp
|
||||
|
||||
std::vector<Node*> nodes;
|
||||
// Check if bounds can be evaluated and none of output nodes have disabled constant folding.
|
||||
if (ov::could_propagate(output, nodes) && std::none_of(nodes.begin(), nodes.end(), [](const Node* n) {
|
||||
return ov::pass::constant_folding_is_disabled(n);
|
||||
if (could_propagate(output, nodes) && std::none_of(nodes.begin(), nodes.end(), [](const Node* n) {
|
||||
return pass::constant_folding_is_disabled(n);
|
||||
})) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
if (const auto c = ov::get_constant_from_source(output)) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
if (const auto c = ov::util::get_constant_from_source(output)) {
|
||||
output_values[0] = c;
|
||||
auto output_ptr = output_values[0].get_node_shared_ptr();
|
||||
for (const auto& n : nodes) {
|
||||
@ -348,3 +334,6 @@ bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const Outp
|
||||
}
|
||||
return is_folded;
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user