Strided Slice fix constant creation (#17557)
* Strided Slice fix constant creation * Apply suggestions from code review * Final touches
This commit is contained in:
committed by
GitHub
parent
0c67b90f47
commit
4c2096ad9c
@@ -119,7 +119,7 @@ public:
|
||||
|
||||
private:
|
||||
AxisSet convert_mask_to_axis_set(const std::vector<int64_t>& mask) const;
|
||||
bool indicies_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& masks) const;
|
||||
bool indices_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& masks) const;
|
||||
|
||||
std::vector<int64_t> m_begin_mask;
|
||||
std::vector<int64_t> m_end_mask;
|
||||
|
||||
@@ -68,14 +68,18 @@ shared_ptr<Node> calculate_default_strides(const Output<Node>& begin, const Outp
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if all indicies in 1-D input shape are ignored by masks.
|
||||
* @brief Check if all indices in 1-D input shape are ignored by masks.
|
||||
*
|
||||
* @param shape Indicies shape (assume compatible 1-D shape).
|
||||
* @param ignored_mask Axis set of ignored indicies.
|
||||
* @param shape Indices shape (assume compatible 1-D shape).
|
||||
* @param ignored_mask Axis set of ignored indices.
|
||||
* @return True if all ignored other wise false.
|
||||
*/
|
||||
bool all_indicies_ignored(const ov::PartialShape& shape, const AxisSet& ignore_mask) {
|
||||
return shape.rank().is_static() && ov::cmp::le(shape[0].get_interval().get_max_val(), ignore_mask.size());
|
||||
bool all_indices_ignored(const ov::PartialShape& shape, const std::vector<int64_t>& ignore_mask) {
|
||||
auto ignored = shape.rank().is_static() && ov::cmp::le(shape[0].get_interval().get_max_val(), ignore_mask.size());
|
||||
for (size_t i = 0; ignored && i < static_cast<size_t>(shape[0].get_interval().get_max_val()); ++i) {
|
||||
ignored = static_cast<bool>(ignore_mask[i]);
|
||||
}
|
||||
return ignored;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -249,13 +253,12 @@ bool op::v1::StridedSlice::has_evaluate() const {
|
||||
return get_input_size() == 4;
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::indicies_input_has_and_set_bounds(const size_t port,
|
||||
const std::vector<int64_t>& masks) const {
|
||||
bool op::v1::StridedSlice::indices_input_has_and_set_bounds(const size_t port, const std::vector<int64_t>& mask) const {
|
||||
const auto& lb_t = get_input_tensor(port).get_lower_value();
|
||||
const auto& ub_t = get_input_tensor(port).get_upper_value();
|
||||
|
||||
const auto mask_set = convert_mask_to_axis_set(masks);
|
||||
bool valid_bounds = all_indicies_ignored(get_input_partial_shape(port), mask_set);
|
||||
const auto mask_set = convert_mask_to_axis_set(mask);
|
||||
bool valid_bounds = all_indices_ignored(get_input_partial_shape(port), mask);
|
||||
|
||||
if (!valid_bounds && lb_t && ub_t) {
|
||||
using TCast = int64_t;
|
||||
@@ -274,38 +277,48 @@ bool op::v1::StridedSlice::indicies_input_has_and_set_bounds(const size_t port,
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
return indicies_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indicies_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
return indicies_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indicies_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
return indicies_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indicies_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
return indices_input_has_and_set_bounds(1, get_begin_mask()) &&
|
||||
indices_input_has_and_set_bounds(2, get_end_mask()) && get_input_tensor(3).has_and_set_bound() &&
|
||||
default_label_evaluator(this, {0}, output_labels);
|
||||
}
|
||||
|
||||
bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
|
||||
auto is_folded = Node::constant_fold(output_values, inputs_values);
|
||||
if (!is_folded) {
|
||||
// If all ignore mask are set for all begin or end then replace this input by dummy constant
|
||||
// to avoid return false from `could_propagate` during bound evaluation (value of const will be ignore).
|
||||
auto get_indicies_input = [&inputs_values](size_t port, AxisSet&& mask) -> Output<Node> {
|
||||
return all_indicies_ignored(inputs_values[port].get_partial_shape(), mask)
|
||||
? std::make_shared<op::v0::Constant>(inputs_values[port].get_element_type(),
|
||||
Shape{mask.size()},
|
||||
0)
|
||||
: inputs_values[port];
|
||||
// If all ignored mask are set for all begin or end then replace this input by dummy constant
|
||||
// to avoid return false from `could_propagate` during bound evaluation (value of const will be ignored).
|
||||
auto get_indices_input = [&inputs_values](size_t port, const std::vector<int64_t>& mask) -> Output<Node> {
|
||||
const auto& port_shape = inputs_values[port].get_partial_shape();
|
||||
const auto& data_shape = inputs_values[0].get_partial_shape();
|
||||
|
||||
size_t size;
|
||||
if (port_shape.rank().is_static() && port_shape[0].is_static())
|
||||
size = static_cast<size_t>(port_shape[0].get_length());
|
||||
else if (data_shape.rank().is_static())
|
||||
size = data_shape.size();
|
||||
else
|
||||
size = mask.size();
|
||||
|
||||
const auto& zero_constant =
|
||||
make_shared<ov::opset1::Constant>(inputs_values[port].get_element_type(), ov::Shape{size}, 0);
|
||||
return all_indices_ignored(inputs_values[port].get_partial_shape(), mask) ? zero_constant
|
||||
: inputs_values[port];
|
||||
};
|
||||
|
||||
const auto& begin = get_indicies_input(1, convert_mask_to_axis_set(get_begin_mask()));
|
||||
const auto& end = get_indicies_input(2, convert_mask_to_axis_set(get_end_mask()));
|
||||
const auto& begin = get_indices_input(1, get_begin_mask());
|
||||
const auto& end = get_indices_input(2, get_end_mask());
|
||||
|
||||
const auto& output =
|
||||
((&begin != &inputs_values[1]) || (&end != &inputs_values[2]))
|
||||
|
||||
@@ -446,6 +446,33 @@ TEST(type_prop, strided_slice_inf_dim_start_from_last_N_to_end) {
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 256, {0, 7}}));
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_different_ranks) {
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
|
||||
auto start = op::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto stop = op::Constant::create(element::i64, Shape{1}, std::vector<int64_t>{INT64_MAX});
|
||||
|
||||
const auto slice = std::make_shared<op::v1::StridedSlice>(data,
|
||||
start,
|
||||
stop,
|
||||
std::vector<int64_t>{1, 1, 1, 1, 1},
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0});
|
||||
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_different_ranks_long_masks) {
|
||||
auto data = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
|
||||
auto start = op::Constant::create(element::i64, Shape{4}, {0, 0, 0, 0});
|
||||
auto stop = op::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{2, 2, 2, 2});
|
||||
|
||||
const auto slice = std::make_shared<op::v1::StridedSlice>(data,
|
||||
start,
|
||||
stop,
|
||||
std::vector<int64_t>{1, 1, 0, 1, 1},
|
||||
std::vector<int64_t>{0, 0, 1, 0, 0});
|
||||
EXPECT_EQ(slice->get_output_partial_shape(0), PartialShape({1, 2, 3, 2}));
|
||||
}
|
||||
|
||||
struct StridedSliceTestParams {
|
||||
std::string case_name;
|
||||
PartialShape input_shape;
|
||||
|
||||
Reference in New Issue
Block a user