Add ScatterUpdate value infer (#12595)
* Add ScatterUpdate value infer * Add additional test case to ScatterUpdate tests
This commit is contained in:
parent
9710bde87c
commit
56808c7aed
@ -162,3 +162,28 @@ class ScatterSub(Scatter):
|
||||
class ScatterUpdate(Scatter):
|
||||
op = op_type = 'ScatterUpdate'
|
||||
version = 'opset3'
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
Scatter.infer(node)
|
||||
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
|
||||
input_value = node.in_port(0).data.get_value()
|
||||
indices_value = node.in_port(1).data.get_value()
|
||||
updates_value = node.in_port(2).data.get_value()
|
||||
|
||||
axis = node.in_port(3).data.get_value()
|
||||
|
||||
if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
|
||||
assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
|
||||
node_name, axis.size)
|
||||
axis = axis.item()
|
||||
if axis < 0:
|
||||
axis = len(input_shape) + axis
|
||||
|
||||
out_value = input_value.copy()
|
||||
for idx in np.ndindex(*input_shape[:axis]):
|
||||
out_value[idx][indices_value] = updates_value[idx]
|
||||
node.out_port(0).data.set_value(out_value)
|
||||
|
@ -6,7 +6,7 @@ import unittest
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from openvino.tools.mo.ops.scatter import ScatterElementsUpdate
|
||||
from openvino.tools.mo.ops.scatter import ScatterElementsUpdate, ScatterUpdate
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data
|
||||
@ -73,7 +73,6 @@ class ScatterElementsInferTest(unittest.TestCase):
|
||||
[32, 31]]
|
||||
]),
|
||||
])
|
||||
|
||||
def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res):
|
||||
nodes = {
|
||||
**valued_const_with_data('data', np.array(data)),
|
||||
@ -101,3 +100,88 @@ class ScatterElementsInferTest(unittest.TestCase):
|
||||
|
||||
res_output_value = scatter_el_node.out_node().value
|
||||
self.assertTrue(np.array_equal(ref_res, res_output_value))
|
||||
|
||||
|
||||
@generator
|
||||
class ScatterUpdateInferTest(unittest.TestCase):
|
||||
@generate(*[
|
||||
([[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]],
|
||||
[[1, 2]],
|
||||
[[[1.0, 1.1, 1.2],
|
||||
[2.0, 2.1, 2.2]]],
|
||||
0,
|
||||
[[0.0, 0.0, 0.0],
|
||||
[1.0, 1.1, 1.2],
|
||||
[2.0, 2.1, 2.2]]),
|
||||
|
||||
# negative axis
|
||||
([[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]],
|
||||
[[1, 2]],
|
||||
[[[1.0, 1.1]],
|
||||
[[1.2, 2.0]],
|
||||
[[2.1, 2.2]]],
|
||||
-1,
|
||||
[[0.0, 1.0, 1.1],
|
||||
[0.0, 1.2, 2.0],
|
||||
[0.0, 2.1, 2.2]]),
|
||||
|
||||
# one element
|
||||
([[[0., 0.], [0., 0.], [0., 0.]],
|
||||
[[0., 0.], [0., 0.], [0., 0.]],
|
||||
[[0., 0.], [0., 0.], [0., 0.]]],
|
||||
[[1]],
|
||||
[[[[1., 2.], [3., 4.], [5., 6.]]]],
|
||||
0,
|
||||
[[[0., 0.], [0., 0.], [0., 0.]],
|
||||
[[1., 2.], [3., 4.], [5., 6.]],
|
||||
[[0., 0.], [0., 0.], [0., 0.]]]),
|
||||
|
||||
# shape [2,3,3]
|
||||
([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
|
||||
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
|
||||
# indices [3,2]
|
||||
[[1, 2], [0, 1], [1, 2]],
|
||||
# updates [2,3,2,3]
|
||||
[[[[1., 2., 3.], [4., 5., 6.]],
|
||||
[[7., 8., 9.], [9., 8., 7.]],
|
||||
[[6., 5., 4.], [3., 2., 1.]]],
|
||||
[[[1., 2., 3.], [4., 5., 6.]],
|
||||
[[7., 8., 9.], [9., 8., 7.]],
|
||||
[[6., 5., 4.], [3., 2., 1.]]]],
|
||||
# axis
|
||||
1,
|
||||
# ref
|
||||
[[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]],
|
||||
[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]]]),
|
||||
])
|
||||
def test_scatter_update_value_infer(self, data, indices, updates, axis, ref_res):
|
||||
nodes = {
|
||||
**valued_const_with_data('data', np.array(data)),
|
||||
**valued_const_with_data('indices', int64_array(indices)),
|
||||
**valued_const_with_data('updates', np.array(updates)),
|
||||
**valued_const_with_data('axis', int64_array(axis)),
|
||||
**regular_op_with_empty_data('scatter_update', {'op': 'ScatterUpdate', 'axis': axis}),
|
||||
**result()
|
||||
}
|
||||
|
||||
graph = build_graph(nodes_attrs=nodes, edges=[
|
||||
*connect('data', '0:scatter_update'),
|
||||
*connect('indices', '1:scatter_update'),
|
||||
*connect('updates', '2:scatter_update'),
|
||||
*connect('axis', '3:scatter_update'),
|
||||
*connect('scatter_update', 'output')
|
||||
], nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
|
||||
scatter_update_node = Node(graph, 'scatter_update')
|
||||
ScatterUpdate.infer(scatter_update_node)
|
||||
|
||||
res_output_shape = scatter_update_node.out_node().shape
|
||||
self.assertTrue(np.array_equal(int64_array(ref_res).shape, res_output_shape))
|
||||
|
||||
res_output_value = scatter_update_node.out_node().value
|
||||
self.assertTrue(np.array_equal(ref_res, res_output_value))
|
||||
|
Loading…
Reference in New Issue
Block a user