[PyOV] Add missing element Type and functions (#12614)
This commit is contained in:
parent
e24a5b8ac3
commit
65ac9b5036
@ -30,6 +30,8 @@ void regclass_graph_Type(py::module m) {
|
||||
:rtype: ov.Type
|
||||
)");
|
||||
|
||||
type.attr("undefined") = ov::element::undefined;
|
||||
type.attr("dynamic") = ov::element::dynamic;
|
||||
type.attr("boolean") = ov::element::boolean;
|
||||
type.attr("f16") = ov::element::f16;
|
||||
type.attr("f32") = ov::element::f32;
|
||||
@ -46,37 +48,78 @@ void regclass_graph_Type(py::module m) {
|
||||
type.attr("u32") = ov::element::u32;
|
||||
type.attr("u64") = ov::element::u64;
|
||||
type.attr("bf16") = ov::element::bf16;
|
||||
type.attr("undefined") = ov::element::undefined;
|
||||
|
||||
type.def("__hash__", &ov::element::Type::hash);
|
||||
type.def("__repr__", [](const ov::element::Type& self) {
|
||||
std::string bitwidth = std::to_string(self.bitwidth());
|
||||
if (self == ov::element::undefined) {
|
||||
return "<Type: '" + self.c_type_string() + "'>";
|
||||
} else if (self.is_signed()) {
|
||||
if (self == ov::element::f32 || self == ov::element::f64) {
|
||||
std::string bitwidth = std::to_string(self.bitwidth());
|
||||
return "<Type: '" + self.c_type_string() + bitwidth + "'>";
|
||||
}
|
||||
return "<Type: 'u" + self.c_type_string() + bitwidth + "'>";
|
||||
|
||||
return "<Type: '" + self.c_type_string() + "'>";
|
||||
});
|
||||
type.def("__hash__", &ov::element::Type::hash);
|
||||
type.def(
|
||||
"__eq__",
|
||||
[](const ov::element::Type& a, const ov::element::Type& b) {
|
||||
return a == b;
|
||||
},
|
||||
py::is_operator());
|
||||
|
||||
type.def("is_static", &ov::element::Type::is_static);
|
||||
type.def("is_dynamic", &ov::element::Type::is_dynamic);
|
||||
type.def("is_real", &ov::element::Type::is_real);
|
||||
type.def("is_integral", &ov::element::Type::is_integral);
|
||||
type.def("is_integral_number", &ov::element::Type::is_integral_number);
|
||||
type.def("is_signed", &ov::element::Type::is_signed);
|
||||
type.def("is_quantized", &ov::element::Type::is_quantized);
|
||||
type.def("get_type_name", &ov::element::Type::get_type_name);
|
||||
type.def("compatible",
|
||||
&ov::element::Type::compatible,
|
||||
py::arg("other"),
|
||||
R"(
|
||||
Checks whether this element type is merge-compatible with
|
||||
`other`.
|
||||
|
||||
:param other: The element type to compare this element type to.
|
||||
:type other: openvino.runtime.Type
|
||||
:return: `True` if element types are compatible, otherwise `False`.
|
||||
:rtype: bool
|
||||
)");
|
||||
type.def(
|
||||
"merge",
|
||||
[](ov::element::Type& self, ov::element::Type& other) {
|
||||
ov::element::Type dst;
|
||||
|
||||
if (ov::element::Type::merge(dst, self, other)) {
|
||||
return py::cast(dst);
|
||||
}
|
||||
|
||||
return py::none().cast<py::object>();
|
||||
},
|
||||
py::arg("other"),
|
||||
R"(
|
||||
Merge two element types and return result if successful,
|
||||
otherwise return None.
|
||||
|
||||
:param other: The element type to merge with this element type.
|
||||
:type other: openvino.runtime.Type
|
||||
:return: If element types are compatible return the least
|
||||
restrictive Type, otherwise `None`.
|
||||
:rtype: Union[openvino.runtime.Type|None]
|
||||
)");
|
||||
|
||||
type.def(
|
||||
"to_dtype",
|
||||
[](ov::element::Type& self) {
|
||||
return Common::ov_type_to_dtype().at(self);
|
||||
},
|
||||
R"(
|
||||
Convert Type to numpy dtype
|
||||
Convert Type to numpy dtype.
|
||||
|
||||
:return: dtype object
|
||||
:rtype: numpy.dtype
|
||||
)");
|
||||
|
||||
type.def_property_readonly("size", &ov::element::Type::size);
|
||||
type.def_property_readonly("bitwidth", &ov::element::Type::bitwidth);
|
||||
type.def_property_readonly("is_real", &ov::element::Type::is_real);
|
||||
}
|
||||
|
125
src/bindings/python/tests/test_runtime/test_type.py
Normal file
125
src/bindings/python/tests/test_runtime/test_type.py
Normal file
@ -0,0 +1,125 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from openvino.runtime import Type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("dtype_string", "dtype", "ovtype"), [
|
||||
("float16", np.float16, Type.f16),
|
||||
("float32", np.float32, Type.f32),
|
||||
("float64", np.float64, Type.f64),
|
||||
("int8", np.int8, Type.i8),
|
||||
("int16", np.int16, Type.i16),
|
||||
("int32", np.int32, Type.i32),
|
||||
("int64", np.int64, Type.i64),
|
||||
("uint8", np.uint8, Type.u8),
|
||||
("uint16", np.uint16, Type.u16),
|
||||
("uint32", np.uint32, Type.u32),
|
||||
("uint64", np.uint64, Type.u64),
|
||||
("bool", np.bool_, Type.boolean),
|
||||
])
|
||||
def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype):
|
||||
assert ovtype.to_dtype() == dtype
|
||||
assert Type(dtype_string) == ovtype
|
||||
assert Type(dtype) == ovtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("ovtype",
|
||||
"static_flag",
|
||||
"dynamic_flag",
|
||||
"real_flag",
|
||||
"integral_flag",
|
||||
"signed_flag",
|
||||
"quantized_flag",
|
||||
"type_name",
|
||||
"type_size",
|
||||
"type_bitwidth"), [
|
||||
(Type.f16, True, False, True, False, True, False, "f16", 2, 16),
|
||||
(Type.f32, True, False, True, False, True, False, "f32", 4, 32),
|
||||
(Type.f64, True, False, True, False, True, False, "f64", 8, 64),
|
||||
(Type.i8, True, False, False, True, True, True, "i8", 1, 8),
|
||||
(Type.i16, True, False, False, True, True, False, "i16", 2, 16),
|
||||
(Type.i32, True, False, False, True, True, True, "i32", 4, 32),
|
||||
(Type.i64, True, False, False, True, True, False, "i64", 8, 64),
|
||||
(Type.u8, True, False, False, True, False, True, "u8", 1, 8),
|
||||
(Type.u16, True, False, False, True, False, False, "u16", 2, 16),
|
||||
(Type.u32, True, False, False, True, False, False, "u32", 4, 32),
|
||||
(Type.u64, True, False, False, True, False, False, "u64", 8, 64),
|
||||
(Type.boolean, True, False, False, True, True, False, "boolean", 1, 8),
|
||||
])
|
||||
def test_basic_ovtypes(ovtype,
|
||||
static_flag,
|
||||
dynamic_flag,
|
||||
real_flag,
|
||||
integral_flag,
|
||||
signed_flag,
|
||||
quantized_flag,
|
||||
type_name,
|
||||
type_size,
|
||||
type_bitwidth):
|
||||
assert ovtype.is_static() is static_flag
|
||||
assert ovtype.is_dynamic() is dynamic_flag
|
||||
assert ovtype.is_real() is real_flag
|
||||
assert ovtype.is_integral() is integral_flag
|
||||
assert ovtype.is_signed() is signed_flag
|
||||
assert ovtype.is_quantized() is quantized_flag
|
||||
assert ovtype.get_type_name() == type_name
|
||||
assert ovtype.size == type_size
|
||||
assert ovtype.bitwidth == type_bitwidth
|
||||
|
||||
|
||||
def test_undefined_ovtype():
|
||||
ov_type = Type.undefined
|
||||
assert ov_type.is_static() is True
|
||||
assert ov_type.is_dynamic() is False
|
||||
assert ov_type.is_real() is False
|
||||
assert ov_type.is_integral() is True
|
||||
assert ov_type.is_signed() is False
|
||||
assert ov_type.is_quantized() is False
|
||||
assert ov_type.get_type_name() == "undefined"
|
||||
assert ov_type.size == 0
|
||||
|
||||
# Note: might depend on the system
|
||||
import sys
|
||||
assert ov_type.bitwidth == sys.maxsize * 2 + 1
|
||||
|
||||
|
||||
def test_dynamic_ov_type():
|
||||
ov_type = Type.dynamic
|
||||
assert ov_type.is_static() is False
|
||||
assert ov_type.is_dynamic() is True
|
||||
assert ov_type.is_real() is False
|
||||
assert ov_type.is_integral() is True
|
||||
assert ov_type.is_signed() is False
|
||||
assert ov_type.is_quantized() is False
|
||||
assert ov_type.get_type_name() == "dynamic"
|
||||
assert ov_type.size == 0
|
||||
assert ov_type.bitwidth == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("ovtype_one", "ovtype_two", "expected"), [
|
||||
(Type.dynamic, Type.dynamic, True),
|
||||
(Type.f32, Type.dynamic, True),
|
||||
(Type.dynamic, Type.f32, True),
|
||||
(Type.f32, Type.f32, True),
|
||||
(Type.f32, Type.f16, False),
|
||||
(Type.i16, Type.f32, False),
|
||||
])
|
||||
def test_ovtypes_compatibility(ovtype_one, ovtype_two, expected):
|
||||
assert ovtype_one.compatible(ovtype_two) is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("ovtype_one", "ovtype_two", "expected"), [
|
||||
(Type.dynamic, Type.dynamic, Type.dynamic),
|
||||
(Type.f32, Type.dynamic, Type.f32),
|
||||
(Type.dynamic, Type.f32, Type.f32),
|
||||
(Type.f32, Type.f32, Type.f32),
|
||||
(Type.f32, Type.f16, None),
|
||||
(Type.i16, Type.f32, None),
|
||||
])
|
||||
def test_ovtypes_merge(ovtype_one, ovtype_two, expected):
|
||||
assert ovtype_one.merge(ovtype_two) == expected
|
@ -122,23 +122,3 @@ def test_serialize_pass():
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(bin_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("dtype_string", "dtype", "ovtype"), [
|
||||
("float16", np.float16, ov.Type.f16),
|
||||
("float32", np.float32, ov.Type.f32),
|
||||
("float64", np.float64, ov.Type.f64),
|
||||
("int8", np.int8, ov.Type.i8),
|
||||
("int16", np.int16, ov.Type.i16),
|
||||
("int32", np.int32, ov.Type.i32),
|
||||
("int64", np.int64, ov.Type.i64),
|
||||
("uint8", np.uint8, ov.Type.u8),
|
||||
("uint16", np.uint16, ov.Type.u16),
|
||||
("uint32", np.uint32, ov.Type.u32),
|
||||
("uint64", np.uint64, ov.Type.u64),
|
||||
("bool", np.bool_, ov.Type.boolean),
|
||||
])
|
||||
def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype):
|
||||
assert ovtype.to_dtype() == dtype
|
||||
assert ov.Type(dtype_string) == ovtype
|
||||
assert ov.Type(dtype) == ovtype
|
||||
|
Loading…
Reference in New Issue
Block a user