diff --git a/ngraph/python/src/ngraph/__init__.py b/ngraph/python/src/ngraph/__init__.py index c9343497b24..441392d4f0a 100644 --- a/ngraph/python/src/ngraph/__init__.py +++ b/ngraph/python/src/ngraph/__init__.py @@ -135,6 +135,7 @@ from ngraph.opset7 import rnn_cell from ngraph.opset7 import rnn_sequence from ngraph.opset7 import roi_align from ngraph.opset7 import roi_pooling +from ngraph.opset7 import roll from ngraph.opset7 import round from ngraph.opset7 import scatter_elements_update from ngraph.opset7 import scatter_update diff --git a/ngraph/python/src/ngraph/opset7/__init__.py b/ngraph/python/src/ngraph/opset7/__init__.py index c665b1e68ef..a7d12fb6f02 100644 --- a/ngraph/python/src/ngraph/opset7/__init__.py +++ b/ngraph/python/src/ngraph/opset7/__init__.py @@ -120,6 +120,7 @@ from ngraph.opset3.ops import rnn_cell from ngraph.opset5.ops import rnn_sequence from ngraph.opset3.ops import roi_align from ngraph.opset2.ops import roi_pooling +from ngraph.opset7.ops import roll from ngraph.opset5.ops import round from ngraph.opset3.ops import scatter_elements_update from ngraph.opset3.ops import scatter_update diff --git a/ngraph/python/src/ngraph/opset7/ops.py b/ngraph/python/src/ngraph/opset7/ops.py index ef2dfc9af0b..dee2c5d3192 100644 --- a/ngraph/python/src/ngraph/opset7/ops.py +++ b/ngraph/python/src/ngraph/opset7/ops.py @@ -65,3 +65,21 @@ def gelu( } return _get_node_factory_opset7().create("Gelu", inputs, attributes) + + +@nameable_op +def roll( + data: NodeInput, + shift: NodeInput, + axes: NodeInput, +) -> Node: + """Return a node which performs Roll operation. + + @param data: The node with data tensor. + @param shift: The node with the tensor with numbers of places by which elements are shifted. + @param axes: The node with the tensor with axes along which elements are shifted. + @return The new node performing a Roll operation on the input tensor. + """ + inputs = as_nodes(data, shift, axes) + + return _get_node_factory_opset7().create("Roll", inputs) diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index 65b7040f679..1ccfdb7cf61 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -163,3 +163,5 @@ xfail_issue_49753 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1 xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE") xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadcast_cpu - " "Not equal to tolerance") + +xfail_issue_49391 = xfail_test(reason="Roll is not implemented in CPU plugin.") diff --git a/ngraph/python/tests/test_ngraph/test_roll.py b/ngraph/python/tests/test_ngraph/test_roll.py new file mode 100644 index 00000000000..07426df0816 --- /dev/null +++ b/ngraph/python/tests/test_ngraph/test_roll.py @@ -0,0 +1,20 @@ +import ngraph as ng +import numpy as np +from tests import xfail_issue_49391 +from tests.runtime import get_runtime + + +@xfail_issue_49391 +def test_roll(): + runtime = get_runtime() + input = np.reshape(np.arange(10), (2, 5)) + input_tensor = ng.constant(input) + input_shift = ng.constant(np.array([-10, 7], dtype=np.int32)) + input_axes = ng.constant(np.array([-1, 0], dtype=np.int32)) + + roll_node = ng.roll(input_tensor, input_shift, input_axes) + computation = runtime.computation(roll_node) + roll_results = computation() + expected_results = np.roll(input, shift=(-10, 7), axis=(-1, 0)) + + assert np.allclose(roll_results, expected_results)