Extend Python API with ScatterNDUpdate-3 (#21325)
Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>
This commit is contained in:
parent
d5d9fd11b3
commit
74bf3d4e38
@ -145,6 +145,7 @@ from openvino.runtime.opset2.ops import roi_pooling
|
|||||||
from openvino.runtime.opset7.ops import roll
|
from openvino.runtime.opset7.ops import roll
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -145,6 +145,7 @@ from openvino.runtime.opset2.ops import roi_pooling
|
|||||||
from openvino.runtime.opset7.ops import roll
|
from openvino.runtime.opset7.ops import roll
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -146,6 +146,7 @@ from openvino.runtime.opset2.ops import roi_pooling
|
|||||||
from openvino.runtime.opset7.ops import roll
|
from openvino.runtime.opset7.ops import roll
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset12.ops import scatter_elements_update
|
from openvino.runtime.opset12.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -154,6 +154,7 @@ from openvino.runtime.opset7.ops import roll
|
|||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset13.ops import scaled_dot_product_attention
|
from openvino.runtime.opset13.ops import scaled_dot_product_attention
|
||||||
from openvino.runtime.opset12.ops import scatter_elements_update
|
from openvino.runtime.opset12.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -114,6 +114,7 @@ from openvino.runtime.opset3.ops import rnn_cell
|
|||||||
from openvino.runtime.opset3.ops import roi_align
|
from openvino.runtime.opset3.ops import roi_align
|
||||||
from openvino.runtime.opset2.ops import roi_pooling
|
from openvino.runtime.opset2.ops import roi_pooling
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -442,3 +442,23 @@ def range(
|
|||||||
"output_type": output_type,
|
"output_type": output_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@nameable_op
|
||||||
|
def scatter_nd_update(
|
||||||
|
data: NodeInput,
|
||||||
|
indices: NodeInput,
|
||||||
|
updates: NodeInput,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Node:
|
||||||
|
"""Return a node which performs ScatterNDUpdate.
|
||||||
|
|
||||||
|
:param data: Node input representing the tensor to be updated.
|
||||||
|
:param indices: Node input representing the indices at which updates will be applied.
|
||||||
|
:param updates: Node input representing the updates to be applied.
|
||||||
|
:param name: Optional name for the output node.
|
||||||
|
:return: New node performing the ScatterNDUpdate.
|
||||||
|
"""
|
||||||
|
inputs = as_nodes(data, indices, updates)
|
||||||
|
|
||||||
|
return _get_node_factory_opset4().create("ScatterNDUpdate", inputs, {})
|
||||||
|
@ -121,6 +121,7 @@ from openvino.runtime.opset3.ops import roi_align
|
|||||||
from openvino.runtime.opset2.ops import roi_pooling
|
from openvino.runtime.opset2.ops import roi_pooling
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -123,6 +123,7 @@ from openvino.runtime.opset3.ops import roi_align
|
|||||||
from openvino.runtime.opset2.ops import roi_pooling
|
from openvino.runtime.opset2.ops import roi_pooling
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -127,6 +127,7 @@ from openvino.runtime.opset2.ops import roi_pooling
|
|||||||
from openvino.runtime.opset7.ops import roll
|
from openvino.runtime.opset7.ops import roll
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -137,6 +137,7 @@ from openvino.runtime.opset2.ops import roi_pooling
|
|||||||
from openvino.runtime.opset7.ops import roll
|
from openvino.runtime.opset7.ops import roll
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -142,6 +142,7 @@ from openvino.runtime.opset2.ops import roi_pooling
|
|||||||
from openvino.runtime.opset7.ops import roll
|
from openvino.runtime.opset7.ops import roll
|
||||||
from openvino.runtime.opset5.ops import round
|
from openvino.runtime.opset5.ops import round
|
||||||
from openvino.runtime.opset3.ops import scatter_elements_update
|
from openvino.runtime.opset3.ops import scatter_elements_update
|
||||||
|
from openvino.runtime.opset4.ops import scatter_nd_update
|
||||||
from openvino.runtime.opset3.ops import scatter_update
|
from openvino.runtime.opset3.ops import scatter_update
|
||||||
from openvino.runtime.opset1.ops import select
|
from openvino.runtime.opset1.ops import select
|
||||||
from openvino.runtime.opset1.ops import selu
|
from openvino.runtime.opset1.ops import selu
|
||||||
|
@ -0,0 +1,164 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from openvino import PartialShape, Type
|
||||||
|
|
||||||
|
import openvino.runtime.opset13 as ov
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update():
|
||||||
|
data_shape = [4, 4, 4]
|
||||||
|
indices_shape = [2, 1]
|
||||||
|
updates_shape = [2, 4, 4]
|
||||||
|
|
||||||
|
data_param = ov.parameter(shape=data_shape, dtype=Type.f32, name="data")
|
||||||
|
indices_param = ov.parameter(shape=indices_shape, dtype=Type.i32, name="indices")
|
||||||
|
updates_param = ov.parameter(shape=updates_shape, dtype=Type.f32, name="updates")
|
||||||
|
|
||||||
|
scatter_nd_node = ov.scatter_nd_update(data_param, indices_param, updates_param)
|
||||||
|
|
||||||
|
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
|
||||||
|
assert scatter_nd_node.get_output_size() == 1
|
||||||
|
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data_shape))
|
||||||
|
assert scatter_nd_node.get_output_element_type(0) == Type.f32
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_basic():
|
||||||
|
data = np.array([1, 2, 3, 4, 5])
|
||||||
|
indices = np.array([[0], [2]])
|
||||||
|
updates = np.array([9, 10])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([9, 2, 10, 4, 5])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_multidimensional():
|
||||||
|
data = np.array([[1, 2], [3, 4]])
|
||||||
|
indices = np.array([[0, 1], [1, 0]])
|
||||||
|
updates = np.array([9, 10])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([[1, 9], [10, 4]])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_mismatched_updates_shape():
|
||||||
|
data = np.array([1, 2, 3])
|
||||||
|
indices = np.array([[0], [1]])
|
||||||
|
updates = np.array([4])
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
ov.scatter_nd_update(data, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_non_integer_indices():
|
||||||
|
data = np.array([1, 2, 3])
|
||||||
|
indices = np.array([[0.5]])
|
||||||
|
updates = np.array([4])
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
ov.scatter_nd_update(data, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_negative_indices():
|
||||||
|
data = np.array([1, 2, 3, 4])
|
||||||
|
indices = np.array([[-1]])
|
||||||
|
updates = np.array([5])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([1, 2, 3, 5])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_multi_index_per_update():
|
||||||
|
data = np.array([[1, 2], [3, 4]])
|
||||||
|
indices = np.array([[0, 0], [0, 1]])
|
||||||
|
updates = np.array([5, 6])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([[5, 6], [3, 4]])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_non_contiguous_indices():
|
||||||
|
data = np.array([10, 20, 30, 40, 50])
|
||||||
|
indices = np.array([[0], [3]])
|
||||||
|
updates = np.array([100, 400])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([100, 20, 30, 400, 50])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_large_updates():
|
||||||
|
data = np.zeros(1000, dtype=np.float64)
|
||||||
|
indices = np.reshape(np.arange(1000), (-1, 1))
|
||||||
|
updates = np.arange(1000, dtype=np.float64)
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.arange(1000, dtype=np.float64)
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_overlapping_indices():
|
||||||
|
data = np.array([1, 2, 3, 4, 5])
|
||||||
|
indices = np.array([[1], [1], [3]])
|
||||||
|
updates = np.array([10, 20, 30])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([1, 20, 3, 30, 5])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_3d_data():
|
||||||
|
data = np.zeros((2, 2, 2), dtype=np.float64)
|
||||||
|
indices = np.array([[0, 0, 1], [1, 1, 0]])
|
||||||
|
updates = np.array([1, 2], dtype=np.float64)
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([[[0, 1], [0, 0]], [[0, 0], [2, 0]]], dtype=np.float64)
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_all_indices():
|
||||||
|
data = np.ones((2, 3), dtype=np.float64)
|
||||||
|
indices = np.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]])
|
||||||
|
updates = np.array([10, 20, 30, 40, 50, 60], dtype=np.float64)
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.float64)
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_invalid_updates_shape():
|
||||||
|
data = np.array([1, 2, 3, 4])
|
||||||
|
indices = np.array([[1], [2]])
|
||||||
|
updates = np.array([5])
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
ov.scatter_nd_update(data, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_negative_updates():
|
||||||
|
data = np.array([1, 2, 3, 4, 5])
|
||||||
|
indices = np.array([[1], [3]])
|
||||||
|
updates = np.array([-1, -2])
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([1, -1, 3, -2, 5])
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update_empty_indices_and_updates():
|
||||||
|
data = np.array([1, 2, 3], dtype=np.float64)
|
||||||
|
indices = np.array([], dtype=np.int64).reshape(0, 1)
|
||||||
|
updates = np.array([], dtype=np.float64)
|
||||||
|
|
||||||
|
result = ov.scatter_nd_update(data, indices, updates)
|
||||||
|
expected = np.array([1, 2, 3], dtype=np.float64)
|
||||||
|
np.testing.assert_array_equal(result, expected)
|
Loading…
Reference in New Issue
Block a user