TopK 11 exposed to Python (#16501)
This commit is contained in:
@@ -170,7 +170,7 @@ from ngraph.opset1.ops import tan
|
||||
from ngraph.opset1.ops import tanh
|
||||
from ngraph.opset1.ops import tensor_iterator
|
||||
from ngraph.opset1.ops import tile
|
||||
from ngraph.opset3.ops import topk
|
||||
from ngraph.opset11.ops import topk
|
||||
from ngraph.opset1.ops import transpose
|
||||
from ngraph.opset10.ops import unique
|
||||
from ngraph.opset1.ops import unsqueeze
|
||||
|
||||
@@ -34,7 +34,7 @@ def interpolate(
|
||||
axes: Optional[NodeInput] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Perfors the interpolation of the input tensor.
|
||||
"""Performs the interpolation of the input tensor.
|
||||
|
||||
:param image: The node providing input tensor with data for interpolation.
|
||||
:param scales_or_sizes:
|
||||
@@ -75,3 +75,33 @@ def interpolate(
|
||||
inputs = as_nodes(image, scales_or_sizes) if axes is None else as_nodes(image, scales_or_sizes, axes)
|
||||
|
||||
return _get_node_factory_opset11().create("Interpolate", inputs, attrs)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def topk(
|
||||
data: NodeInput,
|
||||
k: NodeInput,
|
||||
axis: int,
|
||||
mode: str,
|
||||
sort: str,
|
||||
index_element_type: str = "i32",
|
||||
stable: bool = False,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which performs TopK.
|
||||
|
||||
:param data: Input data.
|
||||
:param k: K.
|
||||
:param axis: TopK Axis.
|
||||
:param mode: Compute TopK largest ('max') or smallest ('min')
|
||||
:param sort: Order of output elements (sort by: 'none', 'index' or 'value')
|
||||
:param index_element_type: Type of output tensor with indices.
|
||||
:param stable: Specifies whether the equivalent elements should maintain
|
||||
their relative order from the input tensor during sorting.
|
||||
:return: The new node which performs TopK
|
||||
"""
|
||||
return _get_node_factory_opset11().create(
|
||||
"TopK",
|
||||
as_nodes(data, k),
|
||||
{"axis": axis, "mode": mode, "sort": sort, "index_element_type": index_element_type, "stable": stable},
|
||||
)
|
||||
|
||||
@@ -171,7 +171,7 @@ from openvino.runtime.opset1.ops import tan
|
||||
from openvino.runtime.opset1.ops import tanh
|
||||
from openvino.runtime.opset1.ops import tensor_iterator
|
||||
from openvino.runtime.opset1.ops import tile
|
||||
from openvino.runtime.opset3.ops import topk
|
||||
from openvino.runtime.opset11.ops import topk
|
||||
from openvino.runtime.opset1.ops import transpose
|
||||
from openvino.runtime.opset10.ops import unique
|
||||
from openvino.runtime.opset1.ops import unsqueeze
|
||||
|
||||
@@ -75,3 +75,33 @@ def interpolate(
|
||||
inputs = as_nodes(image, scales_or_sizes) if axes is None else as_nodes(image, scales_or_sizes, axes)
|
||||
|
||||
return _get_node_factory_opset11().create("Interpolate", inputs, attrs)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def topk(
|
||||
data: NodeInput,
|
||||
k: NodeInput,
|
||||
axis: int,
|
||||
mode: str,
|
||||
sort: str,
|
||||
index_element_type: str = "i32",
|
||||
stable: bool = False,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which performs TopK.
|
||||
|
||||
:param data: Input data.
|
||||
:param k: K.
|
||||
:param axis: TopK Axis.
|
||||
:param mode: Compute TopK largest ('max') or smallest ('min')
|
||||
:param sort: Order of output elements (sort by: 'none', 'index' or 'value')
|
||||
:param index_element_type: Type of output tensor with indices.
|
||||
:param stable: Specifies whether the equivalent elements should maintain
|
||||
their relative order from the input tensor during sorting.
|
||||
:return: The new node which performs TopK
|
||||
"""
|
||||
return _get_node_factory_opset11().create(
|
||||
"TopK",
|
||||
as_nodes(data, k),
|
||||
{"axis": axis, "mode": mode, "sort": sort, "index_element_type": index_element_type, "stable": stable},
|
||||
)
|
||||
|
||||
@@ -2300,3 +2300,16 @@ def test_unique_opset10():
|
||||
assert node.get_output_element_type(1) == Type.i64
|
||||
assert node.get_output_element_type(2) == Type.i64
|
||||
assert node.get_output_element_type(3) == Type.i64
|
||||
|
||||
|
||||
def test_topk_opset11():
|
||||
data_shape = [1, 3, 256]
|
||||
data = ov.parameter(data_shape, dtype=np.int32, name="Data")
|
||||
k_val = np.int32(3)
|
||||
axis = np.int32(-1)
|
||||
node = ov.topk(data, k_val, axis, "min", "value", stable=True)
|
||||
|
||||
assert node.get_type_name() == "TopK"
|
||||
assert node.get_output_size() == 2
|
||||
assert list(node.get_output_shape(0)) == [1, 3, 3]
|
||||
assert list(node.get_output_shape(1)) == [1, 3, 3]
|
||||
|
||||
@@ -2412,3 +2412,16 @@ def test_unique_opset10():
|
||||
assert node.get_output_element_type(1) == Type.i64
|
||||
assert node.get_output_element_type(2) == Type.i64
|
||||
assert node.get_output_element_type(3) == Type.i64
|
||||
|
||||
|
||||
def test_topk_opset11():
|
||||
data_shape = [1, 3, 256]
|
||||
data = ng.parameter(data_shape, dtype=np.int32, name="Data")
|
||||
k_val = np.int32(3)
|
||||
axis = np.int32(-1)
|
||||
node = ng_opset11.topk(data, k_val, axis, "min", "value", stable=True)
|
||||
|
||||
assert node.get_type_name() == "TopK"
|
||||
assert node.get_output_size() == 2
|
||||
assert list(node.get_output_shape(0)) == [1, 3, 3]
|
||||
assert list(node.get_output_shape(1)) == [1, 3, 3]
|
||||
|
||||
Reference in New Issue
Block a user