Roll nGraph Python API (#5237)
* Added Roll to nGraph Python API. * Added empty line at the end of file. * Corrected a typo. * Added Roll test. * Reformat code. * Removed empty line.
This commit is contained in:
parent
a70d13f9e2
commit
f681907fdd
@ -135,6 +135,7 @@ from ngraph.opset7 import rnn_cell
|
||||
from ngraph.opset7 import rnn_sequence
|
||||
from ngraph.opset7 import roi_align
|
||||
from ngraph.opset7 import roi_pooling
|
||||
from ngraph.opset7 import roll
|
||||
from ngraph.opset7 import round
|
||||
from ngraph.opset7 import scatter_elements_update
|
||||
from ngraph.opset7 import scatter_update
|
||||
|
@ -120,6 +120,7 @@ from ngraph.opset3.ops import rnn_cell
|
||||
from ngraph.opset5.ops import rnn_sequence
|
||||
from ngraph.opset3.ops import roi_align
|
||||
from ngraph.opset2.ops import roi_pooling
|
||||
from ngraph.opset7.ops import roll
|
||||
from ngraph.opset5.ops import round
|
||||
from ngraph.opset3.ops import scatter_elements_update
|
||||
from ngraph.opset3.ops import scatter_update
|
||||
|
@ -65,3 +65,21 @@ def gelu(
|
||||
}
|
||||
|
||||
return _get_node_factory_opset7().create("Gelu", inputs, attributes)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def roll(
|
||||
data: NodeInput,
|
||||
shift: NodeInput,
|
||||
axes: NodeInput,
|
||||
) -> Node:
|
||||
"""Return a node which performs Roll operation.
|
||||
|
||||
@param data: The node with data tensor.
|
||||
@param shift: The node with the tensor with numbers of places by which elements are shifted.
|
||||
@param axes: The node with the tensor with axes along which elements are shifted.
|
||||
@return The new node performing a Roll operation on the input tensor.
|
||||
"""
|
||||
inputs = as_nodes(data, shift, axes)
|
||||
|
||||
return _get_node_factory_opset7().create("Roll", inputs)
|
||||
|
@ -163,3 +163,5 @@ xfail_issue_49753 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1
|
||||
xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE")
|
||||
xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadcast_cpu - "
|
||||
"Not equal to tolerance")
|
||||
|
||||
xfail_issue_49391 = xfail_test(reason="Roll is not implemented in CPU plugin.")
|
||||
|
20
ngraph/python/tests/test_ngraph/test_roll.py
Normal file
20
ngraph/python/tests/test_ngraph/test_roll.py
Normal file
@ -0,0 +1,20 @@
|
||||
import ngraph as ng
|
||||
import numpy as np
|
||||
from tests import xfail_issue_49391
|
||||
from tests.runtime import get_runtime
|
||||
|
||||
|
||||
@xfail_issue_49391
|
||||
def test_roll():
|
||||
runtime = get_runtime()
|
||||
input = np.reshape(np.arange(10), (2, 5))
|
||||
input_tensor = ng.constant(input)
|
||||
input_shift = ng.constant(np.array([-10, 7], dtype=np.int32))
|
||||
input_axes = ng.constant(np.array([-1, 0], dtype=np.int32))
|
||||
|
||||
roll_node = ng.roll(input_tensor, input_shift, input_axes)
|
||||
computation = runtime.computation(roll_node)
|
||||
roll_results = computation()
|
||||
expected_results = np.roll(input, shift=(-10, 7), axis=(-1, 0))
|
||||
|
||||
assert np.allclose(roll_results, expected_results)
|
Loading…
Reference in New Issue
Block a user