[Core] Correct shape inference for StridedSlice with non-const begin (#13581)
* [Core] Correct shape inference for StridedSlice with non-constant begin, end, and strides Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix build issue * Fix build issue Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
4188f1f181
commit
be1b72d1e9
@ -36,14 +36,40 @@ void shape_infer(const StridedSlice* op,
|
||||
end_shape.rank(),
|
||||
").");
|
||||
|
||||
const auto& strides_shape = input_shapes[3];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
strides_shape.rank().compatible(1),
|
||||
"End input must be 1D (end rank: ",
|
||||
strides_shape.rank(),
|
||||
").");
|
||||
|
||||
// it is not possible to define output shape if input data shape rank is undefined
|
||||
// even the lengths of begin, end, or strides are defined
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
output_shapes[0] = ov::PartialShape::dynamic();
|
||||
return;
|
||||
}
|
||||
auto input_rank = input_shape.size();
|
||||
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> end;
|
||||
std::vector<int64_t> strides;
|
||||
|
||||
auto number_elements_in_1d = [](const StridedSlice* op, const T& shape_1d) -> int64_t {
|
||||
auto rank_1d = shape_1d.rank();
|
||||
if (rank_1d.is_static()) {
|
||||
NODE_VALIDATION_CHECK(op, rank_1d.get_length() == 1, "Only 1D tensor is allowed.");
|
||||
if (shape_1d[0].is_static()) {
|
||||
return static_cast<int64_t>(shape_1d[0].get_length());
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
// compute constant values of begin, end, and strides if possible
|
||||
bool got_begin = get_data_as_int64<T>(1, op, begin, constant_data);
|
||||
bool got_end = get_data_as_int64<T>(2, op, end, constant_data);
|
||||
bool got_strides = false;
|
||||
|
||||
if (input_shapes.size() > 3) {
|
||||
got_strides = get_data_as_int64<T>(3, op, strides, constant_data);
|
||||
} else if (got_begin) {
|
||||
@ -52,23 +78,32 @@ void shape_infer(const StridedSlice* op,
|
||||
got_strides = true;
|
||||
}
|
||||
|
||||
if (got_begin && got_end && got_strides) {
|
||||
if (begin.size() && end.size()) {
|
||||
// compute and check a number of axes for which begin, end, and strides are defined
|
||||
auto number_axes = number_elements_in_1d(op, begin_shape);
|
||||
auto end_number_axes = number_elements_in_1d(op, end_shape);
|
||||
if (number_axes != -1 && end_number_axes != -1) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
begin.size() == end.size(),
|
||||
"Lower bounds and Upper bounds needs to have same number of values");
|
||||
number_axes == end_number_axes,
|
||||
"Lower bounds and Upper bounds need to have same number of values");
|
||||
} else if (end_number_axes != -1) {
|
||||
number_axes = end_number_axes;
|
||||
}
|
||||
if (begin.size() && strides.size()) {
|
||||
auto strides_number_axes = number_elements_in_1d(op, strides_shape);
|
||||
if (number_axes != -1 && strides_number_axes != -1) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
begin.size() == strides.size(),
|
||||
"Lower bounds and strides needs to have same number of values");
|
||||
}
|
||||
if (end.size() && strides.size()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
end.size() == strides.size(),
|
||||
"Upper bounds and strides needs to have same number of values");
|
||||
number_axes == strides_number_axes,
|
||||
"Stride needs to have same number of values as Lower and Upper bounds");
|
||||
} else if (strides_number_axes != -1) {
|
||||
number_axes = strides_number_axes;
|
||||
}
|
||||
|
||||
// if number of axes is undefined we cannot say about output rank
|
||||
if (number_axes < 0) {
|
||||
output_shapes[0] = ov::PartialShape::dynamic();
|
||||
return;
|
||||
}
|
||||
|
||||
// collect indices of axes by which the shape needs to be changed
|
||||
auto convert_mask_to_axis_set = [](const std::vector<int64_t>& mask) {
|
||||
AxisSet axis_set{};
|
||||
for (size_t i = 0; i < mask.size(); ++i) {
|
||||
@ -78,32 +113,20 @@ void shape_infer(const StridedSlice* op,
|
||||
}
|
||||
return axis_set;
|
||||
};
|
||||
|
||||
AxisSet ellipsis_mask = convert_mask_to_axis_set(op->get_ellipsis_mask());
|
||||
NODE_VALIDATION_CHECK(op, ellipsis_mask.size() <= 1, "At most one ellipsis is allowed.");
|
||||
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
output_shapes[0] = ov::PartialShape::dynamic();
|
||||
return;
|
||||
}
|
||||
|
||||
auto input_rank = input_shape.size();
|
||||
|
||||
AxisSet new_axis_mask = convert_mask_to_axis_set(op->get_new_axis_mask());
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rank + new_axis_mask.size() >= begin.size(),
|
||||
"Input rank plus number of new axis has to be at least the size of Lower "
|
||||
"and Upper bounds vector.");
|
||||
|
||||
AxisSet begin_mask = convert_mask_to_axis_set(op->get_begin_mask());
|
||||
AxisSet end_mask = convert_mask_to_axis_set(op->get_end_mask());
|
||||
AxisSet shrink_axis_mask = convert_mask_to_axis_set(op->get_shrink_axis_mask());
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rank + new_axis_mask.size() >= static_cast<size_t>(number_axes),
|
||||
"Input rank plus number of new axis has to be at least the size of Lower "
|
||||
"and Upper bounds vector.");
|
||||
|
||||
std::vector<DimType> dim;
|
||||
|
||||
std::vector<DimType> dims;
|
||||
int64_t input_shape_idx = 0;
|
||||
for (size_t axis = 0; axis < begin.size(); ++axis) {
|
||||
for (int64_t axis = 0; axis < number_axes; ++axis) {
|
||||
// add all dimensions hidden under the ellipsis mask if ellipsis mask is set
|
||||
if (ellipsis_mask.count(axis)) {
|
||||
// only one bit in ellipsis mask is allowed
|
||||
@ -121,24 +144,23 @@ void shape_infer(const StridedSlice* op,
|
||||
}
|
||||
|
||||
int64_t num_input_axis_after_ellipses =
|
||||
(begin.size() - axis - num_new_axis_after_ellipses - 1); // -1 because it's a position of ellipses
|
||||
int64_t num_of_hidden_dims =
|
||||
input_rank - num_input_axis_after_ellipses - num_input_axis_before_ellipses;
|
||||
(number_axes - axis - num_new_axis_after_ellipses - 1); // -1 because it's a position of ellipses
|
||||
int64_t num_of_hidden_dims = input_rank - num_input_axis_after_ellipses - num_input_axis_before_ellipses;
|
||||
for (int64_t i = 0; i < num_of_hidden_dims; ++i) {
|
||||
dim.emplace_back(input_shape[input_shape_idx]);
|
||||
dims.emplace_back(input_shape[input_shape_idx]);
|
||||
input_shape_idx++;
|
||||
}
|
||||
} else {
|
||||
// add new single dimension if new_axis_mask is set
|
||||
if (new_axis_mask.count(axis)) {
|
||||
dim.emplace_back(1);
|
||||
dims.emplace_back(1);
|
||||
}
|
||||
// skip this dimension if shrink_axis_mask is set
|
||||
else if (shrink_axis_mask.count(axis)) {
|
||||
input_shape_idx++;
|
||||
}
|
||||
// calculating dimension (begin, end, begin_mask, end_mask, stride)
|
||||
else {
|
||||
else if (got_begin && got_end && got_strides) {
|
||||
const int64_t lb0 = begin[axis];
|
||||
const int64_t ub0 = end[axis];
|
||||
// set default value for stride or use given value
|
||||
@ -211,25 +233,32 @@ void shape_infer(const StridedSlice* op,
|
||||
else
|
||||
odim_max = -1;
|
||||
|
||||
dim.emplace_back(ov::Dimension(odim_min, odim_max));
|
||||
dims.emplace_back(ov::Dimension(odim_min, odim_max));
|
||||
} else {
|
||||
int64_t dimension = get_output_dim(input_shape[input_shape_idx].get_length());
|
||||
dim.emplace_back(dimension);
|
||||
dims.emplace_back(dimension);
|
||||
}
|
||||
|
||||
input_shape_idx++;
|
||||
} else {
|
||||
if (input_shape[input_shape_idx].is_static()) {
|
||||
auto dim_value = input_shape[input_shape_idx].get_length();
|
||||
dims.emplace_back(ov::Dimension(0, dim_value));
|
||||
} else {
|
||||
dims.emplace_back(input_shape[input_shape_idx]);
|
||||
}
|
||||
|
||||
input_shape_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get remaining values
|
||||
for (; input_shape_idx < input_shape.rank().get_length(); ++input_shape_idx) {
|
||||
dim.emplace_back(input_shape[input_shape_idx]);
|
||||
dims.emplace_back(input_shape[input_shape_idx]);
|
||||
}
|
||||
|
||||
output_shapes[0] = T(dim);
|
||||
} else {
|
||||
output_shapes[0] = ov::PartialShape::dynamic(input_shape.rank());
|
||||
}
|
||||
output_shapes[0] = T(dims);
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -193,14 +194,14 @@ TEST(type_prop, strided_slice_dynamic_value_and_label_propagation) {
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
|
||||
}
|
||||
|
||||
TEST(type_prop, default_strided_slice_shape_inference) {
|
||||
TEST(type_prop, strided_slice_default_shape_inference) {
|
||||
auto slice = new op::v1::StridedSlice;
|
||||
slice->set_begin_mask({0, 0, 0});
|
||||
slice->set_end_mask({0, 0, 0});
|
||||
slice->set_new_axis_mask({1, 0, 0});
|
||||
slice->set_shrink_axis_mask({0, 0, 0, 1});
|
||||
slice->set_ellipsis_mask_mask({0, 0, 0});
|
||||
std::vector<ov::PartialShape> in = {{10, 11, 12}, {3}, {3}, {3}}, out = {PartialShape()};
|
||||
std::vector<ov::PartialShape> in = {{10, 11, 12}, {4}, {4}, {4}}, out = {PartialShape()};
|
||||
int64_t begin_data[] = {0, 0, 0, 0}, end_data[] = {1, 1, 5, 1}, stride_data[] = {1, 1, 1, 1};
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {
|
||||
{1, std::make_shared<ov::HostTensor>(element::i64, Shape{4}, begin_data)},
|
||||
@ -209,3 +210,99 @@ TEST(type_prop, default_strided_slice_shape_inference) {
|
||||
ov::op::v1::shape_infer(slice, in, out, const_data);
|
||||
ASSERT_EQ(out[0], PartialShape({1, 1, 5}));
|
||||
}
|
||||
|
||||
struct StridedSliceTestParams {
|
||||
PartialShape input_shape;
|
||||
PartialShape begin_shape;
|
||||
PartialShape end_shape;
|
||||
PartialShape strides_shape;
|
||||
std::vector<int64_t> begin_mask;
|
||||
std::vector<int64_t> end_mask;
|
||||
std::vector<int64_t> new_axis_mask;
|
||||
std::vector<int64_t> shrink_axis_mask;
|
||||
std::vector<int64_t> ellipsis_mask;
|
||||
|
||||
PartialShape ref_shape;
|
||||
ov::element::Type ref_type;
|
||||
};
|
||||
|
||||
struct StridedSliceShapeInferTest : ::testing::TestWithParam<StridedSliceTestParams> {};
|
||||
|
||||
TEST_P(StridedSliceShapeInferTest, non_const_begin) {
|
||||
auto params = GetParam();
|
||||
|
||||
auto input_data = std::make_shared<ov::opset9::Parameter>(ov::element::f32, params.input_shape);
|
||||
auto begin = std::make_shared<ov::opset9::Parameter>(ov::element::i32, params.begin_shape);
|
||||
auto end = std::make_shared<ov::opset9::Parameter>(ov::element::i32, params.end_shape);
|
||||
auto strides = std::make_shared<ov::opset9::Parameter>(ov::element::i32, params.strides_shape);
|
||||
std::vector<int64_t> begin_mask = params.begin_mask;
|
||||
std::vector<int64_t> end_mask = params.end_mask;
|
||||
std::vector<int64_t> new_axis_mask = params.new_axis_mask;
|
||||
std::vector<int64_t> shrink_axis_mask = params.shrink_axis_mask;
|
||||
std::vector<int64_t> ellipsis_mask = params.ellipsis_mask;
|
||||
|
||||
auto strided_slice = std::make_shared<ov::opset9::StridedSlice>(input_data,
|
||||
begin,
|
||||
end,
|
||||
strides,
|
||||
begin_mask,
|
||||
end_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask,
|
||||
ellipsis_mask);
|
||||
|
||||
EXPECT_EQ(strided_slice->get_element_type(), params.ref_type);
|
||||
auto res_shape = strided_slice->get_output_partial_shape(0);
|
||||
EXPECT_EQ(res_shape, params.ref_shape);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(type_prop,
|
||||
StridedSliceShapeInferTest,
|
||||
::testing::Values(StridedSliceTestParams{{1, 200, 300, 3}, // input_shape
|
||||
{4}, // begin shape
|
||||
{4}, // end shape
|
||||
{4}, // strides shape
|
||||
{0, 0, 0, 0}, // begin mask
|
||||
{0, 0, 0, 0}, // end mask
|
||||
{0, 0, 0, 0}, // new axis mask
|
||||
{0, 1, 0, 0}, // shrink axis mask
|
||||
{0, 0, 0, 0}, // ellipsis mask
|
||||
PartialShape{
|
||||
Dimension(0, 1),
|
||||
Dimension(0, 300),
|
||||
Dimension(0, 3),
|
||||
}, // reference shape
|
||||
element::f32}, // reference type
|
||||
StridedSliceTestParams{{1, 200, 300, 3}, // input_shape
|
||||
{4}, // begin shape
|
||||
{4}, // end shape
|
||||
{4}, // strides shape
|
||||
{1, 0, 0, 0}, // begin mask
|
||||
{1, 0, 0, 0}, // end mask
|
||||
{0, 0, 0, 0}, // new axis mask
|
||||
{0, 1, 0, 1}, // shrink axis mask
|
||||
{0, 0, 0, 0}, // ellipsis mask
|
||||
PartialShape{
|
||||
Dimension(0, 1),
|
||||
Dimension(0, 300),
|
||||
}, // reference shape
|
||||
element::f32}, // reference type
|
||||
StridedSliceTestParams{{1, 200, 200, 3}, // input_shape
|
||||
{4}, // begin shape
|
||||
{4}, // end shape
|
||||
{4}, // strides shape
|
||||
{1, 0, 0, 0}, // begin mask
|
||||
{1, 0, 0, 0}, // end mask
|
||||
{0, 0, 1, 0}, // new axis mask
|
||||
{0, 0, 0, 0}, // shrink axis mask
|
||||
{0, 0, 0, 0}, // ellipsis mask
|
||||
PartialShape{
|
||||
Dimension(0, 1),
|
||||
Dimension(0, 200),
|
||||
1,
|
||||
Dimension(0, 200),
|
||||
3,
|
||||
}, // reference shape
|
||||
element::f32} // reference type
|
||||
),
|
||||
PrintToDummyParamName());
|
||||
|
Loading…
Reference in New Issue
Block a user