[PyOV] Add missing element Type and functions (#12614)

This commit is contained in:
Jan Iwaszkiewicz 2022-08-18 13:19:44 +02:00 committed by GitHub
parent e24a5b8ac3
commit 65ac9b5036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 177 additions and 29 deletions

View File

@ -30,6 +30,8 @@ void regclass_graph_Type(py::module m) {
:rtype: ov.Type
)");
type.attr("undefined") = ov::element::undefined;
type.attr("dynamic") = ov::element::dynamic;
type.attr("boolean") = ov::element::boolean;
type.attr("f16") = ov::element::f16;
type.attr("f32") = ov::element::f32;
@ -46,37 +48,78 @@ void regclass_graph_Type(py::module m) {
type.attr("u32") = ov::element::u32;
type.attr("u64") = ov::element::u64;
type.attr("bf16") = ov::element::bf16;
type.attr("undefined") = ov::element::undefined;
type.def("__hash__", &ov::element::Type::hash);
type.def("__repr__", [](const ov::element::Type& self) {
std::string bitwidth = std::to_string(self.bitwidth());
if (self == ov::element::undefined) {
return "<Type: '" + self.c_type_string() + "'>";
} else if (self.is_signed()) {
if (self == ov::element::f32 || self == ov::element::f64) {
std::string bitwidth = std::to_string(self.bitwidth());
return "<Type: '" + self.c_type_string() + bitwidth + "'>";
}
return "<Type: 'u" + self.c_type_string() + bitwidth + "'>";
return "<Type: '" + self.c_type_string() + "'>";
});
type.def("__hash__", &ov::element::Type::hash);
type.def(
"__eq__",
[](const ov::element::Type& a, const ov::element::Type& b) {
return a == b;
},
py::is_operator());
type.def("is_static", &ov::element::Type::is_static);
type.def("is_dynamic", &ov::element::Type::is_dynamic);
type.def("is_real", &ov::element::Type::is_real);
type.def("is_integral", &ov::element::Type::is_integral);
type.def("is_integral_number", &ov::element::Type::is_integral_number);
type.def("is_signed", &ov::element::Type::is_signed);
type.def("is_quantized", &ov::element::Type::is_quantized);
type.def("get_type_name", &ov::element::Type::get_type_name);
type.def("compatible",
&ov::element::Type::compatible,
py::arg("other"),
R"(
Checks whether this element type is merge-compatible with
`other`.
:param other: The element type to compare this element type to.
:type other: openvino.runtime.Type
:return: `True` if element types are compatible, otherwise `False`.
:rtype: bool
)");
type.def(
"merge",
[](ov::element::Type& self, ov::element::Type& other) {
ov::element::Type dst;
if (ov::element::Type::merge(dst, self, other)) {
return py::cast(dst);
}
return py::none().cast<py::object>();
},
py::arg("other"),
R"(
Merge two element types and return result if successful,
otherwise return None.
:param other: The element type to merge with this element type.
:type other: openvino.runtime.Type
:return: If element types are compatible return the least
restrictive Type, otherwise `None`.
:rtype: Union[openvino.runtime.Type|None]
)");
type.def(
"to_dtype",
[](ov::element::Type& self) {
return Common::ov_type_to_dtype().at(self);
},
R"(
Convert Type to numpy dtype
Convert Type to numpy dtype.
:return: dtype object
:rtype: numpy.dtype
)");
type.def_property_readonly("size", &ov::element::Type::size);
type.def_property_readonly("bitwidth", &ov::element::Type::bitwidth);
type.def_property_readonly("is_real", &ov::element::Type::is_real);
}

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import numpy as np
from openvino.runtime import Type
@pytest.mark.parametrize(("dtype_string", "dtype", "ovtype"), [
("float16", np.float16, Type.f16),
("float32", np.float32, Type.f32),
("float64", np.float64, Type.f64),
("int8", np.int8, Type.i8),
("int16", np.int16, Type.i16),
("int32", np.int32, Type.i32),
("int64", np.int64, Type.i64),
("uint8", np.uint8, Type.u8),
("uint16", np.uint16, Type.u16),
("uint32", np.uint32, Type.u32),
("uint64", np.uint64, Type.u64),
("bool", np.bool_, Type.boolean),
])
def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype):
assert ovtype.to_dtype() == dtype
assert Type(dtype_string) == ovtype
assert Type(dtype) == ovtype
@pytest.mark.parametrize(("ovtype",
"static_flag",
"dynamic_flag",
"real_flag",
"integral_flag",
"signed_flag",
"quantized_flag",
"type_name",
"type_size",
"type_bitwidth"), [
(Type.f16, True, False, True, False, True, False, "f16", 2, 16),
(Type.f32, True, False, True, False, True, False, "f32", 4, 32),
(Type.f64, True, False, True, False, True, False, "f64", 8, 64),
(Type.i8, True, False, False, True, True, True, "i8", 1, 8),
(Type.i16, True, False, False, True, True, False, "i16", 2, 16),
(Type.i32, True, False, False, True, True, True, "i32", 4, 32),
(Type.i64, True, False, False, True, True, False, "i64", 8, 64),
(Type.u8, True, False, False, True, False, True, "u8", 1, 8),
(Type.u16, True, False, False, True, False, False, "u16", 2, 16),
(Type.u32, True, False, False, True, False, False, "u32", 4, 32),
(Type.u64, True, False, False, True, False, False, "u64", 8, 64),
(Type.boolean, True, False, False, True, True, False, "boolean", 1, 8),
])
def test_basic_ovtypes(ovtype,
static_flag,
dynamic_flag,
real_flag,
integral_flag,
signed_flag,
quantized_flag,
type_name,
type_size,
type_bitwidth):
assert ovtype.is_static() is static_flag
assert ovtype.is_dynamic() is dynamic_flag
assert ovtype.is_real() is real_flag
assert ovtype.is_integral() is integral_flag
assert ovtype.is_signed() is signed_flag
assert ovtype.is_quantized() is quantized_flag
assert ovtype.get_type_name() == type_name
assert ovtype.size == type_size
assert ovtype.bitwidth == type_bitwidth
def test_undefined_ovtype():
ov_type = Type.undefined
assert ov_type.is_static() is True
assert ov_type.is_dynamic() is False
assert ov_type.is_real() is False
assert ov_type.is_integral() is True
assert ov_type.is_signed() is False
assert ov_type.is_quantized() is False
assert ov_type.get_type_name() == "undefined"
assert ov_type.size == 0
# Note: might depend on the system
import sys
assert ov_type.bitwidth == sys.maxsize * 2 + 1
def test_dynamic_ov_type():
ov_type = Type.dynamic
assert ov_type.is_static() is False
assert ov_type.is_dynamic() is True
assert ov_type.is_real() is False
assert ov_type.is_integral() is True
assert ov_type.is_signed() is False
assert ov_type.is_quantized() is False
assert ov_type.get_type_name() == "dynamic"
assert ov_type.size == 0
assert ov_type.bitwidth == 0
@pytest.mark.parametrize(("ovtype_one", "ovtype_two", "expected"), [
(Type.dynamic, Type.dynamic, True),
(Type.f32, Type.dynamic, True),
(Type.dynamic, Type.f32, True),
(Type.f32, Type.f32, True),
(Type.f32, Type.f16, False),
(Type.i16, Type.f32, False),
])
def test_ovtypes_compatibility(ovtype_one, ovtype_two, expected):
assert ovtype_one.compatible(ovtype_two) is expected
@pytest.mark.parametrize(("ovtype_one", "ovtype_two", "expected"), [
(Type.dynamic, Type.dynamic, Type.dynamic),
(Type.f32, Type.dynamic, Type.f32),
(Type.dynamic, Type.f32, Type.f32),
(Type.f32, Type.f32, Type.f32),
(Type.f32, Type.f16, None),
(Type.i16, Type.f32, None),
])
def test_ovtypes_merge(ovtype_one, ovtype_two, expected):
assert ovtype_one.merge(ovtype_two) == expected

View File

@ -122,23 +122,3 @@ def test_serialize_pass():
os.remove(xml_path)
os.remove(bin_path)
@pytest.mark.parametrize(("dtype_string", "dtype", "ovtype"), [
("float16", np.float16, ov.Type.f16),
("float32", np.float32, ov.Type.f32),
("float64", np.float64, ov.Type.f64),
("int8", np.int8, ov.Type.i8),
("int16", np.int16, ov.Type.i16),
("int32", np.int32, ov.Type.i32),
("int64", np.int64, ov.Type.i64),
("uint8", np.uint8, ov.Type.u8),
("uint16", np.uint16, ov.Type.u16),
("uint32", np.uint32, ov.Type.u32),
("uint64", np.uint64, ov.Type.u64),
("bool", np.bool_, ov.Type.boolean),
])
def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype):
assert ovtype.to_dtype() == dtype
assert ov.Type(dtype_string) == ovtype
assert ov.Type(dtype) == ovtype