[nGraph] Fix bound check for reference GatherElements (#3981)
* fix bound check for reference GatherElements * apply review comments
This commit is contained in:
parent
647104f602
commit
efa5b6063a
@ -49,7 +49,7 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
for (int64_t i = 0; i < indices_shape[0]; i++)
|
for (int64_t i = 0; i < indices_shape[0]; i++)
|
||||||
{
|
{
|
||||||
if (indices[i] > data_shape[0])
|
if (indices[i] >= data_shape[0])
|
||||||
{
|
{
|
||||||
throw std::domain_error{
|
throw std::domain_error{
|
||||||
"indices values of GatherElement exceed data size"};
|
"indices values of GatherElement exceed data size"};
|
||||||
@ -73,7 +73,7 @@ namespace ngraph
|
|||||||
for (int64_t j = 0; j < num_columns; j++)
|
for (int64_t j = 0; j < num_columns; j++)
|
||||||
{
|
{
|
||||||
idx = indices[num_columns * i + j];
|
idx = indices[num_columns * i + j];
|
||||||
if (idx < 0 || idx > data_shape[0] - 1)
|
if (idx < 0 || idx >= data_shape[0])
|
||||||
{
|
{
|
||||||
throw std::domain_error{
|
throw std::domain_error{
|
||||||
"indices values of GatherElement exceed data size"};
|
"indices values of GatherElement exceed data size"};
|
||||||
@ -88,7 +88,7 @@ namespace ngraph
|
|||||||
for (int64_t j = 0; j < num_columns; j++)
|
for (int64_t j = 0; j < num_columns; j++)
|
||||||
{
|
{
|
||||||
idx = indices[num_columns * i + j];
|
idx = indices[num_columns * i + j];
|
||||||
if (idx < 0 || idx > data_shape[1] - 1)
|
if (idx < 0 || idx >= data_shape[1])
|
||||||
{
|
{
|
||||||
throw std::domain_error{
|
throw std::domain_error{
|
||||||
"indices values of GatherElement exceed data size"};
|
"indices values of GatherElement exceed data size"};
|
||||||
@ -135,7 +135,7 @@ namespace ngraph
|
|||||||
for (size_t k = 0; k < indices_shape[axis]; k++)
|
for (size_t k = 0; k < indices_shape[axis]; k++)
|
||||||
for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++)
|
for (size_t inner_sum = 0; inner_sum < max_inner_sum; inner_sum++)
|
||||||
{
|
{
|
||||||
if (indices[i] < 0 || indices[i] > data_shape[axis] - 1)
|
if (indices[i] < 0 || indices[i] >= data_shape[axis])
|
||||||
{
|
{
|
||||||
throw std::domain_error{
|
throw std::domain_error{
|
||||||
"indices values of GatherElement exceed data size"};
|
"indices values of GatherElement exceed data size"};
|
||||||
|
Loading…
Reference in New Issue
Block a user