TopK 11 exposed to Python (#16501)

This commit is contained in:
Tomasz Dołbniak
2023-03-23 16:33:54 +01:00
committed by GitHub
parent 44d6d97871
commit de0a4e16fb
6 changed files with 89 additions and 3 deletions

View File

@@ -170,7 +170,7 @@ from ngraph.opset1.ops import tan
from ngraph.opset1.ops import tanh
from ngraph.opset1.ops import tensor_iterator
from ngraph.opset1.ops import tile
from ngraph.opset3.ops import topk
from ngraph.opset11.ops import topk
from ngraph.opset1.ops import transpose
from ngraph.opset10.ops import unique
from ngraph.opset1.ops import unsqueeze

View File

@@ -34,7 +34,7 @@ def interpolate(
axes: Optional[NodeInput] = None,
name: Optional[str] = None,
) -> Node:
"""Perfors the interpolation of the input tensor.
"""Performs the interpolation of the input tensor.
:param image: The node providing input tensor with data for interpolation.
:param scales_or_sizes:
@@ -75,3 +75,33 @@ def interpolate(
inputs = as_nodes(image, scales_or_sizes) if axes is None else as_nodes(image, scales_or_sizes, axes)
return _get_node_factory_opset11().create("Interpolate", inputs, attrs)
@nameable_op
def topk(
data: NodeInput,
k: NodeInput,
axis: int,
mode: str,
sort: str,
index_element_type: str = "i32",
stable: bool = False,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs TopK.
:param data: Input data.
:param k: K.
:param axis: TopK Axis.
:param mode: Compute TopK largest ('max') or smallest ('min')
:param sort: Order of output elements (sort by: 'none', 'index' or 'value')
:param index_element_type: Type of output tensor with indices.
:param stable: Specifies whether the equivalent elements should maintain
their relative order from the input tensor during sorting.
:return: The new node which performs TopK
"""
return _get_node_factory_opset11().create(
"TopK",
as_nodes(data, k),
{"axis": axis, "mode": mode, "sort": sort, "index_element_type": index_element_type, "stable": stable},
)

View File

@@ -171,7 +171,7 @@ from openvino.runtime.opset1.ops import tan
from openvino.runtime.opset1.ops import tanh
from openvino.runtime.opset1.ops import tensor_iterator
from openvino.runtime.opset1.ops import tile
from openvino.runtime.opset3.ops import topk
from openvino.runtime.opset11.ops import topk
from openvino.runtime.opset1.ops import transpose
from openvino.runtime.opset10.ops import unique
from openvino.runtime.opset1.ops import unsqueeze

View File

@@ -75,3 +75,33 @@ def interpolate(
inputs = as_nodes(image, scales_or_sizes) if axes is None else as_nodes(image, scales_or_sizes, axes)
return _get_node_factory_opset11().create("Interpolate", inputs, attrs)
@nameable_op
def topk(
data: NodeInput,
k: NodeInput,
axis: int,
mode: str,
sort: str,
index_element_type: str = "i32",
stable: bool = False,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs TopK.
:param data: Input data.
:param k: K.
:param axis: TopK Axis.
:param mode: Compute TopK largest ('max') or smallest ('min')
:param sort: Order of output elements (sort by: 'none', 'index' or 'value')
:param index_element_type: Type of output tensor with indices.
:param stable: Specifies whether the equivalent elements should maintain
their relative order from the input tensor during sorting.
:return: The new node which performs TopK
"""
return _get_node_factory_opset11().create(
"TopK",
as_nodes(data, k),
{"axis": axis, "mode": mode, "sort": sort, "index_element_type": index_element_type, "stable": stable},
)

View File

@@ -2300,3 +2300,16 @@ def test_unique_opset10():
assert node.get_output_element_type(1) == Type.i64
assert node.get_output_element_type(2) == Type.i64
assert node.get_output_element_type(3) == Type.i64
def test_topk_opset11():
data_shape = [1, 3, 256]
data = ov.parameter(data_shape, dtype=np.int32, name="Data")
k_val = np.int32(3)
axis = np.int32(-1)
node = ov.topk(data, k_val, axis, "min", "value", stable=True)
assert node.get_type_name() == "TopK"
assert node.get_output_size() == 2
assert list(node.get_output_shape(0)) == [1, 3, 3]
assert list(node.get_output_shape(1)) == [1, 3, 3]

View File

@@ -2412,3 +2412,16 @@ def test_unique_opset10():
assert node.get_output_element_type(1) == Type.i64
assert node.get_output_element_type(2) == Type.i64
assert node.get_output_element_type(3) == Type.i64
def test_topk_opset11():
data_shape = [1, 3, 256]
data = ng.parameter(data_shape, dtype=np.int32, name="Data")
k_val = np.int32(3)
axis = np.int32(-1)
node = ng_opset11.topk(data, k_val, axis, "min", "value", stable=True)
assert node.get_type_name() == "TopK"
assert node.get_output_size() == 2
assert list(node.get_output_shape(0)) == [1, 3, 3]
assert list(node.get_output_shape(1)) == [1, 3, 3]