Extend openvino (ngraph) Python API for operation Slice-8 (#7965)
This commit is contained in:
parent
d39fe50470
commit
e0062fc274
@ -140,6 +140,7 @@ from openvino.opset1.ops import sigmoid
|
||||
from openvino.opset1.ops import sign
|
||||
from openvino.opset1.ops import sin
|
||||
from openvino.opset1.ops import sinh
|
||||
from openvino.opset8.ops import slice
|
||||
from openvino.opset1.ops import softmax
|
||||
from openvino.opset4.ops import softplus
|
||||
from openvino.opset2.ops import space_to_batch
|
||||
|
@ -367,3 +367,27 @@ def random_uniform(
|
||||
"op_seed": op_seed,
|
||||
}
|
||||
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def slice(
|
||||
data: NodeInput,
|
||||
start: NodeInput,
|
||||
stop: NodeInput,
|
||||
step: NodeInput,
|
||||
axes: NodeInput = None
|
||||
) -> Node:
|
||||
"""Return a node which generates Slice operation.
|
||||
|
||||
@param data: The node providing input data.
|
||||
@param start: The node providing start indices (inclusively).
|
||||
@param stop: The node providing stop indices (exclusively).
|
||||
@param step: The node providing step values.
|
||||
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
|
||||
"""
|
||||
if axes is None:
|
||||
inputs = as_nodes(data, start, stop, step)
|
||||
else:
|
||||
inputs = as_nodes(data, start, stop, step, axes)
|
||||
|
||||
return _get_node_factory_opset8().create("Slice", inputs)
|
||||
|
@ -1923,3 +1923,31 @@ def test_matrix_nms():
|
||||
assert nms_node.get_output_element_type(0) == Type.f32
|
||||
assert nms_node.get_output_element_type(1) == Type.i32
|
||||
assert nms_node.get_output_element_type(2) == Type.i32
|
||||
|
||||
|
||||
def test_slice():
|
||||
data_shape = [10, 7, 2, 13]
|
||||
data = ov.parameter(data_shape, name="input", dtype=np.float32)
|
||||
|
||||
start = ov.constant(np.array([2, 0, 0], dtype=np.int32))
|
||||
stop = ov.constant(np.array([9, 7, 2], dtype=np.int32))
|
||||
step = ov.constant(np.array([2, 1, 1], dtype=np.int32))
|
||||
|
||||
node_default_axes = ov.slice(data, start, stop, step)
|
||||
|
||||
assert node_default_axes.get_type_name() == "Slice"
|
||||
assert node_default_axes.get_output_size() == 1
|
||||
assert node_default_axes.get_output_element_type(0) == Type.f32
|
||||
assert tuple(node_default_axes.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
|
||||
|
||||
start = ov.constant(np.array([0, 2], dtype=np.int32))
|
||||
stop = ov.constant(np.array([2, 9], dtype=np.int32))
|
||||
step = ov.constant(np.array([1, 2], dtype=np.int32))
|
||||
axes = ov.constant(np.array([-2, 0], dtype=np.int32))
|
||||
|
||||
node = ov.slice(data, start, stop, step, axes)
|
||||
|
||||
assert node.get_type_name() == "Slice"
|
||||
assert node.get_output_size() == 1
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
assert tuple(node.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
|
||||
|
Loading…
Reference in New Issue
Block a user