Add ScatterUpdate value infer (#12595)

* Add ScatterUpdate value infer

* Add additional test case to ScatterUpdate tests
This commit is contained in:
Maxim Vafin 2022-08-22 17:52:49 +02:00 committed by GitHub
parent 9710bde87c
commit 56808c7aed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 15 deletions

View File

@ -162,3 +162,28 @@ class ScatterSub(Scatter):
class ScatterUpdate(Scatter): class ScatterUpdate(Scatter):
op = op_type = 'ScatterUpdate' op = op_type = 'ScatterUpdate'
version = 'opset3' version = 'opset3'
@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
Scatter.infer(node)
input_shape = node.in_port(0).data.get_shape()
input_value = node.in_port(0).data.get_value()
indices_value = node.in_port(1).data.get_value()
updates_value = node.in_port(2).data.get_value()
axis = node.in_port(3).data.get_value()
if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
node_name, axis.size)
axis = axis.item()
if axis < 0:
axis = len(input_shape) + axis
out_value = input_value.copy()
for idx in np.ndindex(*input_shape[:axis]):
out_value[idx][indices_value] = updates_value[idx]
node.out_port(0).data.set_value(out_value)

View File

@ -6,7 +6,7 @@ import unittest
import numpy as np import numpy as np
from generator import generator, generate from generator import generator, generate
from openvino.tools.mo.ops.scatter import ScatterElementsUpdate from openvino.tools.mo.ops.scatter import ScatterElementsUpdate, ScatterUpdate
from openvino.tools.mo.front.common.partial_infer.utils import int64_array from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.graph.graph import Node from openvino.tools.mo.graph.graph import Node
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data
@ -73,7 +73,6 @@ class ScatterElementsInferTest(unittest.TestCase):
[32, 31]] [32, 31]]
]), ]),
]) ])
def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res): def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res):
nodes = { nodes = {
**valued_const_with_data('data', np.array(data)), **valued_const_with_data('data', np.array(data)),
@ -101,3 +100,88 @@ class ScatterElementsInferTest(unittest.TestCase):
res_output_value = scatter_el_node.out_node().value res_output_value = scatter_el_node.out_node().value
self.assertTrue(np.array_equal(ref_res, res_output_value)) self.assertTrue(np.array_equal(ref_res, res_output_value))
@generator
class ScatterUpdateInferTest(unittest.TestCase):
@generate(*[
([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]],
[[1, 2]],
[[[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2]]],
0,
[[0.0, 0.0, 0.0],
[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2]]),
# negative axis
([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]],
[[1, 2]],
[[[1.0, 1.1]],
[[1.2, 2.0]],
[[2.1, 2.2]]],
-1,
[[0.0, 1.0, 1.1],
[0.0, 1.2, 2.0],
[0.0, 2.1, 2.2]]),
# one element
([[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.]]],
[[1]],
[[[[1., 2.], [3., 4.], [5., 6.]]]],
0,
[[[0., 0.], [0., 0.], [0., 0.]],
[[1., 2.], [3., 4.], [5., 6.]],
[[0., 0.], [0., 0.], [0., 0.]]]),
# shape [2,3,3]
([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
# indices [3,2]
[[1, 2], [0, 1], [1, 2]],
# updates [2,3,2,3]
[[[[1., 2., 3.], [4., 5., 6.]],
[[7., 8., 9.], [9., 8., 7.]],
[[6., 5., 4.], [3., 2., 1.]]],
[[[1., 2., 3.], [4., 5., 6.]],
[[7., 8., 9.], [9., 8., 7.]],
[[6., 5., 4.], [3., 2., 1.]]]],
# axis
1,
# ref
[[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]],
[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]]]),
])
def test_scatter_update_value_infer(self, data, indices, updates, axis, ref_res):
nodes = {
**valued_const_with_data('data', np.array(data)),
**valued_const_with_data('indices', int64_array(indices)),
**valued_const_with_data('updates', np.array(updates)),
**valued_const_with_data('axis', int64_array(axis)),
**regular_op_with_empty_data('scatter_update', {'op': 'ScatterUpdate', 'axis': axis}),
**result()
}
graph = build_graph(nodes_attrs=nodes, edges=[
*connect('data', '0:scatter_update'),
*connect('indices', '1:scatter_update'),
*connect('updates', '2:scatter_update'),
*connect('axis', '3:scatter_update'),
*connect('scatter_update', 'output')
], nodes_with_edges_only=True)
graph.stage = 'middle'
scatter_update_node = Node(graph, 'scatter_update')
ScatterUpdate.infer(scatter_update_node)
res_output_shape = scatter_update_node.out_node().shape
self.assertTrue(np.array_equal(int64_array(ref_res).shape, res_output_shape))
res_output_value = scatter_update_node.out_node().value
self.assertTrue(np.array_equal(ref_res, res_output_value))