[Python API] Add missing opest8 ops to compatibility python API (#8659)
This commit is contained in:
parent
7e457bfd72
commit
83991607c3
@ -166,6 +166,7 @@ from ngraph.opset8 import sigmoid
|
||||
from ngraph.opset8 import sign
|
||||
from ngraph.opset8 import sin
|
||||
from ngraph.opset8 import sinh
|
||||
from ngraph.opset8 import slice
|
||||
from ngraph.opset8 import softmax
|
||||
from ngraph.opset8 import softplus
|
||||
from ngraph.opset8 import space_to_batch
|
||||
|
@ -55,7 +55,7 @@ from ngraph.opset1.ops import floor
|
||||
from ngraph.opset1.ops import floor_mod
|
||||
from ngraph.opset8.ops import gather
|
||||
from ngraph.opset6.ops import gather_elements
|
||||
from ngraph.opset5.ops import gather_nd
|
||||
from ngraph.opset8.ops import gather_nd
|
||||
from ngraph.opset1.ops import gather_tree
|
||||
from ngraph.opset7.ops import gelu
|
||||
from ngraph.opset1.ops import greater
|
||||
@ -140,6 +140,7 @@ from ngraph.opset1.ops import sigmoid
|
||||
from ngraph.opset1.ops import sign
|
||||
from ngraph.opset1.ops import sin
|
||||
from ngraph.opset1.ops import sinh
|
||||
from ngraph.opset8.ops import slice
|
||||
from ngraph.opset1.ops import softmax
|
||||
from ngraph.opset4.ops import softplus
|
||||
from ngraph.opset2.ops import space_to_batch
|
||||
|
@ -367,3 +367,53 @@ def random_uniform(
|
||||
"op_seed": op_seed,
|
||||
}
|
||||
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def slice(
|
||||
data: NodeInput,
|
||||
start: NodeInput,
|
||||
stop: NodeInput,
|
||||
step: NodeInput,
|
||||
axes: Optional[NodeInput] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which generates Slice operation.
|
||||
|
||||
@param data: The node providing input data.
|
||||
@param start: The node providing start indices (inclusively).
|
||||
@param stop: The node providing stop indices (exclusively).
|
||||
@param step: The node providing step values.
|
||||
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
|
||||
@param name: The optional name for the created output node.
|
||||
@return The new node performing Slice operation.
|
||||
"""
|
||||
if axes is None:
|
||||
inputs = as_nodes(data, start, stop, step)
|
||||
else:
|
||||
inputs = as_nodes(data, start, stop, step, axes)
|
||||
|
||||
return _get_node_factory_opset8().create("Slice", inputs)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def gather_nd(
|
||||
data: NodeInput,
|
||||
indices: NodeInput,
|
||||
batch_dims: Optional[int] = 0,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which performs GatherND.
|
||||
|
||||
@param data: N-D tensor with data for gathering
|
||||
@param indices: K-D tensor of tuples with indices by which data is gathered
|
||||
@param batch_dims: Scalar value of batch dimensions
|
||||
@return: The new node which performs GatherND
|
||||
"""
|
||||
inputs = as_nodes(data, indices)
|
||||
|
||||
attributes = {
|
||||
"batch_dims": batch_dims
|
||||
}
|
||||
|
||||
return _get_node_factory_opset8().create("GatherND", inputs, attributes)
|
||||
|
@ -375,7 +375,8 @@ def slice(
|
||||
start: NodeInput,
|
||||
stop: NodeInput,
|
||||
step: NodeInput,
|
||||
axes: NodeInput = None
|
||||
axes: Optional[NodeInput] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which generates Slice operation.
|
||||
|
||||
@ -384,6 +385,8 @@ def slice(
|
||||
@param stop: The node providing stop indices (exclusively).
|
||||
@param step: The node providing step values.
|
||||
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
|
||||
@param name: The optional name for the created output node.
|
||||
@return The new node performing Slice operation.
|
||||
"""
|
||||
if axes is None:
|
||||
inputs = as_nodes(data, start, stop, step)
|
||||
|
@ -1923,3 +1923,31 @@ def test_matrix_nms():
|
||||
assert nms_node.get_output_element_type(0) == Type.f32
|
||||
assert nms_node.get_output_element_type(1) == Type.i32
|
||||
assert nms_node.get_output_element_type(2) == Type.i32
|
||||
|
||||
|
||||
def test_slice():
|
||||
data_shape = [10, 7, 2, 13]
|
||||
data = ng.parameter(data_shape, name="input", dtype=np.float32)
|
||||
|
||||
start = ng.constant(np.array([2, 0, 0], dtype=np.int32))
|
||||
stop = ng.constant(np.array([9, 7, 2], dtype=np.int32))
|
||||
step = ng.constant(np.array([2, 1, 1], dtype=np.int32))
|
||||
|
||||
node_default_axes = ng.slice(data, start, stop, step)
|
||||
|
||||
assert node_default_axes.get_type_name() == "Slice"
|
||||
assert node_default_axes.get_output_size() == 1
|
||||
assert node_default_axes.get_output_element_type(0) == Type.f32
|
||||
assert tuple(node_default_axes.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
|
||||
|
||||
start = ng.constant(np.array([0, 2], dtype=np.int32))
|
||||
stop = ng.constant(np.array([2, 9], dtype=np.int32))
|
||||
step = ng.constant(np.array([1, 2], dtype=np.int32))
|
||||
axes = ng.constant(np.array([-2, 0], dtype=np.int32))
|
||||
|
||||
node = ng.slice(data, start, stop, step, axes)
|
||||
|
||||
assert node.get_type_name() == "Slice"
|
||||
assert node.get_output_size() == 1
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
assert tuple(node.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
|
||||
|
@ -196,6 +196,21 @@ def test_gather_nd():
|
||||
batch_dims = 2
|
||||
expected_shape = [20, 30, 40, 50]
|
||||
|
||||
node = ng.opset5.gather_nd(data, indices, batch_dims)
|
||||
assert node.get_type_name() == "GatherND"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == expected_shape
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
|
||||
|
||||
def test_gather_v8_nd():
|
||||
indices_type = np.int32
|
||||
data_dtype = np.float32
|
||||
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
|
||||
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
|
||||
batch_dims = 2
|
||||
expected_shape = [2, 10, 30, 40, 50]
|
||||
|
||||
node = ng.gather_nd(data, indices, batch_dims)
|
||||
assert node.get_type_name() == "GatherND"
|
||||
assert node.get_output_size() == 1
|
||||
|
Loading…
Reference in New Issue
Block a user