From 64c6ca05edd7a3935cd9c0ff0c307ed8b4e8004b Mon Sep 17 00:00:00 2001 From: Artur Kulikowski Date: Wed, 8 Dec 2021 09:34:31 +0100 Subject: [PATCH] [Python API] Binding ov::Any (#8996) --- .../src/compatibility/pyngraph/rt_map.cpp | 2 +- .../python/src/openvino/runtime/__init__.py | 11 ++- .../python/src/openvino/runtime/ie_api.py | 62 ++++++++++++- .../python/src/pyopenvino/core/common.cpp | 73 ++++++++++------ .../python/src/pyopenvino/core/common.hpp | 3 +- .../python/src/pyopenvino/core/core.cpp | 9 +- .../pyopenvino/core/executable_network.cpp | 8 +- .../python/src/pyopenvino/graph/rt_map.cpp | 43 +++++++++- .../python/src/pyopenvino/graph/variant.cpp | 65 ++++++++------ .../python/src/pyopenvino/graph/variant.hpp | 73 ++-------------- .../python/src/pyopenvino/pyopenvino.cpp | 2 - .../python/tests/test_ngraph/test_basic.py | 7 +- .../python/tests/test_ngraph/test_variant.py | 86 +++++++++++++++++++ 13 files changed, 301 insertions(+), 143 deletions(-) create mode 100644 src/bindings/python/tests/test_ngraph/test_variant.py diff --git a/src/bindings/python/src/compatibility/pyngraph/rt_map.cpp b/src/bindings/python/src/compatibility/pyngraph/rt_map.cpp index 5bcc81423c3..ca08c46c605 100644 --- a/src/bindings/python/src/compatibility/pyngraph/rt_map.cpp +++ b/src/bindings/python/src/compatibility/pyngraph/rt_map.cpp @@ -25,7 +25,7 @@ using PyRTMap = ngraph::RTMap; PYBIND11_MAKE_OPAQUE(PyRTMap); void regclass_pyngraph_PyRTMap(py::module m) { - auto py_map = py::bind_map(m, "PyRTMap"); + auto py_map = py::bind_map(m, "PyRTMap", py::module_local()); py_map.doc() = "ngraph.impl.PyRTMap makes bindings for std::map>, which can later be used as ngraph::Node::RTMap"; diff --git a/src/bindings/python/src/openvino/runtime/__init__.py b/src/bindings/python/src/openvino/runtime/__init__.py index 4975a289736..04dfd7cc534 100644 --- a/src/bindings/python/src/openvino/runtime/__init__.py +++ b/src/bindings/python/src/openvino/runtime/__init__.py @@ -22,8 +22,10 @@ if sys.platform == "win32": # add the location of openvino dlls to your system PATH. # # looking for the libs in the pip installation path by default. - openvino_libs = [os.path.join(os.path.dirname(__file__), "..", "..", ".."), - os.path.join(os.path.dirname(__file__), "..", "..", "openvino", "libs")] + openvino_libs = [ + os.path.join(os.path.dirname(__file__), "..", "..", ".."), + os.path.join(os.path.dirname(__file__), "..", "..", "openvino", "libs"), + ] # setupvars.bat script set all libs paths to OPENVINO_LIB_PATHS environment variable. openvino_libs_installer = os.getenv("OPENVINO_LIB_PATHS") if openvino_libs_installer: @@ -35,7 +37,9 @@ if sys.platform == "win32": if (3, 8) <= sys.version_info: os.add_dll_directory(os.path.abspath(lib_path)) else: - os.environ["PATH"] = os.path.abspath(lib_path) + ";" + os.environ["PATH"] + os.environ["PATH"] = ( + os.path.abspath(lib_path) + ";" + os.environ["PATH"] + ) # Openvino pybind bindings and python extended classes @@ -50,6 +54,7 @@ from openvino.runtime.ie_api import Core from openvino.runtime.ie_api import ExecutableNetwork from openvino.runtime.ie_api import InferRequest from openvino.runtime.ie_api import AsyncInferQueue +from openvino.runtime.ie_api import Variant from openvino.pyopenvino import Version from openvino.pyopenvino import Parameter from openvino.pyopenvino import Tensor diff --git a/src/bindings/python/src/openvino/runtime/ie_api.py b/src/bindings/python/src/openvino/runtime/ie_api.py index 22cf00c85e7..e5fae209c38 100644 --- a/src/bindings/python/src/openvino/runtime/ie_api.py +++ b/src/bindings/python/src/openvino/runtime/ie_api.py @@ -11,6 +11,7 @@ from openvino.pyopenvino import ExecutableNetwork as ExecutableNetworkBase from openvino.pyopenvino import InferRequest as InferRequestBase from openvino.pyopenvino import AsyncInferQueue as AsyncInferQueueBase from openvino.pyopenvino import Tensor +from openvino.pyopenvino import Variant as VariantBase from openvino.runtime.utils.types import get_dtype @@ -32,7 +33,11 @@ def normalize_inputs(py_dict: dict, py_types: dict) -> dict: raise TypeError("Incompatible key type for tensor named: {}".format(k)) except KeyError: raise KeyError("Port for tensor named {} was not found!".format(k)) - py_dict[k] = val if isinstance(val, Tensor) else Tensor(np.array(val, get_dtype(ov_type))) + py_dict[k] = ( + val + if isinstance(val, Tensor) + else Tensor(np.array(val, get_dtype(ov_type))) + ) return py_dict @@ -51,7 +56,9 @@ class InferRequest(InferRequestBase): def start_async(self, inputs: dict = None, userdata: Any = None) -> None: """Asynchronous infer wrapper for InferRequest.""" - inputs = {} if inputs is None else normalize_inputs(inputs, get_input_types(self)) + inputs = ( + {} if inputs is None else normalize_inputs(inputs, get_input_types(self)) + ) super().start_async(inputs, userdata) @@ -80,7 +87,9 @@ class AsyncInferQueue(AsyncInferQueueBase): inputs = ( {} if inputs is None - else normalize_inputs(inputs, get_input_types(self[self.get_idle_request_id()])) + else normalize_inputs( + inputs, get_input_types(self[self.get_idle_request_id()]) + ) ) super().start_async(inputs, userdata) @@ -101,7 +110,9 @@ class Core(CoreBase): ) -> ExecutableNetwork: """Compile a model from given model file path.""" return ExecutableNetwork( - super().import_model(model_file, device_name, {} if config is None else config) + super().import_model( + model_file, device_name, {} if config is None else config + ) ) @@ -117,3 +128,46 @@ def compile_model(model_path: str) -> ExecutableNetwork: """Compact method to compile model with AUTO plugin.""" core = Core() return ExtendedNetwork(core, core.compile_model(model_path, "AUTO")) + + +class Variant(VariantBase): + """Variant wrapper. + + Wrapper provides some useful overloads for simple built-in Python types. + + Access to the Variant value is direct if it is a built-in Python data type. + Example: + @code{.py} + variant = Variant([1, 2]) + print(variant[0]) + + Output: 2 + @endcode + + Otherwise if Variant value is a custom data type (for example user class), + access to the value is possible by 'get()' method or property 'value'. + Example: + @code{.py} + class Test: + def __init__(self): + self.data = "test" + + v = Variant(Test()) + print(v.value.data) + @endcode + """ + + def __getitem__(self, key: Union[str, int]) -> Any: + return self.value[key] + + def __get__(self) -> Any: + return self.value + + def __setitem__(self, key: Union[str, int], val: Any) -> None: + self.value[key] = val + + def __set__(self, val: Any) -> None: + self.value = val + + def __len__(self) -> int: + return len(self.value) diff --git a/src/bindings/python/src/pyopenvino/core/common.cpp b/src/bindings/python/src/pyopenvino/core/common.cpp index afcbb3d7b2a..6515054bea2 100644 --- a/src/bindings/python/src/pyopenvino/core/common.cpp +++ b/src/bindings/python/src/pyopenvino/core/common.cpp @@ -173,34 +173,44 @@ void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inp } } -PyObject* parse_parameter(const InferenceEngine::Parameter& param) { +PyAny from_ov_any(const ov::Any& any) { + // Check for py::object + if (any.is()) { + return any.as(); + } // Check for std::string - if (param.is()) { - return PyUnicode_FromString(param.as().c_str()); + else if (any.is()) { + return PyUnicode_FromString(any.as().c_str()); } // Check for int - else if (param.is()) { - auto val = param.as(); + else if (any.is()) { + auto val = any.as(); + return PyLong_FromLong((long)val); + } else if (any.is()) { + auto val = any.as(); return PyLong_FromLong((long)val); } // Check for unsinged int - else if (param.is()) { - auto val = param.as(); + else if (any.is()) { + auto val = any.as(); return PyLong_FromLong((unsigned long)val); } // Check for float - else if (param.is()) { - auto val = param.as(); + else if (any.is()) { + auto val = any.as(); return PyFloat_FromDouble((double)val); + } else if (any.is()) { + auto val = any.as(); + return PyFloat_FromDouble(val); } // Check for bool - else if (param.is()) { - auto val = param.as(); + else if (any.is()) { + auto val = any.as(); return val ? Py_True : Py_False; } // Check for std::vector - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* list = PyList_New(0); for (const auto& it : val) { PyObject* str_val = PyUnicode_FromString(it.c_str()); @@ -209,8 +219,17 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { return list; } // Check for std::vector - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); + PyObject* list = PyList_New(0); + for (const auto& it : val) { + PyList_Append(list, PyLong_FromLong(it)); + } + return list; + } + // Check for std::vector + else if (any.is>()) { + auto val = any.as>(); PyObject* list = PyList_New(0); for (const auto& it : val) { PyList_Append(list, PyLong_FromLong(it)); @@ -218,8 +237,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { return list; } // Check for std::vector - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* list = PyList_New(0); for (const auto& it : val) { PyList_Append(list, PyLong_FromLong(it)); @@ -227,8 +246,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { return list; } // Check for std::vector - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* list = PyList_New(0); for (const auto& it : val) { PyList_Append(list, PyFloat_FromDouble((double)it)); @@ -236,16 +255,16 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { return list; } // Check for std::tuple - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* tuple = PyTuple_New(2); PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLong((unsigned long)std::get<0>(val))); PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLong((unsigned long)std::get<1>(val))); return tuple; } // Check for std::tuple - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* tuple = PyTuple_New(3); PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLong((unsigned long)std::get<0>(val))); PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLong((unsigned long)std::get<1>(val))); @@ -253,8 +272,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { return tuple; } // Check for std::map - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* dict = PyDict_New(); for (const auto& it : val) { PyDict_SetItemString(dict, it.first.c_str(), PyUnicode_FromString(it.second.c_str())); @@ -262,8 +281,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { return dict; } // Check for std::map - else if (param.is>()) { - auto val = param.as>(); + else if (any.is>()) { + auto val = any.as>(); PyObject* dict = PyDict_New(); for (const auto& it : val) { PyDict_SetItemString(dict, it.first.c_str(), PyLong_FromLong((long)it.second)); diff --git a/src/bindings/python/src/pyopenvino/core/common.hpp b/src/bindings/python/src/pyopenvino/core/common.hpp index 7be421b3434..6d61348770f 100644 --- a/src/bindings/python/src/pyopenvino/core/common.hpp +++ b/src/bindings/python/src/pyopenvino/core/common.hpp @@ -19,6 +19,7 @@ #include "openvino/runtime/executable_network.hpp" #include "openvino/runtime/infer_request.hpp" #include "pyopenvino/core/containers.hpp" +#include "pyopenvino/graph/variant.hpp" namespace py = pybind11; @@ -40,7 +41,7 @@ namespace Common void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inputs); - PyObject* parse_parameter(const InferenceEngine::Parameter& param); + PyAny from_ov_any(const ov::Any& any); uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual); diff --git a/src/bindings/python/src/pyopenvino/core/core.cpp b/src/bindings/python/src/pyopenvino/core/core.cpp index 4d468fdf72c..f48f63895d9 100644 --- a/src/bindings/python/src/pyopenvino/core/core.cpp +++ b/src/bindings/python/src/pyopenvino/core/core.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -101,16 +102,16 @@ void regclass_Core(py::module m) { cls.def( "get_config", - [](ov::runtime::Core& self, const std::string& device_name, const std::string& name) -> py::handle { - return Common::parse_parameter(self.get_config(device_name, name)); + [](ov::runtime::Core& self, const std::string& device_name, const std::string& name) -> py::object { + return Common::from_ov_any(self.get_config(device_name, name)).as(); }, py::arg("device_name"), py::arg("name")); cls.def( "get_metric", - [](ov::runtime::Core& self, const std::string device_name, const std::string name) -> py::handle { - return Common::parse_parameter(self.get_metric(device_name, name)); + [](ov::runtime::Core& self, const std::string device_name, const std::string name) -> py::object { + return Common::from_ov_any(self.get_metric(device_name, name)).as(); }, py::arg("device_name"), py::arg("name")); diff --git a/src/bindings/python/src/pyopenvino/core/executable_network.cpp b/src/bindings/python/src/pyopenvino/core/executable_network.cpp index 116582465fc..d805d749c8b 100644 --- a/src/bindings/python/src/pyopenvino/core/executable_network.cpp +++ b/src/bindings/python/src/pyopenvino/core/executable_network.cpp @@ -44,15 +44,15 @@ void regclass_ExecutableNetwork(py::module m) { cls.def( "get_config", - [](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::handle { - return Common::parse_parameter(self.get_config(name)); + [](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::object { + return Common::from_ov_any(self.get_config(name)).as(); }, py::arg("name")); cls.def( "get_metric", - [](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::handle { - return Common::parse_parameter(self.get_metric(name)); + [](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::object { + return Common::from_ov_any(self.get_metric(name)).as(); }, py::arg("name")); diff --git a/src/bindings/python/src/pyopenvino/graph/rt_map.cpp b/src/bindings/python/src/pyopenvino/graph/rt_map.cpp index ec60f8d0f36..e1a421fdb74 100644 --- a/src/bindings/python/src/pyopenvino/graph/rt_map.cpp +++ b/src/bindings/python/src/pyopenvino/graph/rt_map.cpp @@ -15,6 +15,7 @@ #include "openvino/op/divide.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/subtract.hpp" +#include "pyopenvino/core/common.hpp" #include "pyopenvino/graph/node.hpp" #include "pyopenvino/graph/variant.hpp" @@ -25,7 +26,7 @@ using PyRTMap = ov::RTMap; PYBIND11_MAKE_OPAQUE(PyRTMap); void regclass_graph_PyRTMap(py::module m) { - auto py_map = py::bind_map(m, "PyRTMap"); + auto py_map = py::class_(m, "PyRTMap"); py_map.doc() = "openvino.impl.PyRTMap makes bindings for std::map py::object { + return Common::from_ov_any(m[k]).as(); + }); + py_map.def( + "__bool__", + [](const PyRTMap& m) -> bool { + return !m.empty(); + }, + "Check whether the map is nonempty"); + + py_map.def( + "__iter__", + [](PyRTMap& m) { + return py::make_key_iterator(m.begin(), m.end()); + }, + py::keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + py_map.def( + "items", + [](PyRTMap& m) { + return py::make_iterator(m.begin(), m.end()); + }, + py::keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + py_map.def("__contains__", [](PyRTMap& m, const std::string& k) -> bool { + auto it = m.find(k); + if (it == m.end()) + return false; + return true; + }); + py_map.def("__delitem__", [](PyRTMap& m, const std::string& k) { + auto it = m.find(k); + if (it == m.end()) + throw py::key_error(); + m.erase(it); + }); + + py_map.def("__len__", &PyRTMap::size); } diff --git a/src/bindings/python/src/pyopenvino/graph/variant.cpp b/src/bindings/python/src/pyopenvino/graph/variant.cpp index ebf388d1c3c..cd7e08ab47e 100644 --- a/src/bindings/python/src/pyopenvino/graph/variant.cpp +++ b/src/bindings/python/src/pyopenvino/graph/variant.cpp @@ -7,38 +7,53 @@ #include #include "openvino/core/any.hpp" +#include "pyopenvino/core/common.hpp" namespace py = pybind11; void regclass_graph_Variant(py::module m) { - py::class_ variant_base(m, "Variant", py::module_local()); - variant_base.doc() = "openvino.impl.Variant wraps ov::Any"; + py::class_> variant(m, "Variant", py::module_local()); + variant.doc() = "openvino.impl.Variant wraps ov::Any"; + variant.def(py::init()); - variant_base.def( - "__eq__", - [](const ov::Any& a, const ov::Any& b) { - return a == b; - }, - py::is_operator()); - variant_base.def( - "__eq__", - [](const ov::Any& a, const std::string& b) { - return a.as() == b; - }, - py::is_operator()); - variant_base.def( - "__eq__", - [](const ov::Any& a, const int64_t& b) { - return a.as() == b; - }, - py::is_operator()); - - variant_base.def("__repr__", [](const ov::Any self) { + variant.def("__repr__", [](const PyAny& self) { std::stringstream ret; self.print(ret); return ret.str(); }); -} + variant.def("__eq__", [](const PyAny& a, const PyAny& b) -> bool { + return a == b; + }); + variant.def("__eq__", [](const PyAny& a, const ov::Any& b) -> bool { + return a == b; + }); + variant.def("__eq__", [](const PyAny& a, py::object b) -> bool { + return a == PyAny(b); + }); + variant.def( + "get", + [](const PyAny& self) -> py::object { + return self.as(); + }, + R"( + Returns + ---------- + get : Any + Value of ov::Any. + )"); + variant.def( + "set", + [](PyAny& self, py::object value) { + self = PyAny(value); + }, + R"( + Parameters + ---------- + set : Any + Value to be set in ov::Any. -template void regclass_graph_VariantWrapper(py::module m, std::string typestring); -template void regclass_graph_VariantWrapper(py::module m, std::string typestring); + )"); + variant.def_property_readonly("value", [](const PyAny& self) { + return self.as(); + }); +} diff --git a/src/bindings/python/src/pyopenvino/graph/variant.hpp b/src/bindings/python/src/pyopenvino/graph/variant.hpp index fe84fd63d30..4789c85b7cf 100644 --- a/src/bindings/python/src/pyopenvino/graph/variant.hpp +++ b/src/bindings/python/src/pyopenvino/graph/variant.hpp @@ -5,82 +5,21 @@ #pragma once #include -#include #include #include +#include "Python.h" #include "openvino/core/any.hpp" // ov::RuntimeAttribute namespace py = pybind11; void regclass_graph_Variant(py::module m); -template -struct AnyT : public ov::Any { +class PyAny : public ov::Any { +public: using ov::Any::Any; + PyAny(py::object object) : ov::Any(object) {} + PyAny(PyObject* object) : ov::Any(py::reinterpret_borrow(object)) {} + PyAny(const ov::Any& any) : ov::Any(any) {} }; - -template -extern void regclass_graph_VariantWrapper(py::module m, std::string typestring) -{ - auto pyclass_name = py::detail::c_str((std::string("Variant") + typestring)); - py::class_, ov::Any> - variant_wrapper(m, pyclass_name, py::module_local()); - variant_wrapper.doc() = - "openvino.impl.Variant[" + typestring + "] wraps ov::Any with " + typestring; - - variant_wrapper.def(py::init()); - - variant_wrapper.def( - "__eq__", - [](const ov::Any& a, const ov::Any& b) { - return a.as() == b.as(); - }, - py::is_operator()); - variant_wrapper.def( - "__eq__", - [](const ov::Any& a, const std::string& b) { - return a.as() == b; - }, - py::is_operator()); - variant_wrapper.def( - "__eq__", - [](const ov::Any& a, const int64_t& b) { return a.as() == b; }, - py::is_operator()); - - variant_wrapper.def("__repr__", [](const ov::Any self) { - std::stringstream ret; - self.print(ret); - return ret.str(); - }); - - variant_wrapper.def("get", - [] (const ov::Any& self) { - return self.as(); - }, - R"( - Returns - ---------- - get : Variant - Value of ov::Any. - )"); - variant_wrapper.def("set", - [] (ov::Any& self, const VT value) { - self = value; - }, - R"( - Parameters - ---------- - set : str or int - Value to be set in ov::Any. - )"); - - variant_wrapper.def_property("value", - [] (const ov::Any& self) { - return self.as(); - }, - [] (ov::Any& self, const VT value) { - self = value; - }); -} diff --git a/src/bindings/python/src/pyopenvino/pyopenvino.cpp b/src/bindings/python/src/pyopenvino/pyopenvino.cpp index 8ef3d6530dc..4d6265d4ccd 100644 --- a/src/bindings/python/src/pyopenvino/pyopenvino.cpp +++ b/src/bindings/python/src/pyopenvino/pyopenvino.cpp @@ -98,8 +98,6 @@ PYBIND11_MODULE(pyopenvino, m) { regmodule_graph_util(m); regmodule_graph_layout_helpers(m); regclass_graph_Variant(m); - regclass_graph_VariantWrapper(m, std::string("String")); - regclass_graph_VariantWrapper(m, std::string("Int")); regclass_graph_Output(m, std::string("")); regclass_graph_Output(m, std::string("Const")); diff --git a/src/bindings/python/tests/test_ngraph/test_basic.py b/src/bindings/python/tests/test_ngraph/test_basic.py index 8fb938be3b6..78b42535984 100644 --- a/src/bindings/python/tests/test_ngraph/test_basic.py +++ b/src/bindings/python/tests/test_ngraph/test_basic.py @@ -9,7 +9,7 @@ import pytest import openvino.runtime.opset8 as ops import openvino.runtime as ov -from openvino.pyopenvino import VariantInt, VariantString +from openvino.pyopenvino import Variant from openvino.runtime.exceptions import UserInputError from openvino.runtime.impl import Function, PartialShape, Shape, Type, layout_helpers @@ -490,8 +490,8 @@ def test_node_target_inputs_soruce_output(): def test_variants(): - variant_int = VariantInt(32) - variant_str = VariantString("test_text") + variant_int = Variant(32) + variant_str = Variant("test_text") assert variant_int.get() == 32 assert variant_str.get() == "test_text" @@ -512,7 +512,6 @@ def test_runtime_info(): runtime_info["affinity"] = "test_affinity" relu_node.set_friendly_name("testReLU") runtime_info_after = relu_node.get_rt_info() - assert runtime_info_after["affinity"] == "test_affinity" diff --git a/src/bindings/python/tests/test_ngraph/test_variant.py b/src/bindings/python/tests/test_ngraph/test_variant.py new file mode 100644 index 00000000000..1a8ee3fed2a --- /dev/null +++ b/src/bindings/python/tests/test_ngraph/test_variant.py @@ -0,0 +1,86 @@ +from openvino.runtime import Variant + + +def test_variant_str(): + var = Variant("test_string") + assert isinstance(var.value, str) + assert var == "test_string" + + +def test_variant_int(): + var = Variant(2137) + assert isinstance(var.value, int) + assert var == 2137 + + +def test_variant_float(): + var = Variant(21.37) + assert isinstance(var.value, float) + + +def test_variant_string_list(): + var = Variant(["test", "string"]) + assert isinstance(var.value, list) + assert isinstance(var[0], str) + assert var[0] == "test" + + +def test_variant_int_list(): + v = Variant([21, 37]) + assert isinstance(v.value, list) + assert len(v) == 2 + assert isinstance(v[0], int) + + +def test_variant_float_list(): + v = Variant([21.0, 37.0]) + assert isinstance(v.value, list) + assert len(v) == 2 + assert isinstance(v[0], float) + + +def test_variant_tuple(): + v = Variant((2, 1)) + assert isinstance(v.value, tuple) + + +def test_variant_bool(): + v = Variant(False) + assert isinstance(v.value, bool) + assert v is not True + + +def test_variant_dict_str(): + v = Variant({"key": "value"}) + assert isinstance(v.value, dict) + assert v["key"] == "value" + + +def test_variant_dict_str_int(): + v = Variant({"key": 2}) + assert isinstance(v.value, dict) + assert v["key"] == 2 + + +def test_variant_int_dict(): + v = Variant({1: 2}) + assert isinstance(v.value, dict) + assert v[1] == 2 + + +def test_variant_set_new_value(): + v = Variant(int(1)) + assert isinstance(v.value, int) + v = Variant("test") + assert isinstance(v.value, str) + assert v == "test" + + +def test_variant_class(): + class TestClass: + def __init__(self): + self.text = "test" + + v = Variant(TestClass()) + assert isinstance(v.value, TestClass) + assert v.value.text == "test"