Fix check repeats in values in Tile (#20654)

- no action if any of repeats is zero

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Pawel Raasz 2023-10-25 05:27:21 +02:00 committed by GitHub
parent 6fa4f9fd78
commit 84732515b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 1 deletions

View File

@ -26,7 +26,7 @@ void tile(const char* arg,
const Shape& out_shape,
const size_t elem_size,
const std::vector<int64_t>& repeats) {
if (std::all_of(repeats.begin(), repeats.end(), [](int64_t repeat) {
if (std::any_of(repeats.begin(), repeats.end(), [](int64_t repeat) {
return repeat == 0;
})) {
return;

View File

@ -102,6 +102,14 @@ std::vector<TileParams> generateParams() {
reference_tests::Tensor(ET_INT, {2}, std::vector<T_INT>{2, 1}),
reference_tests::Tensor(ET, {2, 2, 3}, std::vector<T>{1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}),
"tile_3d_to_3d_repeats_broadcast"),
TileParams(reference_tests::Tensor(ET, {1}, std::vector<T>{1}),
reference_tests::Tensor(ET_INT, {3}, std::vector<T_INT>{0, 2, 3}),
reference_tests::Tensor(ET, {0}, std::vector<T>{}),
"tile_1d_to_3d_with_zero_on_axis_0"),
TileParams(reference_tests::Tensor(ET, {3}, std::vector<T>{1, 2, 3}),
reference_tests::Tensor(ET_INT, {3}, std::vector<T_INT>{2, 0, 3}),
reference_tests::Tensor(ET, {0}, std::vector<T>{}),
"tile_1d_to_3d_with_zero_on_axis_1"),
};
return params;
}