Extend Python API with ScatterNDUpdate-3 (#21325)

Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>
This commit is contained in:
Kristián Körmöndi 2023-11-30 16:32:45 +01:00 committed by GitHub
parent d5d9fd11b3
commit 74bf3d4e38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 194 additions and 0 deletions

View File

@ -145,6 +145,7 @@ from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -145,6 +145,7 @@ from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -146,6 +146,7 @@ from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset12.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -154,6 +154,7 @@ from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset13.ops import scaled_dot_product_attention
from openvino.runtime.opset12.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -114,6 +114,7 @@ from openvino.runtime.opset3.ops import rnn_cell
from openvino.runtime.opset3.ops import roi_align
from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -442,3 +442,23 @@ def range(
"output_type": output_type,
},
)
@nameable_op
def scatter_nd_update(
data: NodeInput,
indices: NodeInput,
updates: NodeInput,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs ScatterNDUpdate.
:param data: Node input representing the tensor to be updated.
:param indices: Node input representing the indices at which updates will be applied.
:param updates: Node input representing the updates to be applied.
:param name: Optional name for the output node.
:return: New node performing the ScatterNDUpdate.
"""
inputs = as_nodes(data, indices, updates)
return _get_node_factory_opset4().create("ScatterNDUpdate", inputs, {})

View File

@ -121,6 +121,7 @@ from openvino.runtime.opset3.ops import roi_align
from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -123,6 +123,7 @@ from openvino.runtime.opset3.ops import roi_align
from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -127,6 +127,7 @@ from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -137,6 +137,7 @@ from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -142,6 +142,7 @@ from openvino.runtime.opset2.ops import roi_pooling
from openvino.runtime.opset7.ops import roll
from openvino.runtime.opset5.ops import round
from openvino.runtime.opset3.ops import scatter_elements_update
from openvino.runtime.opset4.ops import scatter_nd_update
from openvino.runtime.opset3.ops import scatter_update
from openvino.runtime.opset1.ops import select
from openvino.runtime.opset1.ops import selu

View File

@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from openvino import PartialShape, Type
import openvino.runtime.opset13 as ov
def test_scatter_nd_update():
data_shape = [4, 4, 4]
indices_shape = [2, 1]
updates_shape = [2, 4, 4]
data_param = ov.parameter(shape=data_shape, dtype=Type.f32, name="data")
indices_param = ov.parameter(shape=indices_shape, dtype=Type.i32, name="indices")
updates_param = ov.parameter(shape=updates_shape, dtype=Type.f32, name="updates")
scatter_nd_node = ov.scatter_nd_update(data_param, indices_param, updates_param)
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
assert scatter_nd_node.get_output_size() == 1
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data_shape))
assert scatter_nd_node.get_output_element_type(0) == Type.f32
def test_scatter_nd_update_basic():
data = np.array([1, 2, 3, 4, 5])
indices = np.array([[0], [2]])
updates = np.array([9, 10])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([9, 2, 10, 4, 5])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_multidimensional():
data = np.array([[1, 2], [3, 4]])
indices = np.array([[0, 1], [1, 0]])
updates = np.array([9, 10])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([[1, 9], [10, 4]])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_mismatched_updates_shape():
data = np.array([1, 2, 3])
indices = np.array([[0], [1]])
updates = np.array([4])
with pytest.raises(RuntimeError):
ov.scatter_nd_update(data, indices, updates)
def test_scatter_nd_update_non_integer_indices():
data = np.array([1, 2, 3])
indices = np.array([[0.5]])
updates = np.array([4])
with pytest.raises(RuntimeError):
ov.scatter_nd_update(data, indices, updates)
def test_scatter_nd_update_negative_indices():
data = np.array([1, 2, 3, 4])
indices = np.array([[-1]])
updates = np.array([5])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([1, 2, 3, 5])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_multi_index_per_update():
data = np.array([[1, 2], [3, 4]])
indices = np.array([[0, 0], [0, 1]])
updates = np.array([5, 6])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([[5, 6], [3, 4]])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_non_contiguous_indices():
data = np.array([10, 20, 30, 40, 50])
indices = np.array([[0], [3]])
updates = np.array([100, 400])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([100, 20, 30, 400, 50])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_large_updates():
data = np.zeros(1000, dtype=np.float64)
indices = np.reshape(np.arange(1000), (-1, 1))
updates = np.arange(1000, dtype=np.float64)
result = ov.scatter_nd_update(data, indices, updates)
expected = np.arange(1000, dtype=np.float64)
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_overlapping_indices():
data = np.array([1, 2, 3, 4, 5])
indices = np.array([[1], [1], [3]])
updates = np.array([10, 20, 30])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([1, 20, 3, 30, 5])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_3d_data():
data = np.zeros((2, 2, 2), dtype=np.float64)
indices = np.array([[0, 0, 1], [1, 1, 0]])
updates = np.array([1, 2], dtype=np.float64)
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([[[0, 1], [0, 0]], [[0, 0], [2, 0]]], dtype=np.float64)
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_all_indices():
data = np.ones((2, 3), dtype=np.float64)
indices = np.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]])
updates = np.array([10, 20, 30, 40, 50, 60], dtype=np.float64)
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.float64)
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_invalid_updates_shape():
data = np.array([1, 2, 3, 4])
indices = np.array([[1], [2]])
updates = np.array([5])
with pytest.raises(RuntimeError):
ov.scatter_nd_update(data, indices, updates)
def test_scatter_nd_update_negative_updates():
data = np.array([1, 2, 3, 4, 5])
indices = np.array([[1], [3]])
updates = np.array([-1, -2])
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([1, -1, 3, -2, 5])
np.testing.assert_array_equal(result, expected)
def test_scatter_nd_update_empty_indices_and_updates():
data = np.array([1, 2, 3], dtype=np.float64)
indices = np.array([], dtype=np.int64).reshape(0, 1)
updates = np.array([], dtype=np.float64)
result = ov.scatter_nd_update(data, indices, updates)
expected = np.array([1, 2, 3], dtype=np.float64)
np.testing.assert_array_equal(result, expected)