Strided Slice fix constant creation (#17557)

* Strided Slice fix constant creation

* Apply suggestions from code review

* Final touches
This commit is contained in:
Evgenya Stepyreva
2023-05-16 17:53:57 +04:00
committed by GitHub
parent 0c67b90f47
commit 4c2096ad9c
3 changed files with 66 additions and 26 deletions

View File

@@ -119,7 +119,7 @@ public:
private:
AxisSet convert_mask_to_axis_set(const std::vector<int64_t>& mask) const;
bool indicies_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& masks) const;
bool indices_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& masks) const;
std::vector<int64_t> m_begin_mask;
std::vector<int64_t> m_end_mask;

View File

@@ -68,14 +68,18 @@ shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Outp
}
/**
* @brief Check if all indicies in 1-D input shape are ignored by masks.
* @brief Check if all indices in 1-D input shape are ignored by masks.
*
* @param shape Indicies shape (assume compatible 1-D shape).
* @param ignored_mask Axis set of ignored indicies.
* @param shape Indices shape (assume compatible 1-D shape).
* @param ignored_mask Axis set of ignored indices.
* @return True if all ignored other wise false.
*/
bool all_indicies_ignored(const ov::PartialShape& shape, const AxisSet& ignore_mask) {
return shape.rank().is_static() && ov::cmp::le(shape[0].get_interval().get_max_val(), ignore_mask.size());
bool all_indices_ignored(const ov::PartialShape& shape, const std::vector<int64_t>& ignore_mask) {
auto ignored = shape.rank().is_static() && ov::cmp::le(shape[0].get_interval().get_max_val(), ignore_mask.size());
for (size_t i = 0; ignored && i < static_cast<size_t>(shape[0].get_interval().get_max_val()); ++i) {
ignored = static_cast<bool>(ignore_mask[i]);
}
return ignored;
}
} // namespace
@@ -249,13 +253,12 @@ bool op::v1::StridedSlice::has_evaluate() const {
return get_input_size() == 4;
}
bool op::v1::StridedSlice::indicies_input_has_and_set_bounds(const size_t port,
const std::vector<int64_t>& masks) const {
bool op::v1::StridedSlice::indices_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& mask) const {
const auto& lb_t = get_input_tensor(port).get_lower_value();
const auto& ub_t = get_input_tensor(port).get_upper_value();
const auto mask_set = convert_mask_to_axis_set(masks);
bool valid_bounds = all_indicies_ignored(get_input_partial_shape(port), mask_set);
const auto mask_set = convert_mask_to_axis_set(mask);
bool valid_bounds = all_indices_ignored(get_input_partial_shape(port), mask);
if (!valid_bounds && lb_t && ub_t) {
using TCast = int64_t;
@@ -274,38 +277,48 @@ bool op::v1::StridedSlice::indicies_input_has_and_set_bounds(const size_t port,
}
bool op::v1::StridedSlice::evaluate_lower(ov::TensorVector& output_values) const {
return indicies_input_has_and_set_bounds(1, get_begin_mask()) &&
indicies_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
default_lower_bound_evaluator(this, output_values);
}
bool op::v1::StridedSlice::evaluate_upper(ov::TensorVector& output_values) const {
return indicies_input_has_and_set_bounds(1, get_begin_mask()) &&
indicies_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
default_upper_bound_evaluator(this, output_values);
}
bool op::v1::StridedSlice::evaluate_label(TensorLabelVector& output_labels) const {
return indicies_input_has_and_set_bounds(1, get_begin_mask()) &&
indicies_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
default_label_evaluator(this, {0}, output_labels);
}
bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
auto is_folded = Node::constant_fold(output_values, inputs_values);
if (!is_folded) {
// If all ignore mask are set for all begin or end then replace this input by dummy constant
// to avoid return false from `could_propagate` during bound evaluation (value of const will be ignore).
auto get_indicies_input = [&inputs_values](size_t port, AxisSet&& mask) -> Output<Node> {
return all_indicies_ignored(inputs_values[port].get_partial_shape(), mask)
? std::make_shared<op::v0::Constant>(inputs_values[port].get_element_type(),
Shape{mask.size()},
0)
: inputs_values[port];
// If all ignored mask are set for all begin or end then replace this input by dummy constant
// to avoid return false from `could_propagate` during bound evaluation (value of const will be ignored).
auto get_indices_input = [&inputs_values](size_t port, const std::vector<int64_t>& mask) -> Output<Node> {
const auto& port_shape = inputs_values[port].get_partial_shape();
const auto& data_shape = inputs_values[0].get_partial_shape();
size_t size;
if (port_shape.rank().is_static() && port_shape[0].is_static())
size = static_cast<size_t>(port_shape[0].get_length());
else if (data_shape.rank().is_static())
size = data_shape.size();
else
size = mask.size();
const auto& zero_constant =
make_shared<ov::opset1::Constant>(inputs_values[port].get_element_type(), ov::Shape{size}, 0);
return all_indices_ignored(inputs_values[port].get_partial_shape(), mask) ? zero_constant
: inputs_values[port];
};
const auto& begin = get_indicies_input(1, convert_mask_to_axis_set(get_begin_mask()));
const auto& end = get_indicies_input(2, convert_mask_to_axis_set(get_end_mask()));
const auto& begin = get_indices_input(1, get_begin_mask());
const auto& end = get_indices_input(2, get_end_mask());
const auto& output =
((&begin != &inputs_values[1]) || (&end != &inputs_values[2]))

View File

@@ -446,6 +446,33 @@ TEST(type_prop, strided_slice_inf_dim_start_from_last_N_to_end) {
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 256, {0, 7}}));
}
TEST(type_prop, strided_slice_different_ranks) {
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
auto start = op::Constant::create(element::i64, Shape{1}, {0});
auto stop = op::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{INT64_MAX});
const auto slice = std::make_shared<op::v1::StridedSlice>(data,
start,
stop,
std::vector<int64_t>{1, 1, 1, 1, 1},
std::vector<int64_t>{0, 0, 0, 0, 0});
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 2, 3, 4}));
}
TEST(type_prop, strided_slice_different_ranks_long_masks) {
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
auto start = op::Constant::create(element::i64, Shape{4}, {0, 0, 0, 0});
auto stop = op::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{2, 2, 2, 2});
const auto slice = std::make_shared<op::v1::StridedSlice>(data,
start,
stop,
std::vector<int64_t>{1, 1, 0, 1, 1},
std::vector<int64_t>{0, 0, 1, 0, 0});
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 2, 3, 2}));
}
struct StridedSliceTestParams {
std::string case_name;
PartialShape input_shape;