Add ScatterUpdate value infer (#12595)
* Add ScatterUpdate value infer * Add additional test case to ScatterUpdate tests
This commit is contained in:
parent
9710bde87c
commit
56808c7aed
@ -162,3 +162,28 @@ class ScatterSub(Scatter):
|
|||||||
class ScatterUpdate(Scatter):
|
class ScatterUpdate(Scatter):
|
||||||
op = op_type = 'ScatterUpdate'
|
op = op_type = 'ScatterUpdate'
|
||||||
version = 'opset3'
|
version = 'opset3'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer(node: Node):
|
||||||
|
node_name = node.soft_get('name', node.id)
|
||||||
|
Scatter.infer(node)
|
||||||
|
|
||||||
|
input_shape = node.in_port(0).data.get_shape()
|
||||||
|
|
||||||
|
input_value = node.in_port(0).data.get_value()
|
||||||
|
indices_value = node.in_port(1).data.get_value()
|
||||||
|
updates_value = node.in_port(2).data.get_value()
|
||||||
|
|
||||||
|
axis = node.in_port(3).data.get_value()
|
||||||
|
|
||||||
|
if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
|
||||||
|
assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
|
||||||
|
node_name, axis.size)
|
||||||
|
axis = axis.item()
|
||||||
|
if axis < 0:
|
||||||
|
axis = len(input_shape) + axis
|
||||||
|
|
||||||
|
out_value = input_value.copy()
|
||||||
|
for idx in np.ndindex(*input_shape[:axis]):
|
||||||
|
out_value[idx][indices_value] = updates_value[idx]
|
||||||
|
node.out_port(0).data.set_value(out_value)
|
||||||
|
@ -6,7 +6,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from generator import generator, generate
|
from generator import generator, generate
|
||||||
|
|
||||||
from openvino.tools.mo.ops.scatter import ScatterElementsUpdate
|
from openvino.tools.mo.ops.scatter import ScatterElementsUpdate, ScatterUpdate
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||||
from openvino.tools.mo.graph.graph import Node
|
from openvino.tools.mo.graph.graph import Node
|
||||||
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data
|
from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data
|
||||||
@ -73,7 +73,6 @@ class ScatterElementsInferTest(unittest.TestCase):
|
|||||||
[32, 31]]
|
[32, 31]]
|
||||||
]),
|
]),
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res):
|
def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res):
|
||||||
nodes = {
|
nodes = {
|
||||||
**valued_const_with_data('data', np.array(data)),
|
**valued_const_with_data('data', np.array(data)),
|
||||||
@ -101,3 +100,88 @@ class ScatterElementsInferTest(unittest.TestCase):
|
|||||||
|
|
||||||
res_output_value = scatter_el_node.out_node().value
|
res_output_value = scatter_el_node.out_node().value
|
||||||
self.assertTrue(np.array_equal(ref_res, res_output_value))
|
self.assertTrue(np.array_equal(ref_res, res_output_value))
|
||||||
|
|
||||||
|
|
||||||
|
@generator
|
||||||
|
class ScatterUpdateInferTest(unittest.TestCase):
|
||||||
|
@generate(*[
|
||||||
|
([[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0]],
|
||||||
|
[[1, 2]],
|
||||||
|
[[[1.0, 1.1, 1.2],
|
||||||
|
[2.0, 2.1, 2.2]]],
|
||||||
|
0,
|
||||||
|
[[0.0, 0.0, 0.0],
|
||||||
|
[1.0, 1.1, 1.2],
|
||||||
|
[2.0, 2.1, 2.2]]),
|
||||||
|
|
||||||
|
# negative axis
|
||||||
|
([[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0]],
|
||||||
|
[[1, 2]],
|
||||||
|
[[[1.0, 1.1]],
|
||||||
|
[[1.2, 2.0]],
|
||||||
|
[[2.1, 2.2]]],
|
||||||
|
-1,
|
||||||
|
[[0.0, 1.0, 1.1],
|
||||||
|
[0.0, 1.2, 2.0],
|
||||||
|
[0.0, 2.1, 2.2]]),
|
||||||
|
|
||||||
|
# one element
|
||||||
|
([[[0., 0.], [0., 0.], [0., 0.]],
|
||||||
|
[[0., 0.], [0., 0.], [0., 0.]],
|
||||||
|
[[0., 0.], [0., 0.], [0., 0.]]],
|
||||||
|
[[1]],
|
||||||
|
[[[[1., 2.], [3., 4.], [5., 6.]]]],
|
||||||
|
0,
|
||||||
|
[[[0., 0.], [0., 0.], [0., 0.]],
|
||||||
|
[[1., 2.], [3., 4.], [5., 6.]],
|
||||||
|
[[0., 0.], [0., 0.], [0., 0.]]]),
|
||||||
|
|
||||||
|
# shape [2,3,3]
|
||||||
|
([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
|
||||||
|
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
|
||||||
|
# indices [3,2]
|
||||||
|
[[1, 2], [0, 1], [1, 2]],
|
||||||
|
# updates [2,3,2,3]
|
||||||
|
[[[[1., 2., 3.], [4., 5., 6.]],
|
||||||
|
[[7., 8., 9.], [9., 8., 7.]],
|
||||||
|
[[6., 5., 4.], [3., 2., 1.]]],
|
||||||
|
[[[1., 2., 3.], [4., 5., 6.]],
|
||||||
|
[[7., 8., 9.], [9., 8., 7.]],
|
||||||
|
[[6., 5., 4.], [3., 2., 1.]]]],
|
||||||
|
# axis
|
||||||
|
1,
|
||||||
|
# ref
|
||||||
|
[[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]],
|
||||||
|
[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]]]),
|
||||||
|
])
|
||||||
|
def test_scatter_update_value_infer(self, data, indices, updates, axis, ref_res):
|
||||||
|
nodes = {
|
||||||
|
**valued_const_with_data('data', np.array(data)),
|
||||||
|
**valued_const_with_data('indices', int64_array(indices)),
|
||||||
|
**valued_const_with_data('updates', np.array(updates)),
|
||||||
|
**valued_const_with_data('axis', int64_array(axis)),
|
||||||
|
**regular_op_with_empty_data('scatter_update', {'op': 'ScatterUpdate', 'axis': axis}),
|
||||||
|
**result()
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = build_graph(nodes_attrs=nodes, edges=[
|
||||||
|
*connect('data', '0:scatter_update'),
|
||||||
|
*connect('indices', '1:scatter_update'),
|
||||||
|
*connect('updates', '2:scatter_update'),
|
||||||
|
*connect('axis', '3:scatter_update'),
|
||||||
|
*connect('scatter_update', 'output')
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
graph.stage = 'middle'
|
||||||
|
|
||||||
|
scatter_update_node = Node(graph, 'scatter_update')
|
||||||
|
ScatterUpdate.infer(scatter_update_node)
|
||||||
|
|
||||||
|
res_output_shape = scatter_update_node.out_node().shape
|
||||||
|
self.assertTrue(np.array_equal(int64_array(ref_res).shape, res_output_shape))
|
||||||
|
|
||||||
|
res_output_value = scatter_update_node.out_node().value
|
||||||
|
self.assertTrue(np.array_equal(ref_res, res_output_value))
|
||||||
|
Loading…
Reference in New Issue
Block a user