[Unique-10] ConvertPrecision transformation and new attribute for output type (#14229)
* Add attribute for last output element type * Add convert precision transformation and tests * Update Unique python API with new attribute * Update Unique-10 op specification * Update docstrings * Update visitor tests * Add type prop tests for the new attribute * Check axis constant before shape, compare with partial shape
This commit is contained in:
@@ -141,6 +141,7 @@ def unique(
|
||||
axis: Optional[NodeInput] = None,
|
||||
sorted: Optional[bool] = True,
|
||||
index_element_type: Optional[str] = "i64",
|
||||
count_element_type: Optional[str] = "i64",
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Operator which selects and returns unique elements or unique slices of the input tensor.
|
||||
@@ -154,6 +155,8 @@ def unique(
|
||||
Default value: True.
|
||||
:param index_element_type: (Optional) The data type set for outputs containing indices.
|
||||
Default value: "i64".
|
||||
:param count_element_type: (Optional) The data type set for the output with repetition count.
|
||||
Default value: "i64".
|
||||
:param name: (Optional) A name for the output node. Default value: None.
|
||||
:return: Node representing Unique operation.
|
||||
"""
|
||||
@@ -165,5 +168,6 @@ def unique(
|
||||
attributes = {
|
||||
"sorted": sorted,
|
||||
"index_element_type": index_element_type,
|
||||
"count_element_type": count_element_type,
|
||||
}
|
||||
return _get_node_factory_opset10().create("Unique", inputs, attributes)
|
||||
|
||||
@@ -141,6 +141,7 @@ def unique(
|
||||
axis: Optional[NodeInput] = None,
|
||||
sorted: Optional[bool] = True,
|
||||
index_element_type: Optional[str] = "i64",
|
||||
count_element_type: Optional[str] = "i64",
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Operator which selects and returns unique elements or unique slices of the input tensor.
|
||||
@@ -154,6 +155,8 @@ def unique(
|
||||
Default value: True.
|
||||
:param index_element_type: (Optional) The data type set for outputs containing indices.
|
||||
Default value: "i64".
|
||||
:param count_element_type: (Optional) The data type set for the output with repetition count.
|
||||
Default value: "i64".
|
||||
:param name: (Optional) A name for the output node. Default value: None.
|
||||
:return: Node representing Unique operation.
|
||||
"""
|
||||
@@ -165,5 +168,6 @@ def unique(
|
||||
attributes = {
|
||||
"sorted": sorted,
|
||||
"index_element_type": index_element_type,
|
||||
"count_element_type": count_element_type,
|
||||
}
|
||||
return _get_node_factory_opset10().create("Unique", inputs, attributes)
|
||||
|
||||
@@ -2294,7 +2294,7 @@ def test_unique_opset10():
|
||||
assert node.get_output_element_type(3) == Type.i64
|
||||
|
||||
# Axis default, means flattened result
|
||||
node = ov_opset10.unique(input_node, None, False, "i32")
|
||||
node = ov_opset10.unique(input_node, None, False, "i32", "i32")
|
||||
|
||||
assert node.get_type_name() == "Unique"
|
||||
assert node.get_sorted() is False
|
||||
@@ -2308,7 +2308,7 @@ def test_unique_opset10():
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
assert node.get_output_element_type(1) == Type.i32
|
||||
assert node.get_output_element_type(2) == Type.i32
|
||||
assert node.get_output_element_type(3) == Type.i64
|
||||
assert node.get_output_element_type(3) == Type.i32
|
||||
|
||||
# All arguments default
|
||||
node = ov_opset10.unique(input_node)
|
||||
|
||||
@@ -2359,7 +2359,7 @@ def test_unique_opset10():
|
||||
assert node.get_output_element_type(3) == Type.i64
|
||||
|
||||
# Axis default, means flattened result
|
||||
node = ng_opset10.unique(input_node, None, False, "i32")
|
||||
node = ng_opset10.unique(input_node, None, False, "i32", "i32")
|
||||
|
||||
assert node.get_type_name() == "Unique"
|
||||
assert node.get_sorted() is False
|
||||
@@ -2373,7 +2373,7 @@ def test_unique_opset10():
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
assert node.get_output_element_type(1) == Type.i32
|
||||
assert node.get_output_element_type(2) == Type.i32
|
||||
assert node.get_output_element_type(3) == Type.i64
|
||||
assert node.get_output_element_type(3) == Type.i32
|
||||
|
||||
# All arguments default
|
||||
node = ng_opset10.unique(input_node)
|
||||
|
||||
Reference in New Issue
Block a user