Extend openvino (ngraph) Python API for operation Slice-8 (#7965)

This commit is contained in:
Katarzyna Mitrus 2021-10-21 10:22:51 +02:00 committed by GitHub
parent d39fe50470
commit e0062fc274
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 0 deletions

View File

@ -140,6 +140,7 @@ from openvino.opset1.ops import sigmoid
from openvino.opset1.ops import sign
from openvino.opset1.ops import sin
from openvino.opset1.ops import sinh
from openvino.opset8.ops import slice
from openvino.opset1.ops import softmax
from openvino.opset4.ops import softplus
from openvino.opset2.ops import space_to_batch

View File

@ -367,3 +367,27 @@ def random_uniform(
"op_seed": op_seed,
}
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
@nameable_op
def slice(
data: NodeInput,
start: NodeInput,
stop: NodeInput,
step: NodeInput,
axes: NodeInput = None
) -> Node:
"""Return a node which generates Slice operation.
@param data: The node providing input data.
@param start: The node providing start indices (inclusively).
@param stop: The node providing stop indices (exclusively).
@param step: The node providing step values.
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
"""
if axes is None:
inputs = as_nodes(data, start, stop, step)
else:
inputs = as_nodes(data, start, stop, step, axes)
return _get_node_factory_opset8().create("Slice", inputs)

View File

@ -1923,3 +1923,31 @@ def test_matrix_nms():
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32
def test_slice():
data_shape = [10, 7, 2, 13]
data = ov.parameter(data_shape, name="input", dtype=np.float32)
start = ov.constant(np.array([2, 0, 0], dtype=np.int32))
stop = ov.constant(np.array([9, 7, 2], dtype=np.int32))
step = ov.constant(np.array([2, 1, 1], dtype=np.int32))
node_default_axes = ov.slice(data, start, stop, step)
assert node_default_axes.get_type_name() == "Slice"
assert node_default_axes.get_output_size() == 1
assert node_default_axes.get_output_element_type(0) == Type.f32
assert tuple(node_default_axes.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
start = ov.constant(np.array([0, 2], dtype=np.int32))
stop = ov.constant(np.array([2, 9], dtype=np.int32))
step = ov.constant(np.array([1, 2], dtype=np.int32))
axes = ov.constant(np.array([-2, 0], dtype=np.int32))
node = ov.slice(data, start, stop, step, axes)
assert node.get_type_name() == "Slice"
assert node.get_output_size() == 1
assert node.get_output_element_type(0) == Type.f32
assert tuple(node.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape