[Python API] Add missing opest8 ops to compatibility python API (#8659)

This commit is contained in:
Katarzyna Mitrus 2021-11-19 17:32:03 +01:00 committed by GitHub
parent 7e457bfd72
commit 83991607c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 100 additions and 2 deletions

View File

@ -166,6 +166,7 @@ from ngraph.opset8 import sigmoid
from ngraph.opset8 import sign
from ngraph.opset8 import sin
from ngraph.opset8 import sinh
from ngraph.opset8 import slice
from ngraph.opset8 import softmax
from ngraph.opset8 import softplus
from ngraph.opset8 import space_to_batch

View File

@ -55,7 +55,7 @@ from ngraph.opset1.ops import floor
from ngraph.opset1.ops import floor_mod
from ngraph.opset8.ops import gather
from ngraph.opset6.ops import gather_elements
from ngraph.opset5.ops import gather_nd
from ngraph.opset8.ops import gather_nd
from ngraph.opset1.ops import gather_tree
from ngraph.opset7.ops import gelu
from ngraph.opset1.ops import greater
@ -140,6 +140,7 @@ from ngraph.opset1.ops import sigmoid
from ngraph.opset1.ops import sign
from ngraph.opset1.ops import sin
from ngraph.opset1.ops import sinh
from ngraph.opset8.ops import slice
from ngraph.opset1.ops import softmax
from ngraph.opset4.ops import softplus
from ngraph.opset2.ops import space_to_batch

View File

@ -367,3 +367,53 @@ def random_uniform(
"op_seed": op_seed,
}
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
@nameable_op
def slice(
data: NodeInput,
start: NodeInput,
stop: NodeInput,
step: NodeInput,
axes: Optional[NodeInput] = None,
name: Optional[str] = None,
) -> Node:
"""Return a node which generates Slice operation.
@param data: The node providing input data.
@param start: The node providing start indices (inclusively).
@param stop: The node providing stop indices (exclusively).
@param step: The node providing step values.
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
@param name: The optional name for the created output node.
@return The new node performing Slice operation.
"""
if axes is None:
inputs = as_nodes(data, start, stop, step)
else:
inputs = as_nodes(data, start, stop, step, axes)
return _get_node_factory_opset8().create("Slice", inputs)
@nameable_op
def gather_nd(
data: NodeInput,
indices: NodeInput,
batch_dims: Optional[int] = 0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs GatherND.
@param data: N-D tensor with data for gathering
@param indices: K-D tensor of tuples with indices by which data is gathered
@param batch_dims: Scalar value of batch dimensions
@return: The new node which performs GatherND
"""
inputs = as_nodes(data, indices)
attributes = {
"batch_dims": batch_dims
}
return _get_node_factory_opset8().create("GatherND", inputs, attributes)

View File

@ -375,7 +375,8 @@ def slice(
start: NodeInput,
stop: NodeInput,
step: NodeInput,
axes: NodeInput = None
axes: Optional[NodeInput] = None,
name: Optional[str] = None,
) -> Node:
"""Return a node which generates Slice operation.
@ -384,6 +385,8 @@ def slice(
@param stop: The node providing stop indices (exclusively).
@param step: The node providing step values.
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
@param name: The optional name for the created output node.
@return The new node performing Slice operation.
"""
if axes is None:
inputs = as_nodes(data, start, stop, step)

View File

@ -1923,3 +1923,31 @@ def test_matrix_nms():
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32
def test_slice():
data_shape = [10, 7, 2, 13]
data = ng.parameter(data_shape, name="input", dtype=np.float32)
start = ng.constant(np.array([2, 0, 0], dtype=np.int32))
stop = ng.constant(np.array([9, 7, 2], dtype=np.int32))
step = ng.constant(np.array([2, 1, 1], dtype=np.int32))
node_default_axes = ng.slice(data, start, stop, step)
assert node_default_axes.get_type_name() == "Slice"
assert node_default_axes.get_output_size() == 1
assert node_default_axes.get_output_element_type(0) == Type.f32
assert tuple(node_default_axes.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
start = ng.constant(np.array([0, 2], dtype=np.int32))
stop = ng.constant(np.array([2, 9], dtype=np.int32))
step = ng.constant(np.array([1, 2], dtype=np.int32))
axes = ng.constant(np.array([-2, 0], dtype=np.int32))
node = ng.slice(data, start, stop, step, axes)
assert node.get_type_name() == "Slice"
assert node.get_output_size() == 1
assert node.get_output_element_type(0) == Type.f32
assert tuple(node.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape

View File

@ -196,6 +196,21 @@ def test_gather_nd():
batch_dims = 2
expected_shape = [20, 30, 40, 50]
node = ng.opset5.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32
def test_gather_v8_nd():
indices_type = np.int32
data_dtype = np.float32
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
batch_dims = 2
expected_shape = [2, 10, 30, 40, 50]
node = ng.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1