From 56808c7aeddb7063a846a986d0b0b1f046f0a87b Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 22 Aug 2022 17:52:49 +0200 Subject: [PATCH] Add ScatterUpdate value infer (#12595) * Add ScatterUpdate value infer * Add additional test case to ScatterUpdate tests --- tools/mo/openvino/tools/mo/ops/scatter.py | 25 +++++ tools/mo/unit_tests/mo/ops/scatter_test.py | 114 ++++++++++++++++++--- 2 files changed, 124 insertions(+), 15 deletions(-) diff --git a/tools/mo/openvino/tools/mo/ops/scatter.py b/tools/mo/openvino/tools/mo/ops/scatter.py index 785aaecdd2e..8e626519565 100644 --- a/tools/mo/openvino/tools/mo/ops/scatter.py +++ b/tools/mo/openvino/tools/mo/ops/scatter.py @@ -162,3 +162,28 @@ class ScatterSub(Scatter): class ScatterUpdate(Scatter): op = op_type = 'ScatterUpdate' version = 'opset3' + + @staticmethod + def infer(node: Node): + node_name = node.soft_get('name', node.id) + Scatter.infer(node) + + input_shape = node.in_port(0).data.get_shape() + + input_value = node.in_port(0).data.get_value() + indices_value = node.in_port(1).data.get_value() + updates_value = node.in_port(2).data.get_value() + + axis = node.in_port(3).data.get_value() + + if input_value is not None and indices_value is not None and updates_value is not None and axis is not None: + assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format( + node_name, axis.size) + axis = axis.item() + if axis < 0: + axis = len(input_shape) + axis + + out_value = input_value.copy() + for idx in np.ndindex(*input_shape[:axis]): + out_value[idx][indices_value] = updates_value[idx] + node.out_port(0).data.set_value(out_value) diff --git a/tools/mo/unit_tests/mo/ops/scatter_test.py b/tools/mo/unit_tests/mo/ops/scatter_test.py index b679af0ecb8..9dfbf1b5424 100644 --- a/tools/mo/unit_tests/mo/ops/scatter_test.py +++ b/tools/mo/unit_tests/mo/ops/scatter_test.py @@ -6,7 +6,7 @@ import unittest import numpy as np from generator import generator, generate -from openvino.tools.mo.ops.scatter import ScatterElementsUpdate +from openvino.tools.mo.ops.scatter import ScatterElementsUpdate, ScatterUpdate from openvino.tools.mo.front.common.partial_infer.utils import int64_array from openvino.tools.mo.graph.graph import Node from unit_tests.utils.graph import build_graph, regular_op_with_empty_data, result, connect, valued_const_with_data @@ -40,20 +40,20 @@ class ScatterElementsInferTest(unittest.TestCase): [[1.0, 1.1, 3.0, 2.1, 5.0]]), ([ # 3D case - [[1, 2], - [3, 4]], - [[5, 6], - [7, 8]], - [[9, 10], - [11, 12]] - ], + [[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]] + ], [ - [[1, 0], - [0, 1]], - [[1, 0], - [1, 0]], - [[0, 1], - [1, 0]] + [[1, 0], + [0, 1]], + [[1, 0], + [1, 0]], + [[0, 1], + [1, 0]] ], [ [[21, 22], @@ -73,7 +73,6 @@ class ScatterElementsInferTest(unittest.TestCase): [32, 31]] ]), ]) - def test_scatterelements_value_infer(self, data, indices, updates, axis, ref_res): nodes = { **valued_const_with_data('data', np.array(data)), @@ -101,3 +100,88 @@ class ScatterElementsInferTest(unittest.TestCase): res_output_value = scatter_el_node.out_node().value self.assertTrue(np.array_equal(ref_res, res_output_value)) + + +@generator +class ScatterUpdateInferTest(unittest.TestCase): + @generate(*[ + ([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]], + [[1, 2]], + [[[1.0, 1.1, 1.2], + [2.0, 2.1, 2.2]]], + 0, + [[0.0, 0.0, 0.0], + [1.0, 1.1, 1.2], + [2.0, 2.1, 2.2]]), + + # negative axis + ([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]], + [[1, 2]], + [[[1.0, 1.1]], + [[1.2, 2.0]], + [[2.1, 2.2]]], + -1, + [[0.0, 1.0, 1.1], + [0.0, 1.2, 2.0], + [0.0, 2.1, 2.2]]), + + # one element + ([[[0., 0.], [0., 0.], [0., 0.]], + [[0., 0.], [0., 0.], [0., 0.]], + [[0., 0.], [0., 0.], [0., 0.]]], + [[1]], + [[[[1., 2.], [3., 4.], [5., 6.]]]], + 0, + [[[0., 0.], [0., 0.], [0., 0.]], + [[1., 2.], [3., 4.], [5., 6.]], + [[0., 0.], [0., 0.], [0., 0.]]]), + + # shape [2,3,3] + ([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + # indices [3,2] + [[1, 2], [0, 1], [1, 2]], + # updates [2,3,2,3] + [[[[1., 2., 3.], [4., 5., 6.]], + [[7., 8., 9.], [9., 8., 7.]], + [[6., 5., 4.], [3., 2., 1.]]], + [[[1., 2., 3.], [4., 5., 6.]], + [[7., 8., 9.], [9., 8., 7.]], + [[6., 5., 4.], [3., 2., 1.]]]], + # axis + 1, + # ref + [[[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]], + [[7., 8., 9.], [6., 5., 4.], [3., 2., 1.]]]), + ]) + def test_scatter_update_value_infer(self, data, indices, updates, axis, ref_res): + nodes = { + **valued_const_with_data('data', np.array(data)), + **valued_const_with_data('indices', int64_array(indices)), + **valued_const_with_data('updates', np.array(updates)), + **valued_const_with_data('axis', int64_array(axis)), + **regular_op_with_empty_data('scatter_update', {'op': 'ScatterUpdate', 'axis': axis}), + **result() + } + + graph = build_graph(nodes_attrs=nodes, edges=[ + *connect('data', '0:scatter_update'), + *connect('indices', '1:scatter_update'), + *connect('updates', '2:scatter_update'), + *connect('axis', '3:scatter_update'), + *connect('scatter_update', 'output') + ], nodes_with_edges_only=True) + graph.stage = 'middle' + + scatter_update_node = Node(graph, 'scatter_update') + ScatterUpdate.infer(scatter_update_node) + + res_output_shape = scatter_update_node.out_node().shape + self.assertTrue(np.array_equal(int64_array(ref_res).shape, res_output_shape)) + + res_output_value = scatter_update_node.out_node().value + self.assertTrue(np.array_equal(ref_res, res_output_value))