[MO] Fix IndexError inside ScatterNDUpdate shape inference function (#11220)
* Restore inputs order in IR Reader * Add WA to numpy ndarrays IndexError * Add comments to code * Add unit test
This commit is contained in:
committed by
GitHub
parent
c75bc65b83
commit
c1dc71ce28
@@ -91,7 +91,11 @@ class ScatterNDUpdate(ScatterNDBase):
|
||||
# a case when updates is a scalar
|
||||
indx = 0
|
||||
updates_value = [updates_value]
|
||||
output_value[indices_value[indx]] = updates_value[indx]
|
||||
insert_index = indices_value[indx]
|
||||
# we check and change index type explicitly to avoid error in indexing ndarray by another ndarray
|
||||
if isinstance(insert_index, np.ndarray):
|
||||
insert_index = tuple(insert_index)
|
||||
output_value[insert_index] = updates_value[indx]
|
||||
|
||||
node.out_port(0).data.set_value(output_value)
|
||||
|
||||
|
||||
@@ -67,6 +67,17 @@ inputs8 = {'input': {'shape': int64_array([3]), 'value': int64_array([1, 2, 3])}
|
||||
'updates': {'shape': int64_array([1]), 'value': int64_array([9])}}
|
||||
output8 = int64_array([1, 2, 9])
|
||||
|
||||
inputs9 = {'input': {'shape': int64_array([1, 5, 5, 1]), 'value': np.zeros([1, 5, 5, 1],dtype=np.int32)},
|
||||
'indices': {'shape': int64_array([1, 2, 2, 1, 4]),
|
||||
'value': np.array([[[[[0, 0, 0, 0]], [[0, 0, 1, 0]]], [[[0, 2, 1, 0]], [[0, 3, 4, 0]]]]])},
|
||||
'updates': {'shape': int64_array([1, 2, 2, 1]), 'value': np.ones([1, 2, 2, 1])}}
|
||||
|
||||
output9 = np.array([[[[1], [1], [0], [0], [0]], # shape [1, 5, 5, 1]
|
||||
[[0], [0], [0], [0], [0]],
|
||||
[[0], [1], [0], [0], [0]],
|
||||
[[0], [0], [0], [0], [1]],
|
||||
[[0], [0], [0], [0], [0]]]])
|
||||
|
||||
class TestScatterNDUpdate(unittest.TestCase):
|
||||
def test_partial_infer1(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs1)
|
||||
@@ -167,3 +178,14 @@ class TestScatterNDUpdate(unittest.TestCase):
|
||||
|
||||
self.assertTrue(np.array_equal(output8, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
|
||||
|
||||
def test_infer9(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs9)
|
||||
scatternd_node = Node(graph, 'scatternd_node')
|
||||
ScatterNDUpdate.infer(scatternd_node)
|
||||
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output9, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
|
||||
|
||||
Reference in New Issue
Block a user