[Unique-10] ConvertPrecision transformation and new attribute for output type (#14229)

* Add attribute for last output element type

* Add convert precision transformation and tests

* Update Unique python API with new attribute

* Update Unique-10 op specification

* Update docstrings

* Update visitor tests

* Add type prop tests for the new attribute

* Check axis constant before shape, compare with partial shape
This commit is contained in:
Katarzyna Mitrus
2022-11-28 11:56:49 +01:00
committed by GitHub
parent f06c44115f
commit 13a76a8b72
11 changed files with 184 additions and 28 deletions

View File

@@ -141,6 +141,7 @@ def unique(
axis: Optional[NodeInput] = None,
sorted: Optional[bool] = True,
index_element_type: Optional[str] = "i64",
count_element_type: Optional[str] = "i64",
name: Optional[str] = None,
) -> Node:
"""Operator which selects and returns unique elements or unique slices of the input tensor.
@@ -154,6 +155,8 @@ def unique(
Default value: True.
:param index_element_type: (Optional) The data type set for outputs containing indices.
Default value: "i64".
:param count_element_type: (Optional) The data type set for the output with repetition count.
Default value: "i64".
:param name: (Optional) A name for the output node. Default value: None.
:return: Node representing Unique operation.
"""
@@ -165,5 +168,6 @@ def unique(
attributes = {
"sorted": sorted,
"index_element_type": index_element_type,
"count_element_type": count_element_type,
}
return _get_node_factory_opset10().create("Unique", inputs, attributes)

View File

@@ -141,6 +141,7 @@ def unique(
axis: Optional[NodeInput] = None,
sorted: Optional[bool] = True,
index_element_type: Optional[str] = "i64",
count_element_type: Optional[str] = "i64",
name: Optional[str] = None,
) -> Node:
"""Operator which selects and returns unique elements or unique slices of the input tensor.
@@ -154,6 +155,8 @@ def unique(
Default value: True.
:param index_element_type: (Optional) The data type set for outputs containing indices.
Default value: "i64".
:param count_element_type: (Optional) The data type set for the output with repetition count.
Default value: "i64".
:param name: (Optional) A name for the output node. Default value: None.
:return: Node representing Unique operation.
"""
@@ -165,5 +168,6 @@ def unique(
attributes = {
"sorted": sorted,
"index_element_type": index_element_type,
"count_element_type": count_element_type,
}
return _get_node_factory_opset10().create("Unique", inputs, attributes)

View File

@@ -2294,7 +2294,7 @@ def test_unique_opset10():
assert node.get_output_element_type(3) == Type.i64
# Axis default, means flattened result
node = ov_opset10.unique(input_node, None, False, "i32")
node = ov_opset10.unique(input_node, None, False, "i32", "i32")
assert node.get_type_name() == "Unique"
assert node.get_sorted() is False
@@ -2308,7 +2308,7 @@ def test_unique_opset10():
assert node.get_output_element_type(0) == Type.f32
assert node.get_output_element_type(1) == Type.i32
assert node.get_output_element_type(2) == Type.i32
assert node.get_output_element_type(3) == Type.i64
assert node.get_output_element_type(3) == Type.i32
# All arguments default
node = ov_opset10.unique(input_node)

View File

@@ -2359,7 +2359,7 @@ def test_unique_opset10():
assert node.get_output_element_type(3) == Type.i64
# Axis default, means flattened result
node = ng_opset10.unique(input_node, None, False, "i32")
node = ng_opset10.unique(input_node, None, False, "i32", "i32")
assert node.get_type_name() == "Unique"
assert node.get_sorted() is False
@@ -2373,7 +2373,7 @@ def test_unique_opset10():
assert node.get_output_element_type(0) == Type.f32
assert node.get_output_element_type(1) == Type.i32
assert node.get_output_element_type(2) == Type.i32
assert node.get_output_element_type(3) == Type.i64
assert node.get_output_element_type(3) == Type.i32
# All arguments default
node = ng_opset10.unique(input_node)