[Core] Correct shape inference for StridedSlice with non-const begin (#13581)

* [Core] Correct shape inference for StridedSlice with non-constant begin, end, and strides

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix build issue

* Fix build issue

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-10-24 14:18:02 +03:00 committed by GitHub
parent 4188f1f181
commit be1b72d1e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 301 additions and 175 deletions

View File

@ -36,14 +36,40 @@ void shape_infer(const StridedSlice* op,
end_shape.rank(),
").");
const auto& strides_shape = input_shapes[3];
NODE_VALIDATION_CHECK(op,
strides_shape.rank().compatible(1),
"End input must be 1D (end rank: ",
strides_shape.rank(),
").");
// it is not possible to define output shape if input data shape rank is undefined
// even the lengths of begin, end, or strides are defined
if (input_shape.rank().is_dynamic()) {
output_shapes[0] = ov::PartialShape::dynamic();
return;
}
auto input_rank = input_shape.size();
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
auto number_elements_in_1d = [](const StridedSlice* op, const T& shape_1d) -> int64_t {
auto rank_1d = shape_1d.rank();
if (rank_1d.is_static()) {
NODE_VALIDATION_CHECK(op, rank_1d.get_length() == 1, "Only 1D tensor is allowed.");
if (shape_1d[0].is_static()) {
return static_cast<int64_t>(shape_1d[0].get_length());
}
}
return -1;
};
// compute constant values of begin, end, and strides if possible
bool got_begin = get_data_as_int64<T>(1, op, begin, constant_data);
bool got_end = get_data_as_int64<T>(2, op, end, constant_data);
bool got_strides = false;
if (input_shapes.size() > 3) {
got_strides = get_data_as_int64<T>(3, op, strides, constant_data);
} else if (got_begin) {
@ -52,23 +78,32 @@ void shape_infer(const StridedSlice* op,
got_strides = true;
}
if (got_begin && got_end && got_strides) {
if (begin.size() && end.size()) {
// compute and check a number of axes for which begin, end, and strides are defined
auto number_axes = number_elements_in_1d(op, begin_shape);
auto end_number_axes = number_elements_in_1d(op, end_shape);
if (number_axes != -1 && end_number_axes != -1) {
NODE_VALIDATION_CHECK(op,
begin.size() == end.size(),
"Lower bounds and Upper bounds needs to have same number of values");
number_axes == end_number_axes,
"Lower bounds and Upper bounds need to have same number of values");
} else if (end_number_axes != -1) {
number_axes = end_number_axes;
}
if (begin.size() && strides.size()) {
auto strides_number_axes = number_elements_in_1d(op, strides_shape);
if (number_axes != -1 && strides_number_axes != -1) {
NODE_VALIDATION_CHECK(op,
begin.size() == strides.size(),
"Lower bounds and strides needs to have same number of values");
}
if (end.size() && strides.size()) {
NODE_VALIDATION_CHECK(op,
end.size() == strides.size(),
"Upper bounds and strides needs to have same number of values");
number_axes == strides_number_axes,
"Stride needs to have same number of values as Lower and Upper bounds");
} else if (strides_number_axes != -1) {
number_axes = strides_number_axes;
}
// if number of axes is undefined we cannot say about output rank
if (number_axes < 0) {
output_shapes[0] = ov::PartialShape::dynamic();
return;
}
// collect indices of axes by which the shape needs to be changed
auto convert_mask_to_axis_set = [](const std::vector<int64_t>& mask) {
AxisSet axis_set{};
for (size_t i = 0; i < mask.size(); ++i) {
@ -78,32 +113,20 @@ void shape_infer(const StridedSlice* op,
}
return axis_set;
};
AxisSet ellipsis_mask = convert_mask_to_axis_set(op->get_ellipsis_mask());
NODE_VALIDATION_CHECK(op, ellipsis_mask.size() <= 1, "At most one ellipsis is allowed.");
if (input_shape.rank().is_dynamic()) {
output_shapes[0] = ov::PartialShape::dynamic();
return;
}
auto input_rank = input_shape.size();
AxisSet new_axis_mask = convert_mask_to_axis_set(op->get_new_axis_mask());
NODE_VALIDATION_CHECK(op,
input_rank + new_axis_mask.size() >= begin.size(),
"Input rank plus number of new axis has to be at least the size of Lower "
"and Upper bounds vector.");
AxisSet begin_mask = convert_mask_to_axis_set(op->get_begin_mask());
AxisSet end_mask = convert_mask_to_axis_set(op->get_end_mask());
AxisSet shrink_axis_mask = convert_mask_to_axis_set(op->get_shrink_axis_mask());
NODE_VALIDATION_CHECK(op,
input_rank + new_axis_mask.size() >= static_cast<size_t>(number_axes),
"Input rank plus number of new axis has to be at least the size of Lower "
"and Upper bounds vector.");
std::vector<DimType> dim;
std::vector<DimType> dims;
int64_t input_shape_idx = 0;
for (size_t axis = 0; axis < begin.size(); ++axis) {
for (int64_t axis = 0; axis < number_axes; ++axis) {
// add all dimensions hidden under the ellipsis mask if ellipsis mask is set
if (ellipsis_mask.count(axis)) {
// only one bit in ellipsis mask is allowed
@ -121,24 +144,23 @@ void shape_infer(const StridedSlice* op,
}
int64_t num_input_axis_after_ellipses =
(begin.size() - axis - num_new_axis_after_ellipses - 1); // -1 because it's a position of ellipses
int64_t num_of_hidden_dims =
input_rank - num_input_axis_after_ellipses - num_input_axis_before_ellipses;
(number_axes - axis - num_new_axis_after_ellipses - 1); // -1 because it's a position of ellipses
int64_t num_of_hidden_dims = input_rank - num_input_axis_after_ellipses - num_input_axis_before_ellipses;
for (int64_t i = 0; i < num_of_hidden_dims; ++i) {
dim.emplace_back(input_shape[input_shape_idx]);
dims.emplace_back(input_shape[input_shape_idx]);
input_shape_idx++;
}
} else {
// add new single dimension if new_axis_mask is set
if (new_axis_mask.count(axis)) {
dim.emplace_back(1);
dims.emplace_back(1);
}
// skip this dimension if shrink_axis_mask is set
else if (shrink_axis_mask.count(axis)) {
input_shape_idx++;
}
// calculating dimension (begin, end, begin_mask, end_mask, stride)
else {
else if (got_begin && got_end && got_strides) {
const int64_t lb0 = begin[axis];
const int64_t ub0 = end[axis];
// set default value for stride or use given value
@ -211,25 +233,32 @@ void shape_infer(const StridedSlice* op,
else
odim_max = -1;
dim.emplace_back(ov::Dimension(odim_min, odim_max));
dims.emplace_back(ov::Dimension(odim_min, odim_max));
} else {
int64_t dimension = get_output_dim(input_shape[input_shape_idx].get_length());
dim.emplace_back(dimension);
dims.emplace_back(dimension);
}
input_shape_idx++;
} else {
if (input_shape[input_shape_idx].is_static()) {
auto dim_value = input_shape[input_shape_idx].get_length();
dims.emplace_back(ov::Dimension(0, dim_value));
} else {
dims.emplace_back(input_shape[input_shape_idx]);
}
input_shape_idx++;
}
}
}
// get remaining values
for (; input_shape_idx < input_shape.rank().get_length(); ++input_shape_idx) {
dim.emplace_back(input_shape[input_shape_idx]);
dims.emplace_back(input_shape[input_shape_idx]);
}
output_shapes[0] = T(dim);
} else {
output_shapes[0] = ov::PartialShape::dynamic(input_shape.rank());
}
output_shapes[0] = T(dims);
}
} // namespace v1
} // namespace op

View File

@ -8,6 +8,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "openvino/opsets/opset9.hpp"
#include "util/type_prop.hpp"
using namespace std;
@ -193,14 +194,14 @@ TEST(type_prop, strided_slice_dynamic_value_and_label_propagation) {
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
}
TEST(type_prop, default_strided_slice_shape_inference) {
TEST(type_prop, strided_slice_default_shape_inference) {
auto slice = new op::v1::StridedSlice;
slice->set_begin_mask({0, 0, 0});
slice->set_end_mask({0, 0, 0});
slice->set_new_axis_mask({1, 0, 0});
slice->set_shrink_axis_mask({0, 0, 0, 1});
slice->set_ellipsis_mask_mask({0, 0, 0});
std::vector<ov::PartialShape> in = {{10, 11, 12}, {3}, {3}, {3}}, out = {PartialShape()};
std::vector<ov::PartialShape> in = {{10, 11, 12}, {4}, {4}, {4}}, out = {PartialShape()};
int64_t begin_data[] = {0, 0, 0, 0}, end_data[] = {1, 1, 5, 1}, stride_data[] = {1, 1, 1, 1};
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {
{1, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, begin_data)},
@ -209,3 +210,99 @@ TEST(type_prop, default_strided_slice_shape_inference) {
ov::op::v1::shape_infer(slice, in, out, const_data);
ASSERT_EQ(out[0], PartialShape({1, 1, 5}));
}
struct StridedSliceTestParams {
PartialShape input_shape;
PartialShape begin_shape;
PartialShape end_shape;
PartialShape strides_shape;
std::vector<int64_t> begin_mask;
std::vector<int64_t> end_mask;
std::vector<int64_t> new_axis_mask;
std::vector<int64_t> shrink_axis_mask;
std::vector<int64_t> ellipsis_mask;
PartialShape ref_shape;
ov::element::Type ref_type;
};
struct StridedSliceShapeInferTest : ::testing::TestWithParam<StridedSliceTestParams> {};
TEST_P(StridedSliceShapeInferTest, non_const_begin) {
auto params = GetParam();
auto input_data = std::make_shared<ov::opset9::Parameter>(ov::element::f32, params.input_shape);
auto begin = std::make_shared<ov::opset9::Parameter>(ov::element::i32, params.begin_shape);
auto end = std::make_shared<ov::opset9::Parameter>(ov::element::i32, params.end_shape);
auto strides = std::make_shared<ov::opset9::Parameter>(ov::element::i32, params.strides_shape);
std::vector<int64_t> begin_mask = params.begin_mask;
std::vector<int64_t> end_mask = params.end_mask;
std::vector<int64_t> new_axis_mask = params.new_axis_mask;
std::vector<int64_t> shrink_axis_mask = params.shrink_axis_mask;
std::vector<int64_t> ellipsis_mask = params.ellipsis_mask;
auto strided_slice = std::make_shared<ov::opset9::StridedSlice>(input_data,
begin,
end,
strides,
begin_mask,
end_mask,
new_axis_mask,
shrink_axis_mask,
ellipsis_mask);
EXPECT_EQ(strided_slice->get_element_type(), params.ref_type);
auto res_shape = strided_slice->get_output_partial_shape(0);
EXPECT_EQ(res_shape, params.ref_shape);
}
INSTANTIATE_TEST_SUITE_P(type_prop,
StridedSliceShapeInferTest,
::testing::Values(StridedSliceTestParams{{1, 200, 300, 3}, // input_shape
{4}, // begin shape
{4}, // end shape
{4}, // strides shape
{0, 0, 0, 0}, // begin mask
{0, 0, 0, 0}, // end mask
{0, 0, 0, 0}, // new axis mask
{0, 1, 0, 0}, // shrink axis mask
{0, 0, 0, 0}, // ellipsis mask
PartialShape{
Dimension(0, 1),
Dimension(0, 300),
Dimension(0, 3),
}, // reference shape
element::f32}, // reference type
StridedSliceTestParams{{1, 200, 300, 3}, // input_shape
{4}, // begin shape
{4}, // end shape
{4}, // strides shape
{1, 0, 0, 0}, // begin mask
{1, 0, 0, 0}, // end mask
{0, 0, 0, 0}, // new axis mask
{0, 1, 0, 1}, // shrink axis mask
{0, 0, 0, 0}, // ellipsis mask
PartialShape{
Dimension(0, 1),
Dimension(0, 300),
}, // reference shape
element::f32}, // reference type
StridedSliceTestParams{{1, 200, 200, 3}, // input_shape
{4}, // begin shape
{4}, // end shape
{4}, // strides shape
{1, 0, 0, 0}, // begin mask
{1, 0, 0, 0}, // end mask
{0, 0, 1, 0}, // new axis mask
{0, 0, 0, 0}, // shrink axis mask
{0, 0, 0, 0}, // ellipsis mask
PartialShape{
Dimension(0, 1),
Dimension(0, 200),
1,
Dimension(0, 200),
3,
}, // reference shape
element::f32} // reference type
),
PrintToDummyParamName());