diff --git a/src/bindings/python/src/openvino/runtime/opset10/__init__.py b/src/bindings/python/src/openvino/runtime/opset10/__init__.py index 1d152472fd8..ade2b0dc555 100644 --- a/src/bindings/python/src/openvino/runtime/opset10/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset10/__init__.py @@ -145,6 +145,7 @@ from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset11/__init__.py b/src/bindings/python/src/openvino/runtime/opset11/__init__.py index 5522c5bfaa2..50513a812c0 100644 --- a/src/bindings/python/src/openvino/runtime/opset11/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset11/__init__.py @@ -145,6 +145,7 @@ from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset12/__init__.py b/src/bindings/python/src/openvino/runtime/opset12/__init__.py index 6ef3c50a49a..b864996e044 100644 --- a/src/bindings/python/src/openvino/runtime/opset12/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset12/__init__.py @@ -146,6 +146,7 @@ from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset12.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset13/__init__.py b/src/bindings/python/src/openvino/runtime/opset13/__init__.py index 441831dd5f2..9c544b0d7e7 100644 --- a/src/bindings/python/src/openvino/runtime/opset13/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset13/__init__.py @@ -154,6 +154,7 @@ from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset13.ops import scaled_dot_product_attention from openvino.runtime.opset12.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset4/__init__.py b/src/bindings/python/src/openvino/runtime/opset4/__init__.py index 6ed68e392e5..d84f4ad6e18 100644 --- a/src/bindings/python/src/openvino/runtime/opset4/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset4/__init__.py @@ -114,6 +114,7 @@ from openvino.runtime.opset3.ops import rnn_cell from openvino.runtime.opset3.ops import roi_align from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset4/ops.py b/src/bindings/python/src/openvino/runtime/opset4/ops.py index e5489703784..4056053a692 100644 --- a/src/bindings/python/src/openvino/runtime/opset4/ops.py +++ b/src/bindings/python/src/openvino/runtime/opset4/ops.py @@ -442,3 +442,23 @@ def range( "output_type": output_type, }, ) + + +@nameable_op +def scatter_nd_update( + data: NodeInput, + indices: NodeInput, + updates: NodeInput, + name: Optional[str] = None, +) -> Node: + """Return a node which performs ScatterNDUpdate. + + :param data: Node input representing the tensor to be updated. + :param indices: Node input representing the indices at which updates will be applied. + :param updates: Node input representing the updates to be applied. + :param name: Optional name for the output node. + :return: New node performing the ScatterNDUpdate. + """ + inputs = as_nodes(data, indices, updates) + + return _get_node_factory_opset4().create("ScatterNDUpdate", inputs, {}) diff --git a/src/bindings/python/src/openvino/runtime/opset5/__init__.py b/src/bindings/python/src/openvino/runtime/opset5/__init__.py index 8f829cfced9..0651265b756 100644 --- a/src/bindings/python/src/openvino/runtime/opset5/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset5/__init__.py @@ -121,6 +121,7 @@ from openvino.runtime.opset3.ops import roi_align from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset6/__init__.py b/src/bindings/python/src/openvino/runtime/opset6/__init__.py index d05e5af18c7..d22fe8c4f2d 100644 --- a/src/bindings/python/src/openvino/runtime/opset6/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset6/__init__.py @@ -123,6 +123,7 @@ from openvino.runtime.opset3.ops import roi_align from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset7/__init__.py b/src/bindings/python/src/openvino/runtime/opset7/__init__.py index c391c99bd6a..fce9b001f78 100644 --- a/src/bindings/python/src/openvino/runtime/opset7/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset7/__init__.py @@ -127,6 +127,7 @@ from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset8/__init__.py b/src/bindings/python/src/openvino/runtime/opset8/__init__.py index aa9a2c35915..b30cde97be9 100644 --- a/src/bindings/python/src/openvino/runtime/opset8/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset8/__init__.py @@ -137,6 +137,7 @@ from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/src/openvino/runtime/opset9/__init__.py b/src/bindings/python/src/openvino/runtime/opset9/__init__.py index 1e9103c65f3..d08b873e0ca 100644 --- a/src/bindings/python/src/openvino/runtime/opset9/__init__.py +++ b/src/bindings/python/src/openvino/runtime/opset9/__init__.py @@ -142,6 +142,7 @@ from openvino.runtime.opset2.ops import roi_pooling from openvino.runtime.opset7.ops import roll from openvino.runtime.opset5.ops import round from openvino.runtime.opset3.ops import scatter_elements_update +from openvino.runtime.opset4.ops import scatter_nd_update from openvino.runtime.opset3.ops import scatter_update from openvino.runtime.opset1.ops import select from openvino.runtime.opset1.ops import selu diff --git a/src/bindings/python/tests/test_graph/test_ops_scatter_nd_update.py b/src/bindings/python/tests/test_graph/test_ops_scatter_nd_update.py new file mode 100644 index 00000000000..0c8f6ca3438 --- /dev/null +++ b/src/bindings/python/tests/test_graph/test_ops_scatter_nd_update.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from openvino import PartialShape, Type + +import openvino.runtime.opset13 as ov + + +def test_scatter_nd_update(): + data_shape = [4, 4, 4] + indices_shape = [2, 1] + updates_shape = [2, 4, 4] + + data_param = ov.parameter(shape=data_shape, dtype=Type.f32, name="data") + indices_param = ov.parameter(shape=indices_shape, dtype=Type.i32, name="indices") + updates_param = ov.parameter(shape=updates_shape, dtype=Type.f32, name="updates") + + scatter_nd_node = ov.scatter_nd_update(data_param, indices_param, updates_param) + + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" + assert scatter_nd_node.get_output_size() == 1 + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data_shape)) + assert scatter_nd_node.get_output_element_type(0) == Type.f32 + + +def test_scatter_nd_update_basic(): + data = np.array([1, 2, 3, 4, 5]) + indices = np.array([[0], [2]]) + updates = np.array([9, 10]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([9, 2, 10, 4, 5]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_multidimensional(): + data = np.array([[1, 2], [3, 4]]) + indices = np.array([[0, 1], [1, 0]]) + updates = np.array([9, 10]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([[1, 9], [10, 4]]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_mismatched_updates_shape(): + data = np.array([1, 2, 3]) + indices = np.array([[0], [1]]) + updates = np.array([4]) + + with pytest.raises(RuntimeError): + ov.scatter_nd_update(data, indices, updates) + + +def test_scatter_nd_update_non_integer_indices(): + data = np.array([1, 2, 3]) + indices = np.array([[0.5]]) + updates = np.array([4]) + + with pytest.raises(RuntimeError): + ov.scatter_nd_update(data, indices, updates) + + +def test_scatter_nd_update_negative_indices(): + data = np.array([1, 2, 3, 4]) + indices = np.array([[-1]]) + updates = np.array([5]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([1, 2, 3, 5]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_multi_index_per_update(): + data = np.array([[1, 2], [3, 4]]) + indices = np.array([[0, 0], [0, 1]]) + updates = np.array([5, 6]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([[5, 6], [3, 4]]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_non_contiguous_indices(): + data = np.array([10, 20, 30, 40, 50]) + indices = np.array([[0], [3]]) + updates = np.array([100, 400]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([100, 20, 30, 400, 50]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_large_updates(): + data = np.zeros(1000, dtype=np.float64) + indices = np.reshape(np.arange(1000), (-1, 1)) + updates = np.arange(1000, dtype=np.float64) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.arange(1000, dtype=np.float64) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_overlapping_indices(): + data = np.array([1, 2, 3, 4, 5]) + indices = np.array([[1], [1], [3]]) + updates = np.array([10, 20, 30]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([1, 20, 3, 30, 5]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_3d_data(): + data = np.zeros((2, 2, 2), dtype=np.float64) + indices = np.array([[0, 0, 1], [1, 1, 0]]) + updates = np.array([1, 2], dtype=np.float64) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([[[0, 1], [0, 0]], [[0, 0], [2, 0]]], dtype=np.float64) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_all_indices(): + data = np.ones((2, 3), dtype=np.float64) + indices = np.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]]) + updates = np.array([10, 20, 30, 40, 50, 60], dtype=np.float64) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.float64) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_invalid_updates_shape(): + data = np.array([1, 2, 3, 4]) + indices = np.array([[1], [2]]) + updates = np.array([5]) + + with pytest.raises(RuntimeError): + ov.scatter_nd_update(data, indices, updates) + + +def test_scatter_nd_update_negative_updates(): + data = np.array([1, 2, 3, 4, 5]) + indices = np.array([[1], [3]]) + updates = np.array([-1, -2]) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([1, -1, 3, -2, 5]) + np.testing.assert_array_equal(result, expected) + + +def test_scatter_nd_update_empty_indices_and_updates(): + data = np.array([1, 2, 3], dtype=np.float64) + indices = np.array([], dtype=np.int64).reshape(0, 1) + updates = np.array([], dtype=np.float64) + + result = ov.scatter_nd_update(data, indices, updates) + expected = np.array([1, 2, 3], dtype=np.float64) + np.testing.assert_array_equal(result, expected)