[nGraph] Fix bound check for reference GatherElements (#3981)

* fix bound check for reference GatherElements

* apply review comments
This commit is contained in:
Pavel Esir 2021-01-25 15:10:52 +03:00 committed by GitHub
parent 647104f602
commit efa5b6063a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,7 +49,7 @@ namespace ngraph
{ {
for (int64_t i = 0; i < indices_shape[0]; i++) for (int64_t i = 0; i < indices_shape[0]; i++)
{ {
if (indices[i] > data_shape[0]) if (indices[i] >= data_shape[0])
{ {
throw std::domain_error{ throw std::domain_error{
"indices values of GatherElement exceed data size"}; "indices values of GatherElement exceed data size"};
@ -73,7 +73,7 @@ namespace ngraph
for (int64_t j = 0; j < num_columns; j++) for (int64_t j = 0; j < num_columns; j++)
{ {
idx = indices[num_columns * i + j]; idx = indices[num_columns * i + j];
if (idx < 0 || idx > data_shape[0] - 1) if (idx < 0 || idx >= data_shape[0])
{ {
throw std::domain_error{ throw std::domain_error{
"indices values of GatherElement exceed data size"}; "indices values of GatherElement exceed data size"};
@ -88,7 +88,7 @@ namespace ngraph
for (int64_t j = 0; j < num_columns; j++) for (int64_t j = 0; j < num_columns; j++)
{ {
idx = indices[num_columns * i + j]; idx = indices[num_columns * i + j];
if (idx < 0 || idx > data_shape[1] - 1) if (idx < 0 || idx >= data_shape[1])
{ {
throw std::domain_error{ throw std::domain_error{
"indices values of GatherElement exceed data size"}; "indices values of GatherElement exceed data size"};
@ -135,7 +135,7 @@ namespace ngraph
for (size_t k = 0; k < indices_shape[axis]; k++) for (size_t k = 0; k < indices_shape[axis]; k++)
for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++) for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++)
{ {
if (indices[i] < 0 || indices[i] > data_shape[axis] - 1) if (indices[i] < 0 || indices[i] >= data_shape[axis])
{ {
throw std::domain_error{ throw std::domain_error{
"indices values of GatherElement exceed data size"}; "indices values of GatherElement exceed data size"};