From 65ac9b5036e744359b7b07e7651904dc209ddae5 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Thu, 18 Aug 2022 13:19:44 +0200 Subject: [PATCH] [PyOV] Add missing element Type and functions (#12614) --- .../pyopenvino/graph/types/element_type.cpp | 61 +++++++-- .../python/tests/test_runtime/test_type.py | 125 ++++++++++++++++++ .../test_public_transformations.py | 20 --- 3 files changed, 177 insertions(+), 29 deletions(-) create mode 100644 src/bindings/python/tests/test_runtime/test_type.py diff --git a/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp b/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp index 8b62b42e59d..a37e8d65809 100644 --- a/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp +++ b/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp @@ -30,6 +30,8 @@ void regclass_graph_Type(py::module m) { :rtype: ov.Type )"); + type.attr("undefined") = ov::element::undefined; + type.attr("dynamic") = ov::element::dynamic; type.attr("boolean") = ov::element::boolean; type.attr("f16") = ov::element::f16; type.attr("f32") = ov::element::f32; @@ -46,37 +48,78 @@ void regclass_graph_Type(py::module m) { type.attr("u32") = ov::element::u32; type.attr("u64") = ov::element::u64; type.attr("bf16") = ov::element::bf16; - type.attr("undefined") = ov::element::undefined; + type.def("__hash__", &ov::element::Type::hash); type.def("__repr__", [](const ov::element::Type& self) { - std::string bitwidth = std::to_string(self.bitwidth()); - if (self == ov::element::undefined) { - return ""; - } else if (self.is_signed()) { + if (self == ov::element::f32 || self == ov::element::f64) { + std::string bitwidth = std::to_string(self.bitwidth()); return ""; } - return ""; + + return ""; }); - type.def("__hash__", &ov::element::Type::hash); type.def( "__eq__", [](const ov::element::Type& a, const ov::element::Type& b) { return a == b; }, py::is_operator()); + + type.def("is_static", &ov::element::Type::is_static); + type.def("is_dynamic", &ov::element::Type::is_dynamic); + type.def("is_real", &ov::element::Type::is_real); + type.def("is_integral", &ov::element::Type::is_integral); + type.def("is_integral_number", &ov::element::Type::is_integral_number); + type.def("is_signed", &ov::element::Type::is_signed); + type.def("is_quantized", &ov::element::Type::is_quantized); type.def("get_type_name", &ov::element::Type::get_type_name); + type.def("compatible", + &ov::element::Type::compatible, + py::arg("other"), + R"( + Checks whether this element type is merge-compatible with + `other`. + + :param other: The element type to compare this element type to. + :type other: openvino.runtime.Type + :return: `True` if element types are compatible, otherwise `False`. + :rtype: bool + )"); + type.def( + "merge", + [](ov::element::Type& self, ov::element::Type& other) { + ov::element::Type dst; + + if (ov::element::Type::merge(dst, self, other)) { + return py::cast(dst); + } + + return py::none().cast(); + }, + py::arg("other"), + R"( + Merge two element types and return result if successful, + otherwise return None. + + :param other: The element type to merge with this element type. + :type other: openvino.runtime.Type + :return: If element types are compatible return the least + restrictive Type, otherwise `None`. + :rtype: Union[openvino.runtime.Type|None] + )"); + type.def( "to_dtype", [](ov::element::Type& self) { return Common::ov_type_to_dtype().at(self); }, R"( - Convert Type to numpy dtype + Convert Type to numpy dtype. :return: dtype object :rtype: numpy.dtype )"); + type.def_property_readonly("size", &ov::element::Type::size); type.def_property_readonly("bitwidth", &ov::element::Type::bitwidth); - type.def_property_readonly("is_real", &ov::element::Type::is_real); } diff --git a/src/bindings/python/tests/test_runtime/test_type.py b/src/bindings/python/tests/test_runtime/test_type.py new file mode 100644 index 00000000000..5e0a881c440 --- /dev/null +++ b/src/bindings/python/tests/test_runtime/test_type.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np + +from openvino.runtime import Type + + +@pytest.mark.parametrize(("dtype_string", "dtype", "ovtype"), [ + ("float16", np.float16, Type.f16), + ("float32", np.float32, Type.f32), + ("float64", np.float64, Type.f64), + ("int8", np.int8, Type.i8), + ("int16", np.int16, Type.i16), + ("int32", np.int32, Type.i32), + ("int64", np.int64, Type.i64), + ("uint8", np.uint8, Type.u8), + ("uint16", np.uint16, Type.u16), + ("uint32", np.uint32, Type.u32), + ("uint64", np.uint64, Type.u64), + ("bool", np.bool_, Type.boolean), +]) +def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype): + assert ovtype.to_dtype() == dtype + assert Type(dtype_string) == ovtype + assert Type(dtype) == ovtype + + +@pytest.mark.parametrize(("ovtype", + "static_flag", + "dynamic_flag", + "real_flag", + "integral_flag", + "signed_flag", + "quantized_flag", + "type_name", + "type_size", + "type_bitwidth"), [ + (Type.f16, True, False, True, False, True, False, "f16", 2, 16), + (Type.f32, True, False, True, False, True, False, "f32", 4, 32), + (Type.f64, True, False, True, False, True, False, "f64", 8, 64), + (Type.i8, True, False, False, True, True, True, "i8", 1, 8), + (Type.i16, True, False, False, True, True, False, "i16", 2, 16), + (Type.i32, True, False, False, True, True, True, "i32", 4, 32), + (Type.i64, True, False, False, True, True, False, "i64", 8, 64), + (Type.u8, True, False, False, True, False, True, "u8", 1, 8), + (Type.u16, True, False, False, True, False, False, "u16", 2, 16), + (Type.u32, True, False, False, True, False, False, "u32", 4, 32), + (Type.u64, True, False, False, True, False, False, "u64", 8, 64), + (Type.boolean, True, False, False, True, True, False, "boolean", 1, 8), +]) +def test_basic_ovtypes(ovtype, + static_flag, + dynamic_flag, + real_flag, + integral_flag, + signed_flag, + quantized_flag, + type_name, + type_size, + type_bitwidth): + assert ovtype.is_static() is static_flag + assert ovtype.is_dynamic() is dynamic_flag + assert ovtype.is_real() is real_flag + assert ovtype.is_integral() is integral_flag + assert ovtype.is_signed() is signed_flag + assert ovtype.is_quantized() is quantized_flag + assert ovtype.get_type_name() == type_name + assert ovtype.size == type_size + assert ovtype.bitwidth == type_bitwidth + + +def test_undefined_ovtype(): + ov_type = Type.undefined + assert ov_type.is_static() is True + assert ov_type.is_dynamic() is False + assert ov_type.is_real() is False + assert ov_type.is_integral() is True + assert ov_type.is_signed() is False + assert ov_type.is_quantized() is False + assert ov_type.get_type_name() == "undefined" + assert ov_type.size == 0 + + # Note: might depend on the system + import sys + assert ov_type.bitwidth == sys.maxsize * 2 + 1 + + +def test_dynamic_ov_type(): + ov_type = Type.dynamic + assert ov_type.is_static() is False + assert ov_type.is_dynamic() is True + assert ov_type.is_real() is False + assert ov_type.is_integral() is True + assert ov_type.is_signed() is False + assert ov_type.is_quantized() is False + assert ov_type.get_type_name() == "dynamic" + assert ov_type.size == 0 + assert ov_type.bitwidth == 0 + + +@pytest.mark.parametrize(("ovtype_one", "ovtype_two", "expected"), [ + (Type.dynamic, Type.dynamic, True), + (Type.f32, Type.dynamic, True), + (Type.dynamic, Type.f32, True), + (Type.f32, Type.f32, True), + (Type.f32, Type.f16, False), + (Type.i16, Type.f32, False), +]) +def test_ovtypes_compatibility(ovtype_one, ovtype_two, expected): + assert ovtype_one.compatible(ovtype_two) is expected + + +@pytest.mark.parametrize(("ovtype_one", "ovtype_two", "expected"), [ + (Type.dynamic, Type.dynamic, Type.dynamic), + (Type.f32, Type.dynamic, Type.f32), + (Type.dynamic, Type.f32, Type.f32), + (Type.f32, Type.f32, Type.f32), + (Type.f32, Type.f16, None), + (Type.i16, Type.f32, None), +]) +def test_ovtypes_merge(ovtype_one, ovtype_two, expected): + assert ovtype_one.merge(ovtype_two) == expected diff --git a/src/bindings/python/tests/test_transformations/test_public_transformations.py b/src/bindings/python/tests/test_transformations/test_public_transformations.py index 64262b65759..fb03a01e5a5 100644 --- a/src/bindings/python/tests/test_transformations/test_public_transformations.py +++ b/src/bindings/python/tests/test_transformations/test_public_transformations.py @@ -122,23 +122,3 @@ def test_serialize_pass(): os.remove(xml_path) os.remove(bin_path) - - -@pytest.mark.parametrize(("dtype_string", "dtype", "ovtype"), [ - ("float16", np.float16, ov.Type.f16), - ("float32", np.float32, ov.Type.f32), - ("float64", np.float64, ov.Type.f64), - ("int8", np.int8, ov.Type.i8), - ("int16", np.int16, ov.Type.i16), - ("int32", np.int32, ov.Type.i32), - ("int64", np.int64, ov.Type.i64), - ("uint8", np.uint8, ov.Type.u8), - ("uint16", np.uint16, ov.Type.u16), - ("uint32", np.uint32, ov.Type.u32), - ("uint64", np.uint64, ov.Type.u64), - ("bool", np.bool_, ov.Type.boolean), -]) -def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype): - assert ovtype.to_dtype() == dtype - assert ov.Type(dtype_string) == ovtype - assert ov.Type(dtype) == ovtype