[MO] Fix IndexError inside ScatterNDUpdate shape inference function (#11220)

* Restore inputs order in IR Reader

* Add WA to numpy ndarrays IndexError

* Add comments to code

* Add unit test
This commit is contained in:
Anton Chetverikov
2022-04-04 23:59:24 +03:00
committed by GitHub
parent c75bc65b83
commit c1dc71ce28
2 changed files with 27 additions and 1 deletions

View File

@@ -91,7 +91,11 @@ class ScatterNDUpdate(ScatterNDBase):
# a case when updates is a scalar
indx = 0
updates_value = [updates_value]
output_value[indices_value[indx]] = updates_value[indx]
insert_index = indices_value[indx]
# we check and change index type explicitly to avoid error in indexing ndarray by another ndarray
if isinstance(insert_index, np.ndarray):
insert_index = tuple(insert_index)
output_value[insert_index] = updates_value[indx]
node.out_port(0).data.set_value(output_value)

View File

@@ -67,6 +67,17 @@ inputs8 = {'input': {'shape': int64_array([3]), 'value': int64_array([1, 2, 3])}
'updates': {'shape': int64_array([1]), 'value': int64_array([9])}}
output8 = int64_array([1, 2, 9])
inputs9 = {'input': {'shape': int64_array([1, 5, 5, 1]), 'value': np.zeros([1, 5, 5, 1],dtype=np.int32)},
'indices': {'shape': int64_array([1, 2, 2, 1, 4]),
'value': np.array([[[[[0, 0, 0, 0]], [[0, 0, 1, 0]]], [[[0, 2, 1, 0]], [[0, 3, 4, 0]]]]])},
'updates': {'shape': int64_array([1, 2, 2, 1]), 'value': np.ones([1, 2, 2, 1])}}
output9 = np.array([[[[1], [1], [0], [0], [0]], # shape [1, 5, 5, 1]
[[0], [0], [0], [0], [0]],
[[0], [1], [0], [0], [0]],
[[0], [0], [0], [0], [1]],
[[0], [0], [0], [0], [0]]]])
class TestScatterNDUpdate(unittest.TestCase):
def test_partial_infer1(self):
graph = build_graph(nodes_attributes, edges, inputs1)
@@ -167,3 +178,14 @@ class TestScatterNDUpdate(unittest.TestCase):
self.assertTrue(np.array_equal(output8, res_output_value),
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
def test_infer9(self):
graph = build_graph(nodes_attributes, edges, inputs9)
scatternd_node = Node(graph, 'scatternd_node')
ScatterNDUpdate.infer(scatternd_node)
# get the result
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output9, res_output_value),
'values do not match expected: {} and given: {}'.format(output8, res_output_value))