Don't execute reference::strided_slice if input/output tensor is empty (#11337)

This commit is contained in:
Maxim Andronov 2022-03-31 15:42:10 +03:00 committed by GitHub
parent 9185f03e77
commit 1d247815be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View File

@ -368,6 +368,20 @@ std::vector<StridedSliceStrideOptionalParams> generateStrideOptionalParams() {
std::vector<int64_t>{0, 0, 0},
reference_tests::Tensor(IN_ET, {1, 4}, std::vector<T>{20, 21, 22, 23}),
"strided_slice_stride_optional_dynamic"),
// strided_slice_stride_optional_dynamic_empty_output_tensor
StridedSliceStrideOptionalParams(
PartialShape::dynamic(),
reference_tests::Tensor(IN_ET, {2, 3, 4}, std::vector<T>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}),
reference_tests::Tensor(element::i64, {2}, std::vector<int64_t>{0, 0}),
reference_tests::Tensor(element::i64, {2}, std::vector<int64_t>{-1, 0}),
std::vector<int64_t>{1, 0},
std::vector<int64_t>{1, 0},
std::vector<int64_t>{},
std::vector<int64_t>{},
std::vector<int64_t>{},
reference_tests::Tensor(IN_ET, {2, 0, 4}, std::vector<T>{}),
"strided_slice_stride_optional_dynamic_empty_output_tensor"),
};
return params;
}

View File

@ -18,6 +18,15 @@ void runtime::reference::strided_slice(const char* arg,
const Shape& arg_shape,
const SlicePlan& sp,
size_t elem_type) {
auto hasZeroDims = [](const ov::Shape& shape) -> bool {
return std::any_of(shape.begin(), shape.end(), [](const size_t& dim) {
return dim == 0;
});
};
if (hasZeroDims(sp.reshape_in_shape) || hasZeroDims(sp.reshape_out_shape)) {
return;
}
runtime::AlignedBuffer slice_out_buffer(shape_size(sp.reshape_in_shape) * elem_type);
slice(reinterpret_cast<const char*>(arg),
slice_out_buffer.get_ptr<char>(),