Roll nGraph Python API (#5237)

* Added Roll to nGraph Python API.

* Added empty line at the end of file.

* Corrected a typo.

* Added Roll test.

* Reformat code.

* Removed empty line.
This commit is contained in:
Anastasia Popova 2021-04-15 07:19:15 +03:00 committed by GitHub
parent a70d13f9e2
commit f681907fdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 0 deletions

View File

@ -135,6 +135,7 @@ from ngraph.opset7 import rnn_cell
from ngraph.opset7 import rnn_sequence
from ngraph.opset7 import roi_align
from ngraph.opset7 import roi_pooling
from ngraph.opset7 import roll
from ngraph.opset7 import round
from ngraph.opset7 import scatter_elements_update
from ngraph.opset7 import scatter_update

View File

@ -120,6 +120,7 @@ from ngraph.opset3.ops import rnn_cell
from ngraph.opset5.ops import rnn_sequence
from ngraph.opset3.ops import roi_align
from ngraph.opset2.ops import roi_pooling
from ngraph.opset7.ops import roll
from ngraph.opset5.ops import round
from ngraph.opset3.ops import scatter_elements_update
from ngraph.opset3.ops import scatter_update

View File

@ -65,3 +65,21 @@ def gelu(
}
return _get_node_factory_opset7().create("Gelu", inputs, attributes)
@nameable_op
def roll(
data: NodeInput,
shift: NodeInput,
axes: NodeInput,
) -> Node:
"""Return a node which performs Roll operation.
@param data: The node with data tensor.
@param shift: The node with the tensor with numbers of places by which elements are shifted.
@param axes: The node with the tensor with axes along which elements are shifted.
@return The new node performing a Roll operation on the input tensor.
"""
inputs = as_nodes(data, shift, axes)
return _get_node_factory_opset7().create("Roll", inputs)

View File

@ -163,3 +163,5 @@ xfail_issue_49753 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1
xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE")
xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadcast_cpu - "
"Not equal to tolerance")
xfail_issue_49391 = xfail_test(reason="Roll is not implemented in CPU plugin.")

View File

@ -0,0 +1,20 @@
import ngraph as ng
import numpy as np
from tests import xfail_issue_49391
from tests.runtime import get_runtime
@xfail_issue_49391
def test_roll():
runtime = get_runtime()
input = np.reshape(np.arange(10), (2, 5))
input_tensor = ng.constant(input)
input_shift = ng.constant(np.array([-10, 7], dtype=np.int32))
input_axes = ng.constant(np.array([-1, 0], dtype=np.int32))
roll_node = ng.roll(input_tensor, input_shift, input_axes)
computation = runtime.computation(roll_node)
roll_results = computation()
expected_results = np.roll(input, shift=(-10, 7), axis=(-1, 0))
assert np.allclose(roll_results, expected_results)