Review slice for shape inference aspects (#14611)
* Review slice ope for - Interval dimension and label propagation - add template shape inference with static shape test - check preserve partial values on inputs - check upper/lower evaluate * Add bounds evaluation for inputs start, stop * Share code between slice and strided slice Use same function to calculate elements in step * Add array includes * Add to int64_t strides size * Fix windows compile warnings * Fix shape inference for unknown axes * Remove empty lines in slice shape inference * Correct slice static shape tests * Use arrays of const chars to store literals Remove and update exception messages for strided slice * Fix slice test and apply review comments * Fix compilation issues * Fix ellipsis when there is not begin * Fix get element type for const inputs * Insert optional axes as const or dynamic param * Remove temp vectors for dimensions calculation * Revert set optional input in ctor * Fix forward slicing for negative start and MAX end
This commit is contained in:
parent
3094384d74
commit
a1203b931a
@ -483,7 +483,7 @@ def test_graph_preprocess_crop():
|
||||
"Relu",
|
||||
"Slice",
|
||||
]
|
||||
assert len(model_operators) == 7
|
||||
assert len(model_operators) == 8
|
||||
assert function.get_output_size() == 1
|
||||
assert list(function.get_output_shape(0)) == [1, 2, 1, 1]
|
||||
assert function.get_output_element_type(0) == Type.f32
|
||||
|
@ -30,6 +30,15 @@ void infer_auto_padding(const Shape& image_shape,
|
||||
CoordinateDiff& padding_above,
|
||||
CoordinateDiff& padding_below);
|
||||
|
||||
/// \brief Normalize value to the max if value is negative.
|
||||
///
|
||||
/// \param value Input value to normalize.
|
||||
/// \param max Value used for normalization
|
||||
///
|
||||
/// \return Value if positive otherwise return value + max
|
||||
OPENVINO_API
|
||||
int64_t normalize(const int64_t& value, const int64_t& max);
|
||||
|
||||
/// \brief Handle out of range axis.
|
||||
///
|
||||
/// \param[in] node The node with requested axis.
|
||||
@ -172,4 +181,21 @@ OPENVINO_API std::vector<PartialShape> get_node_input_partial_shapes(const ov::N
|
||||
///
|
||||
/// \return True if rank compatible to any from ranks, otherwise false.
|
||||
OPENVINO_API bool is_rank_compatible_any_of(const ov::Rank& rank, const std::vector<ov::Rank>& ranks);
|
||||
|
||||
/// \brief Check if values in vector are unique.
|
||||
///
|
||||
/// \param data Input data to check.
|
||||
///
|
||||
/// \return True if unique otherwise false.
|
||||
OPENVINO_API bool are_unique(const std::vector<int64_t>& data);
|
||||
|
||||
/// \brief Clip value to minimum if below min, or to maximum of above max.
|
||||
///
|
||||
/// \param value Value to be clipped.
|
||||
/// \param min Minimum value bound.
|
||||
/// \param max Maximum value boiund
|
||||
///
|
||||
/// \return Value if between min, max otherwise min or max.
|
||||
OPENVINO_API
|
||||
int64_t clip(const int64_t& value, const int64_t& min, const int64_t& max);
|
||||
} // namespace ov
|
||||
|
@ -54,12 +54,7 @@ public:
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
||||
|
||||
std::shared_ptr<ngraph::op::v0::Constant> get_default_const_axes(const Output<Node>& start) const;
|
||||
PartialShape calculate_output_shape(const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t>& stops,
|
||||
const std::vector<int64_t>& steps,
|
||||
const std::vector<int64_t>& axes,
|
||||
const PartialShape& data_shape) const;
|
||||
std::shared_ptr<v0::Constant> get_default_const_axes(const Output<Node>& start) const;
|
||||
};
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
|
@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "gru_cell_shape_inference.hpp"
|
||||
#include "ov_ops/augru_cell.hpp"
|
||||
#include "ov_ops/augru_sequence.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
|
153
src/core/shape_inference/include/slice_shape_inference.hpp
Normal file
153
src/core/shape_inference/include/slice_shape_inference.hpp
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <openvino/op/slice.hpp>
|
||||
|
||||
#include "slice_shape_inference_utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
|
||||
namespace slice {
|
||||
|
||||
constexpr std::array<char const*, 4> shape_names{"start", "stop", "step", "axes"};
|
||||
|
||||
struct AxesMap {
|
||||
bool is_valid{}; //!< Flag indicates current axes map has valid data (unique).
|
||||
std::map<size_t, size_t> m{}; //!< Map axis value to index of start, stop order.
|
||||
|
||||
void add(const std::vector<int64_t>& axes) {
|
||||
const auto exp_size = std::accumulate(axes.cbegin(), axes.cend(), m.size(), [this](size_t i, int64_t axis) {
|
||||
m.emplace(static_cast<size_t>(axis), i);
|
||||
return ++i;
|
||||
});
|
||||
|
||||
is_valid = exp_size == m.size();
|
||||
}
|
||||
|
||||
void generate_n(size_t n) {
|
||||
n += m.size();
|
||||
for (size_t i = m.size(); i < n; ++i) {
|
||||
m.emplace(i, i);
|
||||
}
|
||||
is_valid = m.size() == n;
|
||||
}
|
||||
};
|
||||
} // namespace slice
|
||||
|
||||
namespace v8 {
|
||||
|
||||
template <class T>
|
||||
void shape_infer(const Slice* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
|
||||
const auto& num_of_inputs = input_shapes.size();
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
num_of_inputs == 4 || num_of_inputs == 5,
|
||||
"Slice has to have 4 or 5 inputs. Got: ",
|
||||
num_of_inputs);
|
||||
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1);
|
||||
|
||||
const auto& input_shape = input_shapes[0];
|
||||
const auto& input_rank = input_shape.rank();
|
||||
|
||||
for (size_t i = 1; i < input_shapes.size(); ++i) {
|
||||
const auto& shape = input_shapes[i];
|
||||
const auto& shape_rank = shape.rank();
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
shape_rank.compatible(1),
|
||||
"Slice `",
|
||||
slice::shape_names[i - 1],
|
||||
"` input must be a 1D tensor. Got rank: ",
|
||||
shape_rank);
|
||||
|
||||
if (input_rank.is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
shape_rank.is_dynamic() || shape[0].get_min_length() <= input_rank.get_length(),
|
||||
"Slice `",
|
||||
slice::shape_names[i - 1],
|
||||
"` input dim size can't be bigger than `data` rank.");
|
||||
}
|
||||
}
|
||||
|
||||
const auto& start_shape = input_shapes[1];
|
||||
const auto& stop_shape = input_shapes[2];
|
||||
const auto& step_shape = input_shapes[3];
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
start_shape.compatible(stop_shape) && start_shape.compatible(step_shape) && stop_shape.compatible(step_shape),
|
||||
"Slice `start`, `stop`, `step` inputs must have compatible shapes.");
|
||||
|
||||
// it is not possible to define output shape if input data shape rank is undefined
|
||||
// even the lengths of begin, end, or strides are defined
|
||||
if (input_rank.is_dynamic()) {
|
||||
output_shapes[0] = PartialShape::dynamic();
|
||||
return;
|
||||
}
|
||||
|
||||
// compute constant values of begin, end, and strides if possible
|
||||
const auto start = slice::get_input_bounds<T>(op, 1, constant_data);
|
||||
const auto stop = slice::get_input_bounds<T>(op, 2, constant_data);
|
||||
const auto steps = get_input_const_data_as<T, int64_t>(op, 3, constant_data);
|
||||
|
||||
slice::AxesMap axes_map;
|
||||
if (input_shapes.size() > 4) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes[4].compatible(start_shape),
|
||||
"Slice `axes` input must have compatible shape with `start`, `stop`, `step` inputs.");
|
||||
|
||||
if (auto axes = get_input_const_data_as<T, int64_t>(op, 4, constant_data)) {
|
||||
ov::normalize_axes(op, input_shape.rank().get_length(), *axes);
|
||||
axes_map.add(*axes);
|
||||
NODE_VALIDATION_CHECK(op, axes_map.is_valid, "Slice values in `axes` input must be unique.");
|
||||
}
|
||||
} else if (start) {
|
||||
axes_map.generate_n(start->size());
|
||||
}
|
||||
|
||||
auto axis_it = axes_map.m.cbegin();
|
||||
|
||||
auto& out = output_shapes.front();
|
||||
out.resize(0);
|
||||
out.reserve(input_shape.size());
|
||||
for (size_t dim_idx = 0; dim_idx < input_shape.size(); ++dim_idx) {
|
||||
const DimType& input_dim = input_shape[dim_idx];
|
||||
|
||||
if (axes_map.is_valid && (axis_it != axes_map.m.cend()) && (axis_it->first == dim_idx)) {
|
||||
const auto& i = axis_it->second;
|
||||
|
||||
if (start && stop && steps) {
|
||||
const auto& step = (*steps)[i];
|
||||
NODE_VALIDATION_CHECK(op, step != 0, "Step must be non-zero");
|
||||
out.push_back(slice::make_dim(input_dim, (*start)[i], (*stop)[i], step));
|
||||
} else {
|
||||
out.emplace_back(0, input_dim.get_max_length());
|
||||
}
|
||||
|
||||
auto& last_dim = out[out.size() - 1];
|
||||
if (std::is_same<DimType, ov::Dimension>::value && (last_dim == input_dim)) {
|
||||
// for equal ov::Dimension do merge to get input label (always success)
|
||||
DimType::merge(last_dim, last_dim, input_dim);
|
||||
}
|
||||
++axis_it;
|
||||
} else if (axes_map.is_valid) {
|
||||
// dimension not on axes list, no change
|
||||
out.push_back(input_dim);
|
||||
} else {
|
||||
// axes are unknow so any dimension can be sliced
|
||||
out.emplace_back(0, input_dim.get_max_length());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
} // namespace ov
|
273
src/core/shape_inference/include/slice_shape_inference_utils.hpp
Normal file
273
src/core/shape_inference/include/slice_shape_inference_utils.hpp
Normal file
@ -0,0 +1,273 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
|
||||
#include "sequnce_generator.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace internal {
|
||||
/**
|
||||
* \brief Check if value of type T has got maximum value of type U.
|
||||
*
|
||||
* \tparam T Input value type
|
||||
* \tparam U Type to get its minimum for comparision. Default same as T.
|
||||
*
|
||||
* \param value Input value.
|
||||
*
|
||||
* \return True if input value has got maximum value of type U otherwise false.
|
||||
*/
|
||||
template <class T, class U = T>
|
||||
constexpr bool is_max(const T& value) {
|
||||
return std::numeric_limits<U>::max() == value;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Check if value of type T has got minimum value of type U.
|
||||
*
|
||||
* \tparam T Input value type.
|
||||
* \tparam U Type to get its minimum for comparision. Default same as T.
|
||||
*
|
||||
* \param value Input value.
|
||||
*
|
||||
* \return True if input value has got minimum value of type U otherwise false.
|
||||
*/
|
||||
template <class T, class U = T>
|
||||
constexpr bool is_min(const T& value) {
|
||||
return std::numeric_limits<U>::min() == value;
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
namespace element {
|
||||
/**
|
||||
* \brief Check if value has got maximum value of ov::element::Type_t
|
||||
*
|
||||
* \tparam T Input value type.
|
||||
*
|
||||
* \param type ov::element type to get its maximum.
|
||||
* \param value Input value for check.
|
||||
*
|
||||
* \return True if input value has got maximum number specified by ov::element type otherwise false.
|
||||
*/
|
||||
template <class T>
|
||||
bool is_max_of(const element::Type_t& type, const T& value) {
|
||||
switch (type) {
|
||||
case element::i32:
|
||||
return internal::is_max<T, typename element_type_traits<element::i32>::value_type>(value);
|
||||
case element::i64:
|
||||
return internal::is_max<T, typename element_type_traits<element::i64>::value_type>(value);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Check if value has got minimum value of ov::element::Type_t
|
||||
*
|
||||
* \tparam T Input value type.
|
||||
*
|
||||
* \param type ov::element type to get its minimum.
|
||||
* \param value Input value for check.
|
||||
*
|
||||
* \return True if input value has got minimum number specified by ov::element type otherwise false.
|
||||
*/
|
||||
template <class T>
|
||||
bool is_min_of(const element::Type_t type, const T& value) {
|
||||
switch (type) {
|
||||
case element::i32:
|
||||
return internal::is_min<T, typename element_type_traits<element::i32>::value_type>(value);
|
||||
case element::i64:
|
||||
return internal::is_min<T, typename element_type_traits<element::i64>::value_type>(value);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Checks input value for element type maximum or minimum and return limit or value.
|
||||
*
|
||||
* \tparam T Type of input value.
|
||||
* \tparam U Type of return value. Default same as T.
|
||||
*
|
||||
* \param type Type of ov::element::Type_t
|
||||
* \param value Input value for check.
|
||||
*
|
||||
* \return If value is maximum or minimum get limit of U otherwise value as U.
|
||||
*/
|
||||
template <class T, class U = T>
|
||||
U get_value_or_limit_of(const element::Type_t& type, const T& value) {
|
||||
if (is_min_of(type, value)) {
|
||||
return std::numeric_limits<U>::min();
|
||||
} else if (is_max_of(type, value)) {
|
||||
return std::numeric_limits<U>::max();
|
||||
} else {
|
||||
return static_cast<U>(value);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace element
|
||||
|
||||
namespace op {
|
||||
namespace slice {
|
||||
|
||||
/**
|
||||
* \brief Get sliced value in step for given dimension value and start, stop, step.
|
||||
*
|
||||
* \note This function cannot be use for step 0 (division by 0)
|
||||
*
|
||||
* \param dim Dimension value.
|
||||
* \param start Start of slice.
|
||||
* \param stop Stop of slice.
|
||||
* \param step Step of slice.
|
||||
*
|
||||
* \return -1 for infinite number otherwise [0..int64_max] for finit step.
|
||||
*/
|
||||
inline int64_t get_sliced_value(const int64_t& dim, const int64_t& start, const int64_t& stop, const int64_t& step) {
|
||||
const auto is_reverse_step = step < 0;
|
||||
|
||||
constexpr int64_t min_bound = 0;
|
||||
constexpr int64_t inf_bound = -1;
|
||||
|
||||
const auto& norm_dim = dim == inf_bound ? std::numeric_limits<int64_t>::max() : dim;
|
||||
const auto is_norm_dim_max = ov::internal::is_max(norm_dim);
|
||||
const int64_t lower_max = is_reverse_step ? norm_dim - 1 : norm_dim;
|
||||
const int64_t upper_min = is_reverse_step ? inf_bound : min_bound;
|
||||
|
||||
const auto is_start_lt_min_bound = start < min_bound;
|
||||
const auto are_bounds_diff_sign = is_start_lt_min_bound != (stop < 0);
|
||||
|
||||
const auto is_start_max = ov::internal::is_max(start);
|
||||
const auto is_start_limit = is_start_max || ov::internal::is_min(start);
|
||||
const auto is_stop_max = ov::internal::is_max(stop);
|
||||
const auto any_bound_max = is_start_max || is_stop_max;
|
||||
// Prepare bounds for sliced value calculation.
|
||||
int64_t lb, ub;
|
||||
if (is_norm_dim_max && (are_bounds_diff_sign || any_bound_max || is_start_limit)) {
|
||||
if (is_reverse_step) {
|
||||
ub = (is_start_lt_min_bound || any_bound_max) ? inf_bound : inf_bound - start;
|
||||
} else if (is_start_lt_min_bound && !is_start_limit) {
|
||||
ub = is_stop_max ? -start : stop;
|
||||
} else {
|
||||
ub = inf_bound;
|
||||
}
|
||||
lb = min_bound;
|
||||
} else {
|
||||
lb = clip(normalize(start, norm_dim), min_bound, lower_max);
|
||||
ub = clip(normalize(stop, norm_dim), upper_min, norm_dim);
|
||||
}
|
||||
|
||||
// Calculate sliced value from bounds and step.
|
||||
if (is_norm_dim_max && lb == min_bound && ub == inf_bound) {
|
||||
return inf_bound;
|
||||
} else {
|
||||
// Limit sliced value to not-positive for negative step or not-negative for positive step
|
||||
auto sliced_value =
|
||||
is_reverse_step ? std::min<int64_t>(min_bound, (ub - lb)) : std::max<int64_t>(min_bound, (ub - lb));
|
||||
|
||||
if (step == -1) {
|
||||
// Sliced value is negative for negative step return opposite
|
||||
sliced_value = -sliced_value;
|
||||
} else if (sliced_value != 0 && step != 1) {
|
||||
// Need to calculate sliced value for step. Depends on step direction reduce sliced value
|
||||
// in order to calculate it in one-step division (no modulo required)
|
||||
is_reverse_step ? ++sliced_value : --sliced_value;
|
||||
sliced_value /= step;
|
||||
++sliced_value;
|
||||
} else {
|
||||
// There is no need for calculations as sliced value is 0 or step is 1.
|
||||
}
|
||||
return sliced_value;
|
||||
}
|
||||
}
|
||||
|
||||
// To get element type from constant or tensor.
|
||||
inline element::Type get_input_const_element_type(const ov::Node* op,
|
||||
size_t idx,
|
||||
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
|
||||
if (constant_data.count(idx)) {
|
||||
return constant_data.at(idx)->get_element_type();
|
||||
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
||||
return constant->get_element_type();
|
||||
} else {
|
||||
return element::undefined;
|
||||
}
|
||||
}
|
||||
|
||||
using Bounds = std::pair<int64_t, int64_t>; //!< Alias to dimension bounds for slice.
|
||||
|
||||
/**
|
||||
* \brief Get the input bounds from constant input (constant map) or evaluate bunds
|
||||
* and return them as vector of pairs (lower, upper).
|
||||
*
|
||||
* \tparam TShape Shape type.
|
||||
*
|
||||
* \param op Operator pointer.
|
||||
* \param idx Input index.
|
||||
* \param constant_data Map with constant data.
|
||||
*
|
||||
* \return Return vector of slice::Bounds.
|
||||
*/
|
||||
template <class TShape, class TResult = std::vector<Bounds>>
|
||||
std::unique_ptr<TResult> get_input_bounds(const ov::Node* op,
|
||||
size_t idx,
|
||||
const std::map<size_t, HostTensorPtr>& constant_data) {
|
||||
// Helper to create TResult from lowers and uppers.
|
||||
const auto make_bounds_vec =
|
||||
[](const element::Type& et, const std::vector<int64_t>& lowers, const std::vector<int64_t>& uppers) {
|
||||
TResult out;
|
||||
out.reserve(lowers.size());
|
||||
std::transform(lowers.begin(),
|
||||
lowers.end(),
|
||||
uppers.begin(),
|
||||
std::back_inserter(out),
|
||||
[&et](int64_t lb, int64_t ub) {
|
||||
return std::make_pair(element::get_value_or_limit_of(et, lb),
|
||||
element::get_value_or_limit_of(et, ub));
|
||||
});
|
||||
return out;
|
||||
};
|
||||
|
||||
std::unique_ptr<TResult> out;
|
||||
if (auto lowers = op::get_input_const_data_as<TShape, int64_t>(op, idx, constant_data)) {
|
||||
const auto& et = get_input_const_element_type(op, idx, constant_data);
|
||||
out.reset(new TResult(make_bounds_vec(et, *lowers, *lowers)));
|
||||
} else {
|
||||
auto bounds = ngraph::evaluate_both_bounds(op->get_input_source_output(idx));
|
||||
if (bounds.first && bounds.second) {
|
||||
const auto& et = op->get_input_element_type(idx);
|
||||
auto lowers = std::make_shared<op::v0::Constant>(bounds.first)->cast_vector<int64_t>();
|
||||
auto uppers = std::make_shared<op::v0::Constant>(bounds.second)->cast_vector<int64_t>();
|
||||
out.reset(new TResult(make_bounds_vec(et, lowers, uppers)));
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Make sliced dimension for input dimension by step from start to stop bounds.
|
||||
*
|
||||
* \tparam TDim Type of in/out dimension.
|
||||
*
|
||||
* \param dim
|
||||
* \param start Slice start bounds.
|
||||
* \param stop Slice stop bounds.
|
||||
* \param step Slice step.
|
||||
*
|
||||
* \return Dimension with upper/lower values set according slice inputs.
|
||||
*/
|
||||
template <class TDim>
|
||||
TDim make_dim(const TDim& dim, const Bounds& start, const Bounds& stop, int64_t step) {
|
||||
using TDimVal = typename TDim::value_type;
|
||||
auto lb = static_cast<TDimVal>(get_sliced_value(dim.get_min_length(), start.second, stop.first, step));
|
||||
auto ub = static_cast<TDimVal>(get_sliced_value(dim.get_max_length(), start.first, stop.second, step));
|
||||
|
||||
return {lb, ub};
|
||||
}
|
||||
} // namespace slice
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -4,9 +4,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <array>
|
||||
#include <openvino/op/strided_slice.hpp>
|
||||
|
||||
#include "slice_shape_inference_utils.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -19,30 +20,24 @@ void shape_infer(const StridedSlice* op,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
static constexpr std::array<char const*, 3> shape_names{"Begin", "End", "Strides"};
|
||||
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3 || input_shapes.size() == 4) && output_shapes.size() == 1);
|
||||
|
||||
const auto& input_shape = input_shapes[0];
|
||||
|
||||
for (size_t i = 1; i < input_shapes.size(); ++i) {
|
||||
const auto& shape_rank = input_shapes[i].rank();
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
shape_rank.compatible(1),
|
||||
shape_names[i - 1],
|
||||
" input must be 1D (has rank: ",
|
||||
shape_rank,
|
||||
")");
|
||||
}
|
||||
|
||||
const auto& begin_shape = input_shapes[1];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
begin_shape.rank().compatible(1),
|
||||
"Begin input must be 1D (begin rank: ",
|
||||
begin_shape.rank(),
|
||||
").");
|
||||
|
||||
const auto& end_shape = input_shapes[2];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
end_shape.rank().compatible(1),
|
||||
"End input must be 1D (end rank: ",
|
||||
end_shape.rank(),
|
||||
").");
|
||||
|
||||
const auto& strides_shape = input_shapes.size() < 4 ? op->get_input_shape(3) : input_shapes[3];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
strides_shape.rank().compatible(1),
|
||||
"Strides input must be 1D (strides rank: ",
|
||||
strides_shape.rank(),
|
||||
").");
|
||||
|
||||
// it is not possible to define output shape if input data shape rank is undefined
|
||||
// even the lengths of begin, end, or strides are defined
|
||||
@ -52,21 +47,6 @@ void shape_infer(const StridedSlice* op,
|
||||
}
|
||||
auto input_rank = input_shape.size();
|
||||
|
||||
const auto get_input_bounds = [&](size_t idx) {
|
||||
std::vector<int64_t> lower, upper;
|
||||
if (!get_data_as_int64<T>(idx, op, lower, constant_data)) {
|
||||
// if no const data try get input bounds
|
||||
auto bounds = ngraph::evaluate_both_bounds(op->get_input_source_output(idx));
|
||||
|
||||
if (bounds.first && bounds.second) {
|
||||
lower = std::make_shared<op::v0::Constant>(bounds.first)->cast_vector<int64_t>();
|
||||
upper = std::make_shared<op::v0::Constant>(bounds.second)->cast_vector<int64_t>();
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(lower, upper);
|
||||
};
|
||||
|
||||
auto number_elements_in_1d = [](const StridedSlice* op, const T& shape_1d) -> int64_t {
|
||||
auto rank_1d = shape_1d.rank();
|
||||
if (rank_1d.is_static()) {
|
||||
@ -79,20 +59,15 @@ void shape_infer(const StridedSlice* op,
|
||||
};
|
||||
|
||||
// compute constant values of begin, end, and strides if possible
|
||||
auto begin = get_input_bounds(1);
|
||||
auto end = get_input_bounds(2);
|
||||
auto got_begin = !begin.first.empty();
|
||||
auto got_end = !end.first.empty();
|
||||
|
||||
std::vector<int64_t> strides;
|
||||
bool got_strides = false;
|
||||
const auto begin = slice::get_input_bounds<T>(op, 1, constant_data);
|
||||
const auto end = slice::get_input_bounds<T>(op, 2, constant_data);
|
||||
|
||||
std::unique_ptr<std::vector<int64_t>> strides;
|
||||
if (input_shapes.size() > 3) {
|
||||
got_strides = get_data_as_int64<T>(3, op, strides, constant_data);
|
||||
} else if (got_begin) {
|
||||
strides = get_input_const_data_as<T, int64_t>(op, 3, constant_data);
|
||||
} else if (begin) {
|
||||
// generate default strides
|
||||
strides.resize(begin.first.size(), 1);
|
||||
got_strides = true;
|
||||
strides.reset(new std::vector<int64_t>(begin->size(), 1));
|
||||
}
|
||||
|
||||
// compute and check a number of axes for which begin, end, and strides are defined
|
||||
@ -103,7 +78,7 @@ void shape_infer(const StridedSlice* op,
|
||||
} else if (end_number_axes != -1) {
|
||||
number_axes = end_number_axes;
|
||||
}
|
||||
auto strides_number_axes = number_elements_in_1d(op, strides_shape);
|
||||
auto strides_number_axes = strides ? static_cast<int64_t>(strides->size()) : static_cast<int64_t>(-1);
|
||||
if (number_axes != -1 && strides_number_axes != -1) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
number_axes == strides_number_axes,
|
||||
@ -139,7 +114,8 @@ void shape_infer(const StridedSlice* op,
|
||||
"Input rank plus number of new axis has to be at least the size of Lower "
|
||||
"and Upper bounds vector.");
|
||||
|
||||
std::vector<DimType> dims;
|
||||
auto& out = output_shapes.front();
|
||||
out.resize(0);
|
||||
int64_t input_shape_idx = 0;
|
||||
for (int64_t axis = 0; axis < number_axes; ++axis) {
|
||||
// add all dimensions hidden under the ellipsis mask if ellipsis mask is set
|
||||
@ -152,9 +128,12 @@ void shape_infer(const StridedSlice* op,
|
||||
num_input_axis_before_ellipses++;
|
||||
}
|
||||
}
|
||||
for (size_t i = axis + 1; i < begin.first.size(); ++i) {
|
||||
if (new_axis_mask.count(i)) {
|
||||
num_new_axis_after_ellipses++;
|
||||
|
||||
if (begin) {
|
||||
for (size_t i = axis + 1; i < begin->size(); ++i) {
|
||||
if (new_axis_mask.count(i)) {
|
||||
num_new_axis_after_ellipses++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,90 +141,49 @@ void shape_infer(const StridedSlice* op,
|
||||
(number_axes - axis - num_new_axis_after_ellipses - 1); // -1 because it's a position of ellipses
|
||||
int64_t num_of_hidden_dims = input_rank - num_input_axis_after_ellipses - num_input_axis_before_ellipses;
|
||||
for (int64_t i = 0; i < num_of_hidden_dims; ++i, ++input_shape_idx) {
|
||||
dims.emplace_back(input_shape[input_shape_idx]);
|
||||
out.emplace_back(input_shape[input_shape_idx]);
|
||||
}
|
||||
} else {
|
||||
// add new single dimension if new_axis_mask is set
|
||||
if (new_axis_mask.count(axis)) {
|
||||
dims.emplace_back(1);
|
||||
out.emplace_back(1);
|
||||
}
|
||||
// skip this dimension if shrink_axis_mask is set
|
||||
else if (shrink_axis_mask.count(axis)) {
|
||||
input_shape_idx++;
|
||||
}
|
||||
// calculating dimension (begin, end, begin_mask, end_mask, stride)
|
||||
else if (got_begin && got_end && got_strides) {
|
||||
else if (begin && end && strides) {
|
||||
// set default value for stride or use given value
|
||||
auto stride = (strides.size() > static_cast<size_t>(axis)) ? strides[axis] : static_cast<int64_t>(1);
|
||||
const auto& input_dim = input_shape[input_shape_idx];
|
||||
auto stride =
|
||||
(strides->size() > static_cast<size_t>(axis)) ? (*strides)[axis] : static_cast<int64_t>(1);
|
||||
NODE_VALIDATION_CHECK(op, stride != 0, "Stride must be non-zero");
|
||||
// normalize by add max to value if negative
|
||||
const auto normalize = [](const int64_t& value, const int64_t& max) -> int64_t {
|
||||
return (value < 0) ? value + max : value;
|
||||
};
|
||||
|
||||
// clip value to min, max
|
||||
const auto clip = [](const int64_t& value, const int64_t& min, const int64_t& max) -> int64_t {
|
||||
return std::min(std::max(value, min), max);
|
||||
};
|
||||
constexpr int64_t inf_bound = -1;
|
||||
const auto is_reverse_stride = stride < 0;
|
||||
const int64_t norm_dim = (input_dim.get_max_length() == inf_bound) ? std::numeric_limits<int64_t>::max()
|
||||
: input_dim.get_max_length();
|
||||
const slice::Bounds default_fstart = std::make_pair<int64_t, int64_t>(0, 0);
|
||||
const slice::Bounds default_rstop = std::make_pair(inf_bound - norm_dim, inf_bound - norm_dim);
|
||||
const slice::Bounds norm_dim_bounds = std::make_pair(norm_dim, norm_dim);
|
||||
|
||||
// get stride output dimension for dimension and bounds
|
||||
// may not be called for stride 0 (div by 0!!!) assert check done above
|
||||
const auto get_output_dim = [&](const int64_t& dim, const int64_t& lower, const int64_t& upper) {
|
||||
const auto is_reverse_stride = stride < 0;
|
||||
const auto& default_start = is_reverse_stride ? norm_dim_bounds : default_fstart;
|
||||
const auto& default_stop = is_reverse_stride ? default_rstop : norm_dim_bounds;
|
||||
|
||||
constexpr int64_t lower_min = 0;
|
||||
const int64_t lower_max = is_reverse_stride ? dim - 1 : dim;
|
||||
const int64_t upper_min = is_reverse_stride ? -1 : lower_min;
|
||||
const int64_t default_min = is_reverse_stride ? lower_max : lower_min;
|
||||
const int64_t default_max = is_reverse_stride ? -1 : dim;
|
||||
const auto& start = begin_mask.count(axis) ? default_start : (*begin)[axis];
|
||||
const auto& stop = end_mask.count(axis) ? default_stop : (*end)[axis];
|
||||
auto sliced_dim = slice::make_dim(input_dim, start, stop, stride);
|
||||
|
||||
auto lb = begin_mask.count(axis) ? default_min : clip(normalize(lower, dim), lower_min, lower_max);
|
||||
auto ub = end_mask.count(axis) ? default_max : clip(normalize(upper, dim), upper_min, dim);
|
||||
|
||||
// decrees range by modifing lower bound depends on stride direction
|
||||
is_reverse_stride ? --lb : ++lb;
|
||||
|
||||
if ((is_reverse_stride && lb >= ub) || (!is_reverse_stride && lb <= ub)) {
|
||||
return ((ub - lb) / stride) + 1;
|
||||
} else {
|
||||
return static_cast<int64_t>(0);
|
||||
}
|
||||
};
|
||||
|
||||
const auto& begin_lb = begin.first[axis];
|
||||
const auto& begin_ub = begin.second.empty() ? begin_lb : begin.second[axis];
|
||||
|
||||
const auto& end_lb = end.first[axis];
|
||||
const auto& end_ub = end.second.empty() ? end_lb : end.second[axis];
|
||||
|
||||
if (input_shape[input_shape_idx].is_dynamic()) {
|
||||
// the relationship between input and output length is monotonically increasing
|
||||
// so we repeat the dimension inference twice to infer dynamic dimension
|
||||
const auto& interval = input_shape[input_shape_idx].get_interval();
|
||||
auto lb = get_output_dim(interval.get_min_val(), begin_ub, end_lb);
|
||||
auto ub =
|
||||
interval.has_upper_bound() ? get_output_dim(interval.get_max_val(), begin_lb, end_ub) : -1;
|
||||
dims.emplace_back(lb, ub);
|
||||
} else {
|
||||
const auto& dimension = input_shape[input_shape_idx].get_length();
|
||||
auto lb = get_output_dim(dimension, begin_ub, end_lb);
|
||||
auto ub = get_output_dim(dimension, begin_lb, end_ub);
|
||||
dims.emplace_back(lb, ub);
|
||||
}
|
||||
|
||||
if (std::is_same<DimType, ov::Dimension>::value && dims.back() == input_shape[input_shape_idx]) {
|
||||
if (std::is_same<DimType, ov::Dimension>::value && (sliced_dim == input_dim)) {
|
||||
// for equal ov::Dimension do merge to get input label (always success)
|
||||
DimType::merge(dims.back(), dims.back(), input_shape[input_shape_idx]);
|
||||
DimType::merge(sliced_dim, sliced_dim, input_dim);
|
||||
}
|
||||
out.push_back(std::move(sliced_dim));
|
||||
|
||||
input_shape_idx++;
|
||||
} else {
|
||||
if (input_shape[input_shape_idx].is_static()) {
|
||||
auto dim_value = input_shape[input_shape_idx].get_length();
|
||||
dims.emplace_back(0, dim_value);
|
||||
} else {
|
||||
dims.emplace_back(-1);
|
||||
}
|
||||
out.emplace_back(0, input_shape[input_shape_idx].get_max_length());
|
||||
|
||||
input_shape_idx++;
|
||||
}
|
||||
@ -254,10 +192,8 @@ void shape_infer(const StridedSlice* op,
|
||||
|
||||
// get remaining values
|
||||
for (; input_shape_idx < input_shape.rank().get_length(); ++input_shape_idx) {
|
||||
dims.push_back(input_shape[input_shape_idx]);
|
||||
out.push_back(input_shape[input_shape_idx]);
|
||||
}
|
||||
|
||||
output_shapes[0] = T(std::move(dims));
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/runtime/reference/slice.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "slice_shape_inference.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -33,63 +34,23 @@ op::v8::Slice::Slice(const Output<Node>& data,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
int64_t get_sliced_dim_size(int64_t start, int64_t stop, int64_t step, int64_t dim_size) {
|
||||
// Normalize index
|
||||
start = start < 0 ? dim_size + start : start;
|
||||
stop = stop < 0 ? dim_size + stop : stop;
|
||||
|
||||
// Clip normalized bounds according to the dim size
|
||||
start = std::max(int64_t(0), std::min(start, dim_size)); // inclusive
|
||||
stop = std::max(int64_t(-1), std::min(stop, dim_size)); // exclusive
|
||||
|
||||
int64_t elements_in_range = 0;
|
||||
if (step < 0) {
|
||||
// Clip max start index (last element inclusively)
|
||||
elements_in_range = std::max(int64_t(0), std::min(dim_size - 1, start) - stop);
|
||||
} else {
|
||||
// Clip max stop index (last element exclusively)
|
||||
elements_in_range = std::max(int64_t(0), std::min(dim_size, stop) - start);
|
||||
}
|
||||
const int64_t rest = elements_in_range % std::abs(step);
|
||||
const int64_t integer_div = elements_in_range / std::abs(step);
|
||||
const int64_t sliced_dim_size = !rest ? integer_div : integer_div + 1;
|
||||
return sliced_dim_size;
|
||||
}
|
||||
|
||||
bool is_max_int(element::Type_t ind_type, int64_t value) {
|
||||
int64_t max_type_value = 0;
|
||||
switch (ind_type) {
|
||||
case element::i32:
|
||||
max_type_value = std::numeric_limits<typename element_type_traits<element::i32>::value_type>::max();
|
||||
break;
|
||||
case element::i64:
|
||||
max_type_value = std::numeric_limits<typename element_type_traits<element::i64>::value_type>::max();
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return max_type_value == value;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool op::v8::Slice::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v8_Slice_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::op::v0::Constant> op::v8::Slice::get_default_const_axes(const Output<Node>& start) const {
|
||||
std::shared_ptr<op::v0::Constant> op::v8::Slice::get_default_const_axes(const Output<Node>& start) const {
|
||||
const auto start_pshape = start.get_partial_shape();
|
||||
// Static case
|
||||
if (start_pshape.rank().is_static() && start_pshape.rank().get_length() == 1 && start_pshape[0].is_static()) {
|
||||
size_t axes_length = start_pshape[0].get_length();
|
||||
std::vector<int64_t> axes(axes_length);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return op::v0::Constant::create(element::i64, Shape{axes_length}, axes);
|
||||
return v0::Constant::create(element::i64, Shape{axes_length}, axes);
|
||||
} else {
|
||||
// Dynamic case
|
||||
return {};
|
||||
}
|
||||
// Dynamic case
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void op::v8::Slice::validate_and_infer_types() {
|
||||
@ -108,136 +69,29 @@ void op::v8::Slice::validate_and_infer_types() {
|
||||
data_rank.is_dynamic() || data_rank.get_length() > 0,
|
||||
"Slice `data` input can't be a scalar.");
|
||||
|
||||
const auto start_const = get_constant_from_source(input_value(1));
|
||||
const auto stop_const = get_constant_from_source(input_value(2));
|
||||
const auto step_const = get_constant_from_source(input_value(3));
|
||||
|
||||
const auto& start_input = start_const ? start_const : input_value(1);
|
||||
const auto& stop_input = stop_const ? stop_const : input_value(2);
|
||||
const auto& step_input = step_const ? step_const : input_value(3);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
start_input.get_element_type().is_integral_number(),
|
||||
"Slice `start` input type must be integer.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
stop_input.get_element_type().is_integral_number(),
|
||||
"Slice `stop` input type must be integer.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
step_input.get_element_type().is_integral_number(),
|
||||
"Slice `step` input type must be integer.");
|
||||
|
||||
const auto& start_shape = start_input.get_partial_shape();
|
||||
const auto& stop_shape = stop_input.get_partial_shape();
|
||||
const auto& step_shape = step_input.get_partial_shape();
|
||||
|
||||
const auto& start_rank = start_shape.rank();
|
||||
const auto& stop_rank = stop_shape.rank();
|
||||
const auto& step_rank = step_shape.rank();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
start_rank.compatible(1),
|
||||
"Slice `start` input must be a 1D tensor. Got rank: ",
|
||||
start_rank);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
stop_rank.compatible(1),
|
||||
"Slice `stop` input must be a 1D tensor. Got rank: ",
|
||||
stop_rank);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
step_rank.compatible(1),
|
||||
"Slice `step` input must be a 1D tensor. Got rank: ",
|
||||
step_rank);
|
||||
|
||||
if (data_rank.is_static()) {
|
||||
const auto data_rank_length = data_rank.get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
start_rank.is_dynamic() || start_shape[0].get_min_length() <= data_rank_length,
|
||||
"Slice `start` input dim size can't be bigger than `data` rank.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
stop_rank.is_dynamic() || stop_shape[0].get_min_length() <= data_rank_length,
|
||||
"Slice `stop` input dim size can't be bigger than `data` rank.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
step_rank.is_dynamic() || step_shape[0].get_min_length() <= data_rank_length,
|
||||
"Slice `step` input dim size can't be bigger than `data` rank.");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
start_shape.compatible(stop_shape) && start_shape.compatible(step_shape) && stop_shape.compatible(step_shape),
|
||||
"Slice `start`, `stop`, `step` inputs must have compatible shapes.");
|
||||
|
||||
set_input_is_relevant_to_shape(0);
|
||||
set_input_is_relevant_to_shape(1);
|
||||
set_input_is_relevant_to_shape(2);
|
||||
set_input_is_relevant_to_shape(3);
|
||||
|
||||
std::shared_ptr<ngraph::op::v0::Constant> axes_const;
|
||||
if (get_input_size() > 4) {
|
||||
set_input_is_relevant_to_shape(4);
|
||||
axes_const = get_constant_from_source(input_value(4));
|
||||
const auto& axes_input = axes_const ? axes_const : input_value(4);
|
||||
const auto& axes_rank = axes_input.get_partial_shape().rank();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_rank.compatible(1),
|
||||
"Slice `axes` input must be a 1D tensor. Got rank: ",
|
||||
axes_rank);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_rank.is_dynamic() || axes_input.get_partial_shape()[0].get_max_length() <=
|
||||
data_rank.get_interval().get_max_val(),
|
||||
"Slice `axes` input dim size can't be bigger than `data` rank.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_input.get_partial_shape().compatible(start_shape),
|
||||
"Slice `axes` input must have compatible shape with `start`, `stop`, `step` inputs.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_input.get_element_type().is_integral_number(),
|
||||
"Slice `axes` input type must be integer.");
|
||||
} else {
|
||||
axes_const = get_default_const_axes(start_input);
|
||||
}
|
||||
|
||||
PartialShape output_shape(data_shape);
|
||||
|
||||
// If data_shape rank is dynamic we can't calulate output shape.
|
||||
// Even with const start/stop/step/axes, we don't know how many axes should be copied
|
||||
// as "unspefified" in the final output shape, so the output shape rank is also dynamic.
|
||||
if (data_rank.is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
if (start_const && stop_const && step_const && axes_const) {
|
||||
const auto& starts = start_const->cast_vector<int64_t>();
|
||||
const auto& stops = stop_const->cast_vector<int64_t>();
|
||||
const auto& steps = step_const->cast_vector<int64_t>();
|
||||
const auto& axes = axes_const->cast_vector<int64_t>();
|
||||
|
||||
output_shape = calculate_output_shape(starts, stops, steps, axes, data_shape);
|
||||
} else {
|
||||
const auto data_static_rank = data_shape.rank().get_length();
|
||||
OPENVINO_ASSERT(data_static_rank >= 0);
|
||||
if (axes_const) {
|
||||
// If we know only `axes` values, we should update lower_bound to 0 value,
|
||||
// for the specified dims by the axes. For unspecified dims, bounds as in data_shape.
|
||||
for (const auto& axis : axes_const->cast_vector<int64_t>()) {
|
||||
const auto norm_axis = axis < 0 ? data_static_rank + axis : axis;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
norm_axis >= 0 && norm_axis < data_static_rank,
|
||||
"Values in the `axes` input must be in range of the `data` input rank: [-",
|
||||
data_static_rank,
|
||||
", ",
|
||||
data_static_rank - 1,
|
||||
"]. Got: ",
|
||||
axis);
|
||||
output_shape[norm_axis] = Dimension(0, data_shape[norm_axis].get_max_length());
|
||||
}
|
||||
} else {
|
||||
// Otherwise `axes` values are also unknown,
|
||||
// then all of the output dims can be 0, so have lower bound = 0.
|
||||
for (size_t i = 0; i < static_cast<size_t>(data_static_rank); ++i) {
|
||||
output_shape[i] = Dimension(0, data_shape[i].get_max_length());
|
||||
}
|
||||
if (get_input_size() < 5) {
|
||||
if (auto axes_const = get_default_const_axes(input_value(1))) {
|
||||
set_argument(4, axes_const);
|
||||
}
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
|
||||
for (size_t i = 0; i < get_input_size(); ++i) {
|
||||
if (i > 0) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(i).is_integral_number(),
|
||||
"Slice `",
|
||||
slice::shape_names[i - 1],
|
||||
"` input type must be integer.");
|
||||
}
|
||||
|
||||
set_input_is_relevant_to_shape(i);
|
||||
}
|
||||
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
|
||||
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, get_input_element_type(0), output_shapes.front());
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v8::Slice::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
@ -254,76 +108,6 @@ std::shared_ptr<Node> op::v8::Slice::clone_with_new_inputs(const OutputVector& n
|
||||
}
|
||||
}
|
||||
|
||||
PartialShape op::v8::Slice::calculate_output_shape(const std::vector<int64_t>& starts,
|
||||
const std::vector<int64_t>& stops,
|
||||
const std::vector<int64_t>& steps,
|
||||
const std::vector<int64_t>& axes,
|
||||
const PartialShape& data_shape) const {
|
||||
OV_OP_SCOPE(v8_Slice_calculate_output_shape);
|
||||
const auto ind_size = starts.size();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
stops.size() == ind_size && steps.size() == ind_size && axes.size() == ind_size,
|
||||
"Slice `start`, `stop`, `step`, `axes` inputs need to have the same size.");
|
||||
|
||||
std::unordered_set<int64_t> axes_set(axes.begin(), axes.end());
|
||||
NODE_VALIDATION_CHECK(this, axes_set.size() == axes.size(), "Slice values in `axes` input must be unique.");
|
||||
|
||||
PartialShape output_shape(data_shape);
|
||||
if (data_shape.rank().is_dynamic()) {
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
const auto data_static_rank = data_shape.rank().get_length();
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
const auto norm_axis = axes[i] < 0 ? data_static_rank + axes[i] : axes[i];
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
norm_axis >= 0 && norm_axis < data_static_rank,
|
||||
"Values in the `axes` input must be in range of the `data` input rank: [-",
|
||||
data_static_rank,
|
||||
", ",
|
||||
data_static_rank - 1,
|
||||
"]. Got: ",
|
||||
axes[i]);
|
||||
|
||||
auto start = starts[i];
|
||||
auto stop = stops[i];
|
||||
auto step = steps[i];
|
||||
|
||||
NODE_VALIDATION_CHECK(this, step != 0, "Slice 'step' value can't be zero.");
|
||||
|
||||
const auto& axis_dim = data_shape[norm_axis];
|
||||
const auto axis_min_dim_length = axis_dim.get_min_length();
|
||||
const auto min_dim_size = get_sliced_dim_size(start, stop, step, axis_min_dim_length);
|
||||
if (axis_dim.is_static()) {
|
||||
output_shape[norm_axis] = min_dim_size;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Avoid negative index normalization without upper bounds
|
||||
if (!axis_dim.get_interval().has_upper_bound()) {
|
||||
if (is_max_int(get_input_element_type(2), stop) || is_max_int(get_input_element_type(1), start)) {
|
||||
output_shape[norm_axis] = Dimension(-1);
|
||||
continue;
|
||||
} else if ((step < 0 && start < 0 && stop > 0) || (step > 0 && stop < 0 && start >= 0)) {
|
||||
output_shape[norm_axis] = Dimension(-1);
|
||||
continue;
|
||||
} else if (step < 0 && start > 0 && stop < 0) {
|
||||
output_shape[norm_axis] = Dimension(0, start + 1);
|
||||
continue;
|
||||
} else if (step > 0 && stop > 0 && start < 0) {
|
||||
output_shape[norm_axis] = Dimension(0, stop);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate max dim length (upper bound)
|
||||
auto axis_max_dim_length = axis_dim.get_interval().get_max_val();
|
||||
const auto max_dim_size = get_sliced_dim_size(start, stop, step, axis_max_dim_length);
|
||||
output_shape[norm_axis] = Dimension(min_dim_size, max_dim_size);
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
bool op::v8::Slice::has_evaluate() const {
|
||||
OV_OP_SCOPE(v8_Slice_has_evaluate);
|
||||
switch (get_input_element_type(1)) {
|
||||
@ -361,33 +145,45 @@ bool op::v8::Slice::has_evaluate() const {
|
||||
|
||||
bool op::v8::Slice::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v8_Slice_evaluate);
|
||||
|
||||
OPENVINO_ASSERT(inputs.size() >= 4, "Slice evaluate needs at least 4 inputs.");
|
||||
std::vector<int64_t> starts = host_tensor_2_vector<int64_t>(inputs[1]);
|
||||
std::vector<int64_t> stops = host_tensor_2_vector<int64_t>(inputs[2]);
|
||||
std::vector<int64_t> steps = host_tensor_2_vector<int64_t>(inputs[3]);
|
||||
|
||||
std::vector<int64_t> axes(starts.size());
|
||||
// Static HostTensor data shape is needed to clamp and normalize `start` values
|
||||
OPENVINO_ASSERT(inputs[0]->get_partial_shape().is_static(),
|
||||
"Can't evaluate Slice elements without static HostTensor data shape.");
|
||||
|
||||
auto constant_data = std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>{};
|
||||
auto input_shapes = std::vector<PartialShape>();
|
||||
input_shapes.reserve(inputs.size());
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto&& tensor = inputs[i];
|
||||
input_shapes.push_back(tensor->get_partial_shape());
|
||||
constant_data.emplace(i, tensor);
|
||||
}
|
||||
|
||||
const auto starts = host_tensor_2_vector<int64_t>(inputs[1]);
|
||||
const auto stops = host_tensor_2_vector<int64_t>(inputs[2]);
|
||||
const auto steps = host_tensor_2_vector<int64_t>(inputs[3]);
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
if (inputs.size() < 5) {
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
axes.reserve(starts.size());
|
||||
std::generate_n(std::back_inserter(axes), starts.size(), SeqGen<int64_t>(0));
|
||||
} else {
|
||||
axes = host_tensor_2_vector<int64_t>(inputs[4]);
|
||||
}
|
||||
|
||||
// Static HostTensor data shape is needed to clamp and normalize `start` values
|
||||
const auto& data_shape = inputs[0]->get_partial_shape();
|
||||
OPENVINO_ASSERT(data_shape.is_static(), "Can't evaluate Slice elements without static HostTensor data shape.");
|
||||
// We need calculate static output shape based on HostTensor inputs
|
||||
PartialShape output_shape = calculate_output_shape(starts, stops, steps, axes, data_shape);
|
||||
OPENVINO_ASSERT(output_shape.is_static(), "Can't calculate static output shape for Slice evaluation.");
|
||||
auto output_shapes = std::vector<PartialShape>(1);
|
||||
shape_infer(this, input_shapes, output_shapes, constant_data);
|
||||
OPENVINO_ASSERT(output_shapes.front().is_static(), "Can't calculate static output shape for Slice evaluation.");
|
||||
|
||||
outputs[0]->set_shape(output_shape.to_shape());
|
||||
outputs[0]->set_shape(output_shapes.front().to_shape());
|
||||
outputs[0]->set_element_type(inputs[0]->get_element_type());
|
||||
|
||||
ngraph::runtime::reference::slice(inputs[0]->get_data_ptr<char>(),
|
||||
data_shape.to_shape(),
|
||||
inputs[0]->get_shape(),
|
||||
outputs[0]->get_data_ptr<char>(),
|
||||
output_shape.to_shape(),
|
||||
outputs[0]->get_shape(),
|
||||
inputs[0]->get_element_type().size(),
|
||||
starts,
|
||||
steps,
|
||||
|
@ -870,6 +870,10 @@ std::string normalize_axis_error_msg(const int64_t& axis, const int64_t& lower,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t ov::normalize(const int64_t& value, const int64_t& max) {
|
||||
return (value < 0) ? value + max : value;
|
||||
};
|
||||
|
||||
void ov::normalize_axes(const Node* node, const int64_t& tensor_rank, std::vector<int64_t>& axes) {
|
||||
const auto axis_checker = cmp::Between<int64_t, cmp::BOTH>(-tensor_rank, tensor_rank ? (tensor_rank - 1) : 0);
|
||||
const auto invalid_axis = std::find_if_not(axes.cbegin(), axes.cend(), axis_checker);
|
||||
@ -1659,7 +1663,7 @@ void ov::generate_transpose_default_order(std::vector<int64_t>& axes_order, cons
|
||||
}
|
||||
|
||||
bool ov::is_valid_axes_order(const std::vector<int64_t>& axes_order, const size_t size) {
|
||||
return (std::unordered_set<size_t>(axes_order.cbegin(), axes_order.cend()).size() == size) &&
|
||||
return are_unique(axes_order) &&
|
||||
std::all_of(axes_order.cbegin(), axes_order.cend(), ov::cmp::Between<int64_t, ov::cmp::LOWER>(0, size));
|
||||
}
|
||||
|
||||
@ -1677,3 +1681,12 @@ bool ov::is_rank_compatible_any_of(const ov::Rank& rank, const std::vector<Rank>
|
||||
return rank.compatible(r);
|
||||
});
|
||||
}
|
||||
|
||||
bool ov::are_unique(const std::vector<int64_t>& data) {
|
||||
return std::unordered_set<int64_t>(data.begin(), data.cend()).size() == data.size();
|
||||
}
|
||||
|
||||
// clip value to min, max
|
||||
int64_t ov::clip(const int64_t& value, const int64_t& min, const int64_t& max) {
|
||||
return std::min(std::max(value, min), max);
|
||||
};
|
||||
|
@ -5,11 +5,14 @@
|
||||
#include <dimension_tracker.hpp>
|
||||
#include <numeric>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "sequnce_generator.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
@ -140,23 +143,27 @@ TEST(type_prop, slice_v8_basic_const_inputs_unordered_axes) {
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), expected_out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_const_inputs_not_all_axes_unordered) {
|
||||
TEST(type_prop, slice_v8_const_inputs_not_all_axes_unordered_prop_labels) {
|
||||
PartialShape data_shape{10, 10, 10, 10, 10, 20, Dimension(20, 30), 30, Dimension(2, 5), Dimension(-1)};
|
||||
PartialShape expected_out_shape{4, 7, 10, 10, 9, 20, Dimension(10, 15), 30, Dimension(2, 5), Dimension(-1)};
|
||||
|
||||
set_shape_labels(data_shape, 10);
|
||||
|
||||
std::vector<int32_t> start_val{1, 1, -20, 9, 10, 9};
|
||||
std::vector<int32_t> stop_val{8, 8, 20, -11, 25, 0};
|
||||
std::vector<int32_t> step_val{1, 2, 1, -1, 1, -1};
|
||||
|
||||
std::vector<int32_t> axes_val{1, 0, 2, 3, 6, 4};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
constexpr auto et = element::i32;
|
||||
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val, axes_val};
|
||||
const auto op = make_slice_op_const_inputs(input_vals, data_shape, et);
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), et);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), expected_out_shape);
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)),
|
||||
ElementsAre(ov::no_label, ov::no_label, 12, 13, ov::no_label, 15, ov::no_label, 17, 18, 19));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_basic_const_inputs_data_dynamic_bounds_dimensions) {
|
||||
@ -200,7 +207,7 @@ TEST(type_prop, slice_v8_basic_const_inputs_data_dynamic_rank) {
|
||||
EXPECT_TRUE(op->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_basic_param_inputs_default_axes) {
|
||||
TEST(type_prop, slice_v8_basic_param_inputs_default_axes_labels_prop) {
|
||||
PartialShape data_shape{Dimension(0, 10),
|
||||
Dimension(1, 10),
|
||||
10,
|
||||
@ -219,12 +226,13 @@ TEST(type_prop, slice_v8_basic_param_inputs_default_axes) {
|
||||
Dimension(0, 8),
|
||||
Dimension(4, 8),
|
||||
16};
|
||||
set_shape_labels(data_shape, 10);
|
||||
|
||||
PartialShape start_shape{7};
|
||||
PartialShape stop_shape{7};
|
||||
PartialShape step_shape{7};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
constexpr auto et = element::i32;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, start_shape);
|
||||
@ -235,6 +243,8 @@ TEST(type_prop, slice_v8_basic_param_inputs_default_axes) {
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), et);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), expected_out_shape);
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)),
|
||||
ElementsAre(10, ov::no_label, ov::no_label, ov::no_label, 14, ov::no_label, 16, 17, 18));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_sss_param_inputs_mixed_neg_const_axes) {
|
||||
@ -460,12 +470,12 @@ TEST(type_prop, slice_v8_basic_const_inputs_MAX_MIN_INT_dynamic_dimensions_neg_s
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_basic_const_inputs_data_full_dynamic_dims) {
|
||||
PartialShape data_shape{Dimension(-1), Dimension(-1), Dimension(-1)};
|
||||
PartialShape expected_out_shape{Dimension(0, 6), Dimension(0, 15), Dimension(0, 5)};
|
||||
PartialShape data_shape{-1, -1, -1, -1};
|
||||
PartialShape expected_out_shape{{0, 6}, {0, 15}, {0, 5}, -1};
|
||||
|
||||
std::vector<int32_t> start_val{2, 10, 35};
|
||||
std::vector<int32_t> stop_val{8, 25, 40};
|
||||
std::vector<int32_t> step_val{1, 1, 1};
|
||||
std::vector<int32_t> start_val{2, 10, 35, INT32_MIN};
|
||||
std::vector<int32_t> stop_val{8, 25, 40, -3};
|
||||
std::vector<int32_t> step_val{1, 1, 1, 1};
|
||||
|
||||
std::vector<int32_t> axes_val(start_val.size());
|
||||
std::iota(axes_val.begin(), axes_val.end(), 0);
|
||||
@ -536,12 +546,12 @@ TEST(type_prop, slice_v8_basic_const_inputs_data_full_dynamic_dims_neg_step_neg_
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_basic_const_inputs_data_full_dynamic_dims_neg_step_mix_ind) {
|
||||
PartialShape data_shape{Dimension(-1), Dimension(-1), Dimension(-1), Dimension(-1), Dimension(-1)};
|
||||
PartialShape expected_out_shape{Dimension(0, 6), Dimension(0, 6), Dimension(-1), Dimension(0, -1), Dimension(-1)};
|
||||
PartialShape data_shape{-1, -1, -1, -1, -1, -1};
|
||||
PartialShape expected_out_shape{{0, 6}, {0, 3}, {0, 6}, -1, -1, -1};
|
||||
|
||||
std::vector<int32_t> start_val{5, 5, -10, INT32_MAX, INT32_MAX};
|
||||
std::vector<int32_t> stop_val{-10, INT32_MIN, 5, 5, INT32_MIN};
|
||||
std::vector<int32_t> step_val{-1, -1, -1, -1, -1};
|
||||
std::vector<int32_t> start_val{5, 5, 5, -10, INT32_MAX, INT32_MAX};
|
||||
std::vector<int32_t> stop_val{-10, -10, INT32_MIN, 5, 5, INT32_MIN};
|
||||
std::vector<int32_t> step_val{-1, -2, -1, -1, -1, -1};
|
||||
|
||||
std::vector<int32_t> axes_val(start_val.size());
|
||||
std::iota(axes_val.begin(), axes_val.end(), 0);
|
||||
@ -643,7 +653,7 @@ TEST(type_prop, slice_v8_basic_const_inputs_MAX_MIN_INT_64_dynamic_dimensions_ne
|
||||
Dimension(-1),
|
||||
Dimension(-1)};
|
||||
PartialShape expected_out_shape{8,
|
||||
Dimension(8, 18),
|
||||
Dimension(4, 9),
|
||||
Dimension(5, 15),
|
||||
Dimension(0, 9),
|
||||
Dimension(0, 9),
|
||||
@ -670,7 +680,7 @@ TEST(type_prop, slice_v8_basic_const_inputs_MAX_MIN_INT_64_dynamic_dimensions_ne
|
||||
20,
|
||||
INT64_MAX};
|
||||
std::vector<int64_t> stop_val{1, 1, 4, 10, 10, 15, 25, INT64_MIN, -21, INT64_MIN, INT64_MIN, INT64_MIN, INT64_MIN};
|
||||
std::vector<int64_t> step_val{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1};
|
||||
std::vector<int64_t> step_val{-1, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1};
|
||||
|
||||
std::vector<int64_t> axes_val(start_val.size());
|
||||
std::iota(axes_val.begin(), axes_val.end(), 0);
|
||||
@ -841,51 +851,27 @@ TEST(type_prop, slice_v8_input_wrong_shape_catch) {
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
{
|
||||
try {
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`start` input must be a 1D tensor");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`stop` input must be a 1D tensor");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`step` input must be a 1D tensor");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`axes` input must be a 1D tensor");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`data` input can't be a scalar");
|
||||
}
|
||||
}
|
||||
const auto wrong_shape_in = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, wrong_shape_in, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`start` input must be a 1D tensor"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, wrong_shape_in, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`stop` input must be a 1D tensor"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, wrong_shape_in, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`step` input must be a 1D tensor"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, wrong_shape_in),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`axes` input must be a 1D tensor"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(wrong_shape_in, start, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`data` input can't be a scalar"));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_input_start_stop_step_dif_length_catch) {
|
||||
@ -894,50 +880,30 @@ TEST(type_prop, slice_v8_input_start_stop_step_dif_length_catch) {
|
||||
PartialShape correct_shape{3};
|
||||
PartialShape wrong_shape{2};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
constexpr auto et = element::i32;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
{
|
||||
try {
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "start`, `stop`, `step` inputs must have compatible shapes");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "start`, `stop`, `step` inputs must have compatible shapes");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "start`, `stop`, `step` inputs must have compatible shapes");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"`axes` input must have compatible shape with `start`, `stop`, `step` inputs");
|
||||
}
|
||||
}
|
||||
const auto wrong_shape_in = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, wrong_shape_in, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("start`, `stop`, `step` inputs must have compatible shapes"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, wrong_shape_in, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("start`, `stop`, `step` inputs must have compatible shapes"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, wrong_shape_in, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("start`, `stop`, `step` inputs must have compatible shapes"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, wrong_shape_in),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`axes` input must have compatible shape with `start`, `stop`, `step` inputs"));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_input_start_stop_step_out_of_data_rank_length_catch) {
|
||||
@ -946,149 +912,92 @@ TEST(type_prop, slice_v8_input_start_stop_step_out_of_data_rank_length_catch) {
|
||||
PartialShape correct_shape{3};
|
||||
PartialShape wrong_shape{5};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
constexpr auto et = element::i32;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
{
|
||||
try {
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`start` input dim size can't be bigger than `data` rank");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`stop` input dim size can't be bigger than `data` rank");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`step` input dim size can't be bigger than `data` rank");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`axes` input dim size can't be bigger than `data` rank");
|
||||
}
|
||||
}
|
||||
const auto wrong_shape_in = std::make_shared<op::v0::Parameter>(et, wrong_shape);
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, wrong_shape_in, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`start` input dim size can't be bigger than `data` rank"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, wrong_shape_in, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`stop` input dim size can't be bigger than `data` rank"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, wrong_shape_in, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`step` input dim size can't be bigger than `data` rank"));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, wrong_shape_in),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`axes` input dim size can't be bigger than `data` rank"));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_input_wrong_types_float_catch) {
|
||||
PartialShape data_shape{100, 100, 100, 100};
|
||||
PartialShape correct_shape{3};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
element::Type_t wrong_et = element::f32;
|
||||
constexpr auto et = element::i32;
|
||||
constexpr auto wrong_et = element::f32;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
{
|
||||
try {
|
||||
const auto start = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`start` input type must be integer.");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`stop` input type must be integer.");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto step = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`step` input type must be integer.");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`axes` input type must be integer.");
|
||||
}
|
||||
}
|
||||
const auto wrong_et_shape = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, wrong_et_shape, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`start` input type must be integer."));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, wrong_et_shape, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`stop` input type must be integer."));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, wrong_et_shape, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`step` input type must be integer."));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, wrong_et_shape),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`axes` input type must be integer."));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_input_wrong_types_bool_catch) {
|
||||
PartialShape data_shape{100, 100, 100, 100};
|
||||
PartialShape correct_shape{3};
|
||||
|
||||
element::Type_t et = element::u64;
|
||||
element::Type_t wrong_et = element::boolean;
|
||||
constexpr auto et = element::u64;
|
||||
constexpr auto wrong_et = element::boolean;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(et, correct_shape);
|
||||
{
|
||||
try {
|
||||
const auto start = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`start` input type must be integer.");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`stop` input type must be integer.");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto step = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`step` input type must be integer.");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto axes = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "`axes` input type must be integer.");
|
||||
}
|
||||
}
|
||||
const auto wrong_et_shape = std::make_shared<op::v0::Parameter>(wrong_et, correct_shape);
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, wrong_et_shape, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`start` input type must be integer."));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, wrong_et_shape, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`stop` input type must be integer."));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, wrong_et_shape, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`step` input type must be integer."));
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, wrong_et_shape),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("`axes` input type must be integer."));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_basic_const_inputs_out_axes_val) {
|
||||
@ -1098,40 +1007,31 @@ TEST(type_prop, slice_v8_basic_const_inputs_out_axes_val) {
|
||||
std::vector<int32_t> stop_val{8, 8, 20, -11, 0, -10, -11, -20};
|
||||
std::vector<int32_t> step_val{1, 2, 1, -1, -1, -1, -2, -1};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
constexpr auto et = element::i32;
|
||||
{
|
||||
try {
|
||||
std::vector<int32_t> axes_val{2, 0, -20, 7, 1, 20, 6, 4};
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val, axes_val};
|
||||
const auto op = make_slice_op_const_inputs(input_vals, data_shape, et);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const ov::AssertFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "must be in range of the `data` input rank: [-8, 7]. Got: -20");
|
||||
}
|
||||
std::vector<int32_t> axes_val{2, 0, -20, 7, 1, 20, 6, 4};
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val, axes_val};
|
||||
OV_EXPECT_THROW(const auto op = make_slice_op_const_inputs(input_vals, data_shape, et),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("axis -20 out of the tensor rank range [-8, 7]"));
|
||||
}
|
||||
{
|
||||
try {
|
||||
std::vector<int32_t> axes_val{2, 0, 9, 7, 1, 20, 6, 4};
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val, axes_val};
|
||||
const auto op = make_slice_op_const_inputs(input_vals, data_shape, et);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "must be in range of the `data` input rank: [-8, 7]. Got: 9");
|
||||
}
|
||||
}
|
||||
{
|
||||
try {
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, PartialShape{2});
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, PartialShape{2});
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, PartialShape{2});
|
||||
const auto axes = std::make_shared<op::v0::Constant>(et, Shape{2}, std::vector<int32_t>{-15, 7});
|
||||
const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "must be in range of the `data` input rank: [-8, 7]. Got: -15");
|
||||
}
|
||||
std::vector<int32_t> axes_val{2, 0, 9, 7, 1, 20, 6, 4};
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val, axes_val};
|
||||
OV_EXPECT_THROW(const auto op = make_slice_op_const_inputs(input_vals, data_shape, et),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("axis 9 out of the tensor rank range [-8, 7]"));
|
||||
}
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(et, data_shape);
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, PartialShape{2});
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, PartialShape{2});
|
||||
const auto step = std::make_shared<op::v0::Parameter>(et, PartialShape{2});
|
||||
const auto axes = std::make_shared<op::v0::Constant>(et, Shape{2}, std::vector<int32_t>{-15, 7});
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v8::Slice>(data, start, stop, step, axes),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("axis -15 out of the tensor rank range [-8, 7]"));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_basic_const_inputs_step_zero) {
|
||||
@ -1140,18 +1040,12 @@ TEST(type_prop, slice_v8_basic_const_inputs_step_zero) {
|
||||
|
||||
std::vector<int32_t> start_val{1, 1, -20, 9, 9, 9, 9, 20};
|
||||
std::vector<int32_t> stop_val{8, 8, 20, -11, 0, -10, -11, -20};
|
||||
std::vector<int32_t> step_val{1, 2, 0, -1, -1, -1, -2, -1};
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val};
|
||||
|
||||
element::Type_t et = element::i32;
|
||||
{
|
||||
std::vector<int32_t> step_val{1, 2, 0, -1, -1, -1, -2, -1};
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val};
|
||||
try {
|
||||
const auto op = make_slice_op_const_inputs(input_vals, data_shape, et);
|
||||
FAIL() << "Slice validation did not work!";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "'step' value can't be zero");
|
||||
}
|
||||
}
|
||||
OV_EXPECT_THROW(const auto op = make_slice_op_const_inputs(input_vals, data_shape, element::i32),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Step must be non-zero"));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_dynamic_rank_inputs) {
|
||||
@ -1168,8 +1062,8 @@ TEST(type_prop, slice_v8_dynamic_rank_inputs) {
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), dyn_rank_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_dynamic_value_and_label_propagation) {
|
||||
Dimension marked_0 = Dimension(3);
|
||||
TEST(type_prop, slice_v8_dynamic_value_and_label_propagation) {
|
||||
Dimension marked_0 = Dimension(3, 7);
|
||||
ov::DimensionTracker::set_label(marked_0, 10);
|
||||
PartialShape target_0 = PartialShape{marked_0, 4};
|
||||
|
||||
@ -1185,8 +1079,115 @@ TEST(type_prop, slice_dynamic_value_and_label_propagation) {
|
||||
const auto slice = std::make_shared<op::v8::Slice>(shape_0, start, stop, step);
|
||||
|
||||
auto bc = std::make_shared<op::v1::Broadcast>(param, slice);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{3}));
|
||||
|
||||
const auto& output_shape = bc->get_output_partial_shape(0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
|
||||
EXPECT_EQ(output_shape, (PartialShape{{3, 7}}));
|
||||
EXPECT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_dynamic_dimension_but_slice_min_is_lt_input_min_size) {
|
||||
PartialShape data_shape{Dimension(20, -1)};
|
||||
|
||||
std::vector<int32_t> start_val{-7};
|
||||
std::vector<int32_t> stop_val{INT32_MAX};
|
||||
std::vector<int32_t> step_val{1};
|
||||
std::vector<int32_t> axes_val{0};
|
||||
|
||||
constexpr auto et = element::i32;
|
||||
|
||||
std::vector<std::vector<int32_t>> input_vals{start_val, stop_val, step_val, axes_val};
|
||||
const auto op = make_slice_op_const_inputs(input_vals, data_shape, et);
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), et);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({{7}}));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_use_default_ctor) {
|
||||
const auto zero_mask = std::vector<int64_t>(3, 0);
|
||||
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{10, 11, 12, 2});
|
||||
auto start = op::Constant::create(element::i64, Shape{4}, {0, 0, 0, 0});
|
||||
auto stop = op::Constant::create(element::i64, Shape{4}, {1, 5, 20, 20});
|
||||
auto step = op::Constant::create(element::i64, Shape{4}, {1, 1, 1, 1});
|
||||
|
||||
auto slice = std::make_shared<op::v8::Slice>();
|
||||
slice->set_arguments(ov::OutputVector{data, start, stop, step});
|
||||
slice->validate_and_infer_types();
|
||||
|
||||
ASSERT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 5, 12, 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_stop_is_shape_of_with_bounds) {
|
||||
auto shape = PartialShape{1, {5, 7}};
|
||||
set_shape_labels(shape, 20);
|
||||
const auto p_stop = std::make_shared<op::Parameter>(element::i64, shape);
|
||||
const auto shape_of_stop = std::make_shared<op::ShapeOf>(p_stop);
|
||||
|
||||
auto data = op::Constant::create(element::i64, Shape{1, 10}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0});
|
||||
auto start = op::Constant::create(element::i64, Shape{2}, {0, 0});
|
||||
auto steps = op::Constant::create(element::i64, Shape{2}, {1, 1});
|
||||
|
||||
auto slice = std::make_shared<op::v8::Slice>(data, start, shape_of_stop, steps);
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, {5, 7}}));
|
||||
EXPECT_THAT(get_shape_labels(slice->get_output_partial_shape(0)), Each(ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_start_is_shape_of_with_bounds) {
|
||||
auto shape = PartialShape{0, {3, 5}};
|
||||
set_shape_labels(shape, 20);
|
||||
const auto p_start = std::make_shared<op::Parameter>(element::i64, shape);
|
||||
const auto shape_of_start = std::make_shared<op::ShapeOf>(p_start);
|
||||
|
||||
auto data = op::Constant::create(element::i64, Shape{1, 10}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0});
|
||||
auto stop = op::Constant::create(element::i64, Shape{2}, {1, 7});
|
||||
auto steps = op::Constant::create(element::i64, Shape{2}, {1, 1});
|
||||
|
||||
auto slice = std::make_shared<op::v8::Slice>(data, shape_of_start, stop, steps);
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, {2, 4}}));
|
||||
EXPECT_THAT(get_shape_labels(slice->get_output_partial_shape(0)), Each(ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_start_stop_is_shape_of_with_bounds) {
|
||||
auto start_shape = PartialShape{0, {3, 5}};
|
||||
auto stop_shape = PartialShape{2, {6, 7}};
|
||||
set_shape_labels(start_shape, 10);
|
||||
set_shape_labels(stop_shape, 20);
|
||||
const auto p_start = std::make_shared<op::Parameter>(element::i64, start_shape);
|
||||
const auto p_stop = std::make_shared<op::Parameter>(element::i64, stop_shape);
|
||||
const auto shape_of_start = std::make_shared<op::ShapeOf>(p_start);
|
||||
const auto shape_of_stop = std::make_shared<op::ShapeOf>(p_stop);
|
||||
|
||||
auto data = op::Constant::create(element::i64, Shape{1, 10}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0});
|
||||
auto steps = op::Constant::create(element::i64, Shape{2}, {1, 1});
|
||||
|
||||
auto slice = std::make_shared<op::v8::Slice>(data, shape_of_start, shape_of_stop, steps);
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, {1, 4}}));
|
||||
EXPECT_THAT(get_shape_labels(slice->get_output_partial_shape(0)), Each(ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_unknowns_axes) {
|
||||
const auto data = std::make_shared<op::Parameter>(element::i64, Shape{5, 10, 15});
|
||||
const auto start = std::make_shared<op::Parameter>(element::i64, PartialShape{-1});
|
||||
const auto stop = std::make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
const auto steps = std::make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
const auto axes = std::make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
|
||||
auto slice = std::make_shared<op::v8::Slice>(data, start, stop, steps, axes);
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({{0, 5}, {0, 10}, {0, 15}}));
|
||||
}
|
||||
|
||||
TEST(type_prop, slice_v8_inf_dim_start_from_last_N_to_end) {
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 256, -1});
|
||||
auto start = op::Constant::create(element::i64, Shape{1}, {-7});
|
||||
auto stop = op::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{INT64_MAX});
|
||||
auto step = op::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto axes = op::Constant::create(element::i64, Shape{1}, {2});
|
||||
|
||||
auto slice = std::make_shared<op::v8::Slice>(data, start, stop, step, axes);
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 256, {0, 7}}));
|
||||
}
|
||||
|
@ -97,38 +97,28 @@ TEST(type_prop, strided_slice_begin_incorrect_shape) {
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
|
||||
auto begin = make_shared<op::Parameter>(element::i64, Shape{4, 5});
|
||||
auto end = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||
try {
|
||||
auto strided_slice = make_shared<op::v1::StridedSlice>(data,
|
||||
begin,
|
||||
end,
|
||||
vector<int64_t>{1, 0, 1, 0},
|
||||
vector<int64_t>{1, 0, 1, 0});
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect shape of begin exception not thrown.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Begin input must be 1D (begin rank:"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
|
||||
OV_EXPECT_THROW(auto strided_slice = make_shared<op::v1::StridedSlice>(data,
|
||||
begin,
|
||||
end,
|
||||
vector<int64_t>{1, 0, 1, 0},
|
||||
vector<int64_t>{1, 0, 1, 0}),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Begin input must be 1D (has rank:"));
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_end_incorrect_shape) {
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
|
||||
auto begin = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||
auto end = make_shared<op::Parameter>(element::i64, Shape{4, 5});
|
||||
try {
|
||||
auto strided_slice = make_shared<op::v1::StridedSlice>(data,
|
||||
begin,
|
||||
end,
|
||||
vector<int64_t>{1, 0, 1, 0},
|
||||
vector<int64_t>{1, 0, 1, 0});
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect shape of end exception not thrown.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("End input must be 1D (end rank:"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
|
||||
OV_EXPECT_THROW(auto strided_slice = make_shared<op::v1::StridedSlice>(data,
|
||||
begin,
|
||||
end,
|
||||
vector<int64_t>{1, 0, 1, 0},
|
||||
vector<int64_t>{1, 0, 1, 0}),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("End input must be 1D (has rank:"));
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_default_stride_dynamic_shape_input_begin_not_1d) {
|
||||
@ -388,7 +378,7 @@ TEST(type_prop, strided_slice_reverse_end_is_int64_min) {
|
||||
|
||||
auto ss = std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, mask, mask);
|
||||
|
||||
EXPECT_EQ(ss->get_output_partial_shape(0), PartialShape({{0, 20}, -1}));
|
||||
EXPECT_EQ(ss->get_output_partial_shape(0), PartialShape({{0, 20}, {0, 21}}));
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_dynamic_value_and_label_propagation) {
|
||||
@ -440,6 +430,22 @@ TEST(type_prop, strided_slice_use_default_ctor) {
|
||||
ASSERT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 5, 12}));
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_inf_dim_start_from_last_N_to_end) {
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 256, -1});
|
||||
auto start = op::Constant::create(element::i64, Shape{3}, {0, 0, -7});
|
||||
auto stop = op::Constant::create(element::i64, Shape{3}, std::vector<int64_t>{0, 0, INT64_MAX});
|
||||
auto step = op::Constant::create(element::i64, Shape{3}, {1, 1, 1});
|
||||
|
||||
const auto slice = std::make_shared<op::v1::StridedSlice>(data,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
std::vector<int64_t>{1, 1, 0},
|
||||
std::vector<int64_t>{1, 1, 0});
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 256, {0, 7}}));
|
||||
}
|
||||
|
||||
struct StridedSliceTestParams {
|
||||
std::string case_name;
|
||||
PartialShape input_shape;
|
||||
@ -562,17 +568,17 @@ INSTANTIATE_TEST_SUITE_P(type_prop,
|
||||
},
|
||||
StridedSliceTestParams{
|
||||
"input_has_dynamic_dimensions_and_shrink_one",
|
||||
{{8, 40}, 200, {100, 200}, 3}, // input_shape
|
||||
{4}, // begin shape
|
||||
{4}, // end shape
|
||||
{4}, // strides shape
|
||||
{0, 0, 0, 0}, // begin mask
|
||||
{0, 0, 0, 0}, // end mask
|
||||
{0, 0, 0, 0}, // new axis mask
|
||||
{1, 0, 0, 0}, // shrink axis mask
|
||||
{0, 0, 0, 0}, // ellipsis mask
|
||||
{{0, 200}, Dimension::dynamic(), {0, 3}}, // reference shape
|
||||
element::f32 // reference type
|
||||
{{8, 40}, 200, {100, 200}, 3}, // input_shape
|
||||
{4}, // begin shape
|
||||
{4}, // end shape
|
||||
{4}, // strides shape
|
||||
{0, 0, 0, 0}, // begin mask
|
||||
{0, 0, 0, 0}, // end mask
|
||||
{0, 0, 0, 0}, // new axis mask
|
||||
{1, 0, 0, 0}, // shrink axis mask
|
||||
{0, 0, 0, 0}, // ellipsis mask
|
||||
{{0, 200}, {0, 200}, {0, 3}}, // reference shape
|
||||
element::f32 // reference type
|
||||
},
|
||||
StridedSliceTestParams{
|
||||
"input_is_dynamic_rank",
|
||||
@ -616,6 +622,20 @@ INSTANTIATE_TEST_SUITE_P(type_prop,
|
||||
{{0, 3}, {0, 5}, {0, 4}}, // reference shape
|
||||
element::f32 // reference type
|
||||
},
|
||||
StridedSliceTestParams{
|
||||
"begin_strides_are_dynamic_rank_and_ellipsis_mask_present",
|
||||
{3, 5, 4}, // input_shape
|
||||
PartialShape::dynamic(), // begin shape
|
||||
{3}, // end shape
|
||||
{3}, // strides shape
|
||||
{0, 0, 1, 0}, // begin mask
|
||||
{0, 0, 0, 0}, // end mask
|
||||
{0, 0, 0, 0}, // new axis mask
|
||||
{0, 0, 0, 0}, // shrink axis mask
|
||||
{0, 1, 0, 0}, // ellipsis mask
|
||||
{{0, 3}, 5, {0, 4}}, // reference shape
|
||||
element::f32 // reference type
|
||||
},
|
||||
StridedSliceTestParams{
|
||||
"begin_end_strides_are_dynamic_rank",
|
||||
{3, 5, 4}, // input_shape
|
||||
|
@ -11,9 +11,6 @@
|
||||
#include <openvino/opsets/opset7.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
|
||||
#include "ov_ops/augru_cell.hpp"
|
||||
#include "ov_ops/augru_sequence.hpp"
|
||||
|
||||
#include "assign_shape_inference.hpp"
|
||||
#include "augru_cell_shape_inference.hpp"
|
||||
#include "augru_sequence_shape_inference.hpp"
|
||||
@ -43,8 +40,8 @@
|
||||
#include "gather_shape_inference.hpp"
|
||||
#include "gather_tree_shape_inference.hpp"
|
||||
#include "grid_sample_shape_inference.hpp"
|
||||
#include "gru_sequence_shape_inference.hpp"
|
||||
#include "gru_cell_shape_inference.hpp"
|
||||
#include "gru_sequence_shape_inference.hpp"
|
||||
#include "interpolate_shape_inference.hpp"
|
||||
#include "lstm_cell_shape_inference.hpp"
|
||||
#include "matmul_shape_inference.hpp"
|
||||
@ -65,6 +62,7 @@
|
||||
#include "shape_inference.hpp"
|
||||
#include "shape_nodes.hpp"
|
||||
#include "shuffle_channels_shape_inference.hpp"
|
||||
#include "slice_shape_inference.hpp"
|
||||
#include "space_to_batch_shape_inference.hpp"
|
||||
#include "space_to_depth_shape_inference.hpp"
|
||||
#include "split_shape_inference.hpp"
|
||||
@ -467,6 +465,8 @@ std::shared_ptr<IShapeInferCommon> make_shape_inference(const std::shared_ptr<ng
|
||||
return make_shared_entryIOC(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset7::Einsum>(op)) {
|
||||
return make_shared_entryIO(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset8::Slice>(op)) {
|
||||
return make_shared_entryIOC(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset1::StridedSlice>(op)) {
|
||||
return make_shared_entryIOC(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset3::Assign>(op)) {
|
||||
|
@ -0,0 +1,108 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "slice_shape_inference.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
class SliceStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v8::Slice> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
output_shapes.resize(num_of_outputs);
|
||||
}
|
||||
|
||||
size_t num_of_outputs = 1;
|
||||
StaticDimension::value_type max_d = std::numeric_limits<StaticDimension::value_type>::max();
|
||||
};
|
||||
|
||||
TEST_F(SliceStaticShapeInferenceTest, reverse_steps_start_stop_outside_dimension_default_axes) {
|
||||
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto start = op::v0::Constant::create(element::i64, Shape{5}, std::vector<int64_t>{100, 5, -1, INT64_MAX, 5});
|
||||
const auto stop =
|
||||
op::v0::Constant::create(element::i64, Shape{5}, std::vector<int64_t>{-100, INT64_MIN, -6, 5, -10});
|
||||
const auto steps = op::v0::Constant::create(element::i64, Shape{5}, {-1, -2, -1, -1, -2});
|
||||
|
||||
const auto op = make_op(data, start, stop, steps);
|
||||
|
||||
input_shapes.push_back({3, 4, 5, max_d, max_d});
|
||||
input_shapes.resize(4, start->get_shape());
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
|
||||
EXPECT_EQ(output_shapes.size(), num_of_outputs);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({3, 2, 5, max_d, 3}));
|
||||
}
|
||||
|
||||
TEST_F(SliceStaticShapeInferenceTest, reverse_step_on_signle_axis_but_start_stop_steps_in_const_map) {
|
||||
constexpr auto et = element::i64;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, PartialShape::dynamic());
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, PartialShape::dynamic());
|
||||
const auto steps = std::make_shared<op::v0::Parameter>(et, PartialShape::dynamic());
|
||||
const auto axes = op::v0::Constant::create(element::i64, Shape{1}, {-1});
|
||||
|
||||
auto start_buff = std::vector<int64_t>{100};
|
||||
auto stop_buff = std::vector<int64_t>{2};
|
||||
auto steps_buff = std::vector<int64_t>{-2};
|
||||
|
||||
const auto start_tensor = std::make_shared<HostTensor>(et, Shape{1}, static_cast<void*>(start_buff.data()));
|
||||
const auto stop_tensor = std::make_shared<HostTensor>(et, Shape{1}, static_cast<void*>(stop_buff.data()));
|
||||
const auto steps_tensor = std::make_shared<HostTensor>(et, Shape{1}, static_cast<void*>(steps_buff.data()));
|
||||
|
||||
const auto op = make_op(data, start, stop, steps, axes);
|
||||
|
||||
input_shapes = ShapeVector{{3, 4, 10}, {1}, {1}, {1}, axes->get_shape()};
|
||||
|
||||
const std::map<size_t, std::shared_ptr<HostTensor>>& constant_data = {{1, start_tensor},
|
||||
{2, stop_tensor},
|
||||
{3, steps_tensor}};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes, constant_data);
|
||||
|
||||
EXPECT_EQ(output_shapes.size(), num_of_outputs);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({3, 4, 4}));
|
||||
}
|
||||
|
||||
TEST_F(SliceStaticShapeInferenceTest, forward_step_all_data_in_const_map) {
|
||||
constexpr auto et = element::i64;
|
||||
|
||||
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto start = std::make_shared<op::v0::Parameter>(et, PartialShape::dynamic());
|
||||
const auto stop = std::make_shared<op::v0::Parameter>(et, PartialShape::dynamic());
|
||||
const auto steps = std::make_shared<op::v0::Parameter>(et, PartialShape::dynamic());
|
||||
|
||||
auto start_buff = std::vector<int64_t>{0, 2, 10, 3, 3, INT64_MIN, INT64_MIN};
|
||||
auto stop_buff = std::vector<int64_t>{10, 8, 12, 15, INT64_MAX, -5, -5};
|
||||
auto steps_buff = std::vector<int64_t>{1, 2, 1, 3, 4, 2, 2};
|
||||
auto axes_buff = std::vector<int64_t>{0, 1, 2, 3, 4, 5, 6};
|
||||
|
||||
const auto common_shape = Shape{start_buff.size()};
|
||||
|
||||
const auto start_tensor = std::make_shared<HostTensor>(et, common_shape, static_cast<void*>(start_buff.data()));
|
||||
const auto stop_tensor = std::make_shared<HostTensor>(et, common_shape, static_cast<void*>(stop_buff.data()));
|
||||
const auto steps_tensor = std::make_shared<HostTensor>(et, common_shape, static_cast<void*>(steps_buff.data()));
|
||||
const auto axes_tensor = std::make_shared<HostTensor>(et, common_shape, static_cast<void*>(axes_buff.data()));
|
||||
|
||||
const auto op = make_op(data, start, stop, steps);
|
||||
|
||||
input_shapes.push_back({10, 10, 8, max_d, max_d, max_d, 10});
|
||||
input_shapes.resize(5, common_shape);
|
||||
|
||||
const std::map<size_t, std::shared_ptr<HostTensor>>& constant_data = {{1, start_tensor},
|
||||
{2, stop_tensor},
|
||||
{3, steps_tensor},
|
||||
{4, axes_tensor}};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes, constant_data);
|
||||
|
||||
EXPECT_EQ(output_shapes.size(), num_of_outputs);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({10, 3, 0, 4, max_d, max_d, 3}));
|
||||
}
|
Loading…
Reference in New Issue
Block a user