[Python API] Binding ov::Any (#8996)
This commit is contained in:
parent
6bbb2c86e5
commit
64c6ca05ed
@ -25,7 +25,7 @@ using PyRTMap = ngraph::RTMap;
|
||||
PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
void regclass_pyngraph_PyRTMap(py::module m) {
|
||||
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap");
|
||||
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap", py::module_local());
|
||||
py_map.doc() = "ngraph.impl.PyRTMap makes bindings for std::map<std::string, "
|
||||
"std::shared_ptr<ngraph::Variant>>, which can later be used as ngraph::Node::RTMap";
|
||||
|
||||
|
@ -22,8 +22,10 @@ if sys.platform == "win32":
|
||||
# add the location of openvino dlls to your system PATH.
|
||||
#
|
||||
# looking for the libs in the pip installation path by default.
|
||||
openvino_libs = [os.path.join(os.path.dirname(__file__), "..", "..", ".."),
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "openvino", "libs")]
|
||||
openvino_libs = [
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", ".."),
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "openvino", "libs"),
|
||||
]
|
||||
# setupvars.bat script set all libs paths to OPENVINO_LIB_PATHS environment variable.
|
||||
openvino_libs_installer = os.getenv("OPENVINO_LIB_PATHS")
|
||||
if openvino_libs_installer:
|
||||
@ -35,7 +37,9 @@ if sys.platform == "win32":
|
||||
if (3, 8) <= sys.version_info:
|
||||
os.add_dll_directory(os.path.abspath(lib_path))
|
||||
else:
|
||||
os.environ["PATH"] = os.path.abspath(lib_path) + ";" + os.environ["PATH"]
|
||||
os.environ["PATH"] = (
|
||||
os.path.abspath(lib_path) + ";" + os.environ["PATH"]
|
||||
)
|
||||
|
||||
|
||||
# Openvino pybind bindings and python extended classes
|
||||
@ -50,6 +54,7 @@ from openvino.runtime.ie_api import Core
|
||||
from openvino.runtime.ie_api import ExecutableNetwork
|
||||
from openvino.runtime.ie_api import InferRequest
|
||||
from openvino.runtime.ie_api import AsyncInferQueue
|
||||
from openvino.runtime.ie_api import Variant
|
||||
from openvino.pyopenvino import Version
|
||||
from openvino.pyopenvino import Parameter
|
||||
from openvino.pyopenvino import Tensor
|
||||
|
@ -11,6 +11,7 @@ from openvino.pyopenvino import ExecutableNetwork as ExecutableNetworkBase
|
||||
from openvino.pyopenvino import InferRequest as InferRequestBase
|
||||
from openvino.pyopenvino import AsyncInferQueue as AsyncInferQueueBase
|
||||
from openvino.pyopenvino import Tensor
|
||||
from openvino.pyopenvino import Variant as VariantBase
|
||||
|
||||
from openvino.runtime.utils.types import get_dtype
|
||||
|
||||
@ -32,7 +33,11 @@ def normalize_inputs(py_dict: dict, py_types: dict) -> dict:
|
||||
raise TypeError("Incompatible key type for tensor named: {}".format(k))
|
||||
except KeyError:
|
||||
raise KeyError("Port for tensor named {} was not found!".format(k))
|
||||
py_dict[k] = val if isinstance(val, Tensor) else Tensor(np.array(val, get_dtype(ov_type)))
|
||||
py_dict[k] = (
|
||||
val
|
||||
if isinstance(val, Tensor)
|
||||
else Tensor(np.array(val, get_dtype(ov_type)))
|
||||
)
|
||||
return py_dict
|
||||
|
||||
|
||||
@ -51,7 +56,9 @@ class InferRequest(InferRequestBase):
|
||||
|
||||
def start_async(self, inputs: dict = None, userdata: Any = None) -> None:
|
||||
"""Asynchronous infer wrapper for InferRequest."""
|
||||
inputs = {} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
inputs = (
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
)
|
||||
super().start_async(inputs, userdata)
|
||||
|
||||
|
||||
@ -80,7 +87,9 @@ class AsyncInferQueue(AsyncInferQueueBase):
|
||||
inputs = (
|
||||
{}
|
||||
if inputs is None
|
||||
else normalize_inputs(inputs, get_input_types(self[self.get_idle_request_id()]))
|
||||
else normalize_inputs(
|
||||
inputs, get_input_types(self[self.get_idle_request_id()])
|
||||
)
|
||||
)
|
||||
super().start_async(inputs, userdata)
|
||||
|
||||
@ -101,7 +110,9 @@ class Core(CoreBase):
|
||||
) -> ExecutableNetwork:
|
||||
"""Compile a model from given model file path."""
|
||||
return ExecutableNetwork(
|
||||
super().import_model(model_file, device_name, {} if config is None else config)
|
||||
super().import_model(
|
||||
model_file, device_name, {} if config is None else config
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -117,3 +128,46 @@ def compile_model(model_path: str) -> ExecutableNetwork:
|
||||
"""Compact method to compile model with AUTO plugin."""
|
||||
core = Core()
|
||||
return ExtendedNetwork(core, core.compile_model(model_path, "AUTO"))
|
||||
|
||||
|
||||
class Variant(VariantBase):
|
||||
"""Variant wrapper.
|
||||
|
||||
Wrapper provides some useful overloads for simple built-in Python types.
|
||||
|
||||
Access to the Variant value is direct if it is a built-in Python data type.
|
||||
Example:
|
||||
@code{.py}
|
||||
variant = Variant([1, 2])
|
||||
print(variant[0])
|
||||
|
||||
Output: 2
|
||||
@endcode
|
||||
|
||||
Otherwise if Variant value is a custom data type (for example user class),
|
||||
access to the value is possible by 'get()' method or property 'value'.
|
||||
Example:
|
||||
@code{.py}
|
||||
class Test:
|
||||
def __init__(self):
|
||||
self.data = "test"
|
||||
|
||||
v = Variant(Test())
|
||||
print(v.value.data)
|
||||
@endcode
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: Union[str, int]) -> Any:
|
||||
return self.value[key]
|
||||
|
||||
def __get__(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def __setitem__(self, key: Union[str, int], val: Any) -> None:
|
||||
self.value[key] = val
|
||||
|
||||
def __set__(self, val: Any) -> None:
|
||||
self.value = val
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.value)
|
||||
|
@ -173,34 +173,44 @@ void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inp
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
PyAny from_ov_any(const ov::Any& any) {
|
||||
// Check for py::object
|
||||
if (any.is<py::object>()) {
|
||||
return any.as<py::object>();
|
||||
}
|
||||
// Check for std::string
|
||||
if (param.is<std::string>()) {
|
||||
return PyUnicode_FromString(param.as<std::string>().c_str());
|
||||
else if (any.is<std::string>()) {
|
||||
return PyUnicode_FromString(any.as<std::string>().c_str());
|
||||
}
|
||||
// Check for int
|
||||
else if (param.is<int>()) {
|
||||
auto val = param.as<int>();
|
||||
else if (any.is<int>()) {
|
||||
auto val = any.as<int>();
|
||||
return PyLong_FromLong((long)val);
|
||||
} else if (any.is<int64_t>()) {
|
||||
auto val = any.as<int64_t>();
|
||||
return PyLong_FromLong((long)val);
|
||||
}
|
||||
// Check for unsinged int
|
||||
else if (param.is<unsigned int>()) {
|
||||
auto val = param.as<unsigned int>();
|
||||
else if (any.is<unsigned int>()) {
|
||||
auto val = any.as<unsigned int>();
|
||||
return PyLong_FromLong((unsigned long)val);
|
||||
}
|
||||
// Check for float
|
||||
else if (param.is<float>()) {
|
||||
auto val = param.as<float>();
|
||||
else if (any.is<float>()) {
|
||||
auto val = any.as<float>();
|
||||
return PyFloat_FromDouble((double)val);
|
||||
} else if (any.is<double>()) {
|
||||
auto val = any.as<double>();
|
||||
return PyFloat_FromDouble(val);
|
||||
}
|
||||
// Check for bool
|
||||
else if (param.is<bool>()) {
|
||||
auto val = param.as<bool>();
|
||||
else if (any.is<bool>()) {
|
||||
auto val = any.as<bool>();
|
||||
return val ? Py_True : Py_False;
|
||||
}
|
||||
// Check for std::vector<std::string>
|
||||
else if (param.is<std::vector<std::string>>()) {
|
||||
auto val = param.as<std::vector<std::string>>();
|
||||
else if (any.is<std::vector<std::string>>()) {
|
||||
auto val = any.as<std::vector<std::string>>();
|
||||
PyObject* list = PyList_New(0);
|
||||
for (const auto& it : val) {
|
||||
PyObject* str_val = PyUnicode_FromString(it.c_str());
|
||||
@ -209,8 +219,17 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
return list;
|
||||
}
|
||||
// Check for std::vector<int>
|
||||
else if (param.is<std::vector<int>>()) {
|
||||
auto val = param.as<std::vector<int>>();
|
||||
else if (any.is<std::vector<int>>()) {
|
||||
auto val = any.as<std::vector<int>>();
|
||||
PyObject* list = PyList_New(0);
|
||||
for (const auto& it : val) {
|
||||
PyList_Append(list, PyLong_FromLong(it));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
// Check for std::vector<int64_t>
|
||||
else if (any.is<std::vector<int64_t>>()) {
|
||||
auto val = any.as<std::vector<int64_t>>();
|
||||
PyObject* list = PyList_New(0);
|
||||
for (const auto& it : val) {
|
||||
PyList_Append(list, PyLong_FromLong(it));
|
||||
@ -218,8 +237,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
return list;
|
||||
}
|
||||
// Check for std::vector<unsigned int>
|
||||
else if (param.is<std::vector<unsigned int>>()) {
|
||||
auto val = param.as<std::vector<unsigned int>>();
|
||||
else if (any.is<std::vector<unsigned int>>()) {
|
||||
auto val = any.as<std::vector<unsigned int>>();
|
||||
PyObject* list = PyList_New(0);
|
||||
for (const auto& it : val) {
|
||||
PyList_Append(list, PyLong_FromLong(it));
|
||||
@ -227,8 +246,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
return list;
|
||||
}
|
||||
// Check for std::vector<float>
|
||||
else if (param.is<std::vector<float>>()) {
|
||||
auto val = param.as<std::vector<float>>();
|
||||
else if (any.is<std::vector<float>>()) {
|
||||
auto val = any.as<std::vector<float>>();
|
||||
PyObject* list = PyList_New(0);
|
||||
for (const auto& it : val) {
|
||||
PyList_Append(list, PyFloat_FromDouble((double)it));
|
||||
@ -236,16 +255,16 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
return list;
|
||||
}
|
||||
// Check for std::tuple<unsigned int, unsigned int>
|
||||
else if (param.is<std::tuple<unsigned int, unsigned int>>()) {
|
||||
auto val = param.as<std::tuple<unsigned int, unsigned int>>();
|
||||
else if (any.is<std::tuple<unsigned int, unsigned int>>()) {
|
||||
auto val = any.as<std::tuple<unsigned int, unsigned int>>();
|
||||
PyObject* tuple = PyTuple_New(2);
|
||||
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLong((unsigned long)std::get<0>(val)));
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLong((unsigned long)std::get<1>(val)));
|
||||
return tuple;
|
||||
}
|
||||
// Check for std::tuple<unsigned int, unsigned int, unsigned int>
|
||||
else if (param.is<std::tuple<unsigned int, unsigned int, unsigned int>>()) {
|
||||
auto val = param.as<std::tuple<unsigned int, unsigned int, unsigned int>>();
|
||||
else if (any.is<std::tuple<unsigned int, unsigned int, unsigned int>>()) {
|
||||
auto val = any.as<std::tuple<unsigned int, unsigned int, unsigned int>>();
|
||||
PyObject* tuple = PyTuple_New(3);
|
||||
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLong((unsigned long)std::get<0>(val)));
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLong((unsigned long)std::get<1>(val)));
|
||||
@ -253,8 +272,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
return tuple;
|
||||
}
|
||||
// Check for std::map<std::string, std::string>
|
||||
else if (param.is<std::map<std::string, std::string>>()) {
|
||||
auto val = param.as<std::map<std::string, std::string>>();
|
||||
else if (any.is<std::map<std::string, std::string>>()) {
|
||||
auto val = any.as<std::map<std::string, std::string>>();
|
||||
PyObject* dict = PyDict_New();
|
||||
for (const auto& it : val) {
|
||||
PyDict_SetItemString(dict, it.first.c_str(), PyUnicode_FromString(it.second.c_str()));
|
||||
@ -262,8 +281,8 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||
return dict;
|
||||
}
|
||||
// Check for std::map<std::string, int>
|
||||
else if (param.is<std::map<std::string, int>>()) {
|
||||
auto val = param.as<std::map<std::string, int>>();
|
||||
else if (any.is<std::map<std::string, int>>()) {
|
||||
auto val = any.as<std::map<std::string, int>>();
|
||||
PyObject* dict = PyDict_New();
|
||||
for (const auto& it : val) {
|
||||
PyDict_SetItemString(dict, it.first.c_str(), PyLong_FromLong((long)it.second));
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "openvino/runtime/executable_network.hpp"
|
||||
#include "openvino/runtime/infer_request.hpp"
|
||||
#include "pyopenvino/core/containers.hpp"
|
||||
#include "pyopenvino/graph/variant.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@ -40,7 +41,7 @@ namespace Common
|
||||
|
||||
void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inputs);
|
||||
|
||||
PyObject* parse_parameter(const InferenceEngine::Parameter& param);
|
||||
PyAny from_ov_any(const ov::Any& any);
|
||||
|
||||
uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual);
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ie_extension.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <openvino/core/any.hpp>
|
||||
#include <openvino/runtime/core.hpp>
|
||||
#include <pyopenvino/core/tensor.hpp>
|
||||
|
||||
@ -101,16 +102,16 @@ void regclass_Core(py::module m) {
|
||||
|
||||
cls.def(
|
||||
"get_config",
|
||||
[](ov::runtime::Core& self, const std::string& device_name, const std::string& name) -> py::handle {
|
||||
return Common::parse_parameter(self.get_config(device_name, name));
|
||||
[](ov::runtime::Core& self, const std::string& device_name, const std::string& name) -> py::object {
|
||||
return Common::from_ov_any(self.get_config(device_name, name)).as<py::object>();
|
||||
},
|
||||
py::arg("device_name"),
|
||||
py::arg("name"));
|
||||
|
||||
cls.def(
|
||||
"get_metric",
|
||||
[](ov::runtime::Core& self, const std::string device_name, const std::string name) -> py::handle {
|
||||
return Common::parse_parameter(self.get_metric(device_name, name));
|
||||
[](ov::runtime::Core& self, const std::string device_name, const std::string name) -> py::object {
|
||||
return Common::from_ov_any(self.get_metric(device_name, name)).as<py::object>();
|
||||
},
|
||||
py::arg("device_name"),
|
||||
py::arg("name"));
|
||||
|
@ -44,15 +44,15 @@ void regclass_ExecutableNetwork(py::module m) {
|
||||
|
||||
cls.def(
|
||||
"get_config",
|
||||
[](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::handle {
|
||||
return Common::parse_parameter(self.get_config(name));
|
||||
[](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::object {
|
||||
return Common::from_ov_any(self.get_config(name)).as<py::object>();
|
||||
},
|
||||
py::arg("name"));
|
||||
|
||||
cls.def(
|
||||
"get_metric",
|
||||
[](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::handle {
|
||||
return Common::parse_parameter(self.get_metric(name));
|
||||
[](ov::runtime::ExecutableNetwork& self, const std::string& name) -> py::object {
|
||||
return Common::from_ov_any(self.get_metric(name)).as<py::object>();
|
||||
},
|
||||
py::arg("name"));
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "pyopenvino/core/common.hpp"
|
||||
#include "pyopenvino/graph/node.hpp"
|
||||
#include "pyopenvino/graph/variant.hpp"
|
||||
|
||||
@ -25,7 +26,7 @@ using PyRTMap = ov::RTMap;
|
||||
PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
void regclass_graph_PyRTMap(py::module m) {
|
||||
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap");
|
||||
auto py_map = py::class_<PyRTMap>(m, "PyRTMap");
|
||||
py_map.doc() = "openvino.impl.PyRTMap makes bindings for std::map<std::string, "
|
||||
"ov::Any, which can later be used as ov::Node::RTMap";
|
||||
|
||||
@ -35,4 +36,44 @@ void regclass_graph_PyRTMap(py::module m) {
|
||||
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const int64_t v) {
|
||||
m[k] = v;
|
||||
});
|
||||
py_map.def("__getitem__", [](PyRTMap& m, const std::string& k) -> py::object {
|
||||
return Common::from_ov_any(m[k]).as<py::object>();
|
||||
});
|
||||
py_map.def(
|
||||
"__bool__",
|
||||
[](const PyRTMap& m) -> bool {
|
||||
return !m.empty();
|
||||
},
|
||||
"Check whether the map is nonempty");
|
||||
|
||||
py_map.def(
|
||||
"__iter__",
|
||||
[](PyRTMap& m) {
|
||||
return py::make_key_iterator(m.begin(), m.end());
|
||||
},
|
||||
py::keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
|
||||
);
|
||||
|
||||
py_map.def(
|
||||
"items",
|
||||
[](PyRTMap& m) {
|
||||
return py::make_iterator(m.begin(), m.end());
|
||||
},
|
||||
py::keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
|
||||
);
|
||||
|
||||
py_map.def("__contains__", [](PyRTMap& m, const std::string& k) -> bool {
|
||||
auto it = m.find(k);
|
||||
if (it == m.end())
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
py_map.def("__delitem__", [](PyRTMap& m, const std::string& k) {
|
||||
auto it = m.find(k);
|
||||
if (it == m.end())
|
||||
throw py::key_error();
|
||||
m.erase(it);
|
||||
});
|
||||
|
||||
py_map.def("__len__", &PyRTMap::size);
|
||||
}
|
||||
|
@ -7,38 +7,53 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "openvino/core/any.hpp"
|
||||
#include "pyopenvino/core/common.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_graph_Variant(py::module m) {
|
||||
py::class_<ov::Any> variant_base(m, "Variant", py::module_local());
|
||||
variant_base.doc() = "openvino.impl.Variant wraps ov::Any";
|
||||
py::class_<PyAny, std::shared_ptr<PyAny>> variant(m, "Variant", py::module_local());
|
||||
variant.doc() = "openvino.impl.Variant wraps ov::Any";
|
||||
variant.def(py::init<py::object>());
|
||||
|
||||
variant_base.def(
|
||||
"__eq__",
|
||||
[](const ov::Any& a, const ov::Any& b) {
|
||||
return a == b;
|
||||
},
|
||||
py::is_operator());
|
||||
variant_base.def(
|
||||
"__eq__",
|
||||
[](const ov::Any& a, const std::string& b) {
|
||||
return a.as<std::string>() == b;
|
||||
},
|
||||
py::is_operator());
|
||||
variant_base.def(
|
||||
"__eq__",
|
||||
[](const ov::Any& a, const int64_t& b) {
|
||||
return a.as<int64_t>() == b;
|
||||
},
|
||||
py::is_operator());
|
||||
|
||||
variant_base.def("__repr__", [](const ov::Any self) {
|
||||
variant.def("__repr__", [](const PyAny& self) {
|
||||
std::stringstream ret;
|
||||
self.print(ret);
|
||||
return ret.str();
|
||||
});
|
||||
}
|
||||
variant.def("__eq__", [](const PyAny& a, const PyAny& b) -> bool {
|
||||
return a == b;
|
||||
});
|
||||
variant.def("__eq__", [](const PyAny& a, const ov::Any& b) -> bool {
|
||||
return a == b;
|
||||
});
|
||||
variant.def("__eq__", [](const PyAny& a, py::object b) -> bool {
|
||||
return a == PyAny(b);
|
||||
});
|
||||
variant.def(
|
||||
"get",
|
||||
[](const PyAny& self) -> py::object {
|
||||
return self.as<py::object>();
|
||||
},
|
||||
R"(
|
||||
Returns
|
||||
----------
|
||||
get : Any
|
||||
Value of ov::Any.
|
||||
)");
|
||||
variant.def(
|
||||
"set",
|
||||
[](PyAny& self, py::object value) {
|
||||
self = PyAny(value);
|
||||
},
|
||||
R"(
|
||||
Parameters
|
||||
----------
|
||||
set : Any
|
||||
Value to be set in ov::Any.
|
||||
|
||||
template void regclass_graph_VariantWrapper<std::string>(py::module m, std::string typestring);
|
||||
template void regclass_graph_VariantWrapper<int64_t>(py::module m, std::string typestring);
|
||||
)");
|
||||
variant.def_property_readonly("value", [](const PyAny& self) {
|
||||
return self.as<py::object>();
|
||||
});
|
||||
}
|
||||
|
@ -5,82 +5,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "Python.h"
|
||||
#include "openvino/core/any.hpp" // ov::RuntimeAttribute
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_graph_Variant(py::module m);
|
||||
|
||||
template<typename T>
|
||||
struct AnyT : public ov::Any {
|
||||
class PyAny : public ov::Any {
|
||||
public:
|
||||
using ov::Any::Any;
|
||||
PyAny(py::object object) : ov::Any(object) {}
|
||||
PyAny(PyObject* object) : ov::Any(py::reinterpret_borrow<py::object>(object)) {}
|
||||
PyAny(const ov::Any& any) : ov::Any(any) {}
|
||||
};
|
||||
|
||||
template <typename VT>
|
||||
extern void regclass_graph_VariantWrapper(py::module m, std::string typestring)
|
||||
{
|
||||
auto pyclass_name = py::detail::c_str((std::string("Variant") + typestring));
|
||||
py::class_<AnyT<VT>, ov::Any>
|
||||
variant_wrapper(m, pyclass_name, py::module_local());
|
||||
variant_wrapper.doc() =
|
||||
"openvino.impl.Variant[" + typestring + "] wraps ov::Any with " + typestring;
|
||||
|
||||
variant_wrapper.def(py::init<const VT&>());
|
||||
|
||||
variant_wrapper.def(
|
||||
"__eq__",
|
||||
[](const ov::Any& a, const ov::Any& b) {
|
||||
return a.as<VT>() == b.as<VT>();
|
||||
},
|
||||
py::is_operator());
|
||||
variant_wrapper.def(
|
||||
"__eq__",
|
||||
[](const ov::Any& a, const std::string& b) {
|
||||
return a.as<std::string>() == b;
|
||||
},
|
||||
py::is_operator());
|
||||
variant_wrapper.def(
|
||||
"__eq__",
|
||||
[](const ov::Any& a, const int64_t& b) { return a.as<int64_t>() == b; },
|
||||
py::is_operator());
|
||||
|
||||
variant_wrapper.def("__repr__", [](const ov::Any self) {
|
||||
std::stringstream ret;
|
||||
self.print(ret);
|
||||
return ret.str();
|
||||
});
|
||||
|
||||
variant_wrapper.def("get",
|
||||
[] (const ov::Any& self) {
|
||||
return self.as<VT>();
|
||||
},
|
||||
R"(
|
||||
Returns
|
||||
----------
|
||||
get : Variant
|
||||
Value of ov::Any.
|
||||
)");
|
||||
variant_wrapper.def("set",
|
||||
[] (ov::Any& self, const VT value) {
|
||||
self = value;
|
||||
},
|
||||
R"(
|
||||
Parameters
|
||||
----------
|
||||
set : str or int
|
||||
Value to be set in ov::Any.
|
||||
)");
|
||||
|
||||
variant_wrapper.def_property("value",
|
||||
[] (const ov::Any& self) {
|
||||
return self.as<VT>();
|
||||
},
|
||||
[] (ov::Any& self, const VT value) {
|
||||
self = value;
|
||||
});
|
||||
}
|
||||
|
@ -98,8 +98,6 @@ PYBIND11_MODULE(pyopenvino, m) {
|
||||
regmodule_graph_util(m);
|
||||
regmodule_graph_layout_helpers(m);
|
||||
regclass_graph_Variant(m);
|
||||
regclass_graph_VariantWrapper<std::string>(m, std::string("String"));
|
||||
regclass_graph_VariantWrapper<int64_t>(m, std::string("Int"));
|
||||
regclass_graph_Output<ov::Node>(m, std::string(""));
|
||||
regclass_graph_Output<const ov::Node>(m, std::string("Const"));
|
||||
|
||||
|
@ -9,7 +9,7 @@ import pytest
|
||||
import openvino.runtime.opset8 as ops
|
||||
import openvino.runtime as ov
|
||||
|
||||
from openvino.pyopenvino import VariantInt, VariantString
|
||||
from openvino.pyopenvino import Variant
|
||||
|
||||
from openvino.runtime.exceptions import UserInputError
|
||||
from openvino.runtime.impl import Function, PartialShape, Shape, Type, layout_helpers
|
||||
@ -490,8 +490,8 @@ def test_node_target_inputs_soruce_output():
|
||||
|
||||
|
||||
def test_variants():
|
||||
variant_int = VariantInt(32)
|
||||
variant_str = VariantString("test_text")
|
||||
variant_int = Variant(32)
|
||||
variant_str = Variant("test_text")
|
||||
|
||||
assert variant_int.get() == 32
|
||||
assert variant_str.get() == "test_text"
|
||||
@ -512,7 +512,6 @@ def test_runtime_info():
|
||||
runtime_info["affinity"] = "test_affinity"
|
||||
relu_node.set_friendly_name("testReLU")
|
||||
runtime_info_after = relu_node.get_rt_info()
|
||||
|
||||
assert runtime_info_after["affinity"] == "test_affinity"
|
||||
|
||||
|
||||
|
86
src/bindings/python/tests/test_ngraph/test_variant.py
Normal file
86
src/bindings/python/tests/test_ngraph/test_variant.py
Normal file
@ -0,0 +1,86 @@
|
||||
from openvino.runtime import Variant
|
||||
|
||||
|
||||
def test_variant_str():
|
||||
var = Variant("test_string")
|
||||
assert isinstance(var.value, str)
|
||||
assert var == "test_string"
|
||||
|
||||
|
||||
def test_variant_int():
|
||||
var = Variant(2137)
|
||||
assert isinstance(var.value, int)
|
||||
assert var == 2137
|
||||
|
||||
|
||||
def test_variant_float():
|
||||
var = Variant(21.37)
|
||||
assert isinstance(var.value, float)
|
||||
|
||||
|
||||
def test_variant_string_list():
|
||||
var = Variant(["test", "string"])
|
||||
assert isinstance(var.value, list)
|
||||
assert isinstance(var[0], str)
|
||||
assert var[0] == "test"
|
||||
|
||||
|
||||
def test_variant_int_list():
|
||||
v = Variant([21, 37])
|
||||
assert isinstance(v.value, list)
|
||||
assert len(v) == 2
|
||||
assert isinstance(v[0], int)
|
||||
|
||||
|
||||
def test_variant_float_list():
|
||||
v = Variant([21.0, 37.0])
|
||||
assert isinstance(v.value, list)
|
||||
assert len(v) == 2
|
||||
assert isinstance(v[0], float)
|
||||
|
||||
|
||||
def test_variant_tuple():
|
||||
v = Variant((2, 1))
|
||||
assert isinstance(v.value, tuple)
|
||||
|
||||
|
||||
def test_variant_bool():
|
||||
v = Variant(False)
|
||||
assert isinstance(v.value, bool)
|
||||
assert v is not True
|
||||
|
||||
|
||||
def test_variant_dict_str():
|
||||
v = Variant({"key": "value"})
|
||||
assert isinstance(v.value, dict)
|
||||
assert v["key"] == "value"
|
||||
|
||||
|
||||
def test_variant_dict_str_int():
|
||||
v = Variant({"key": 2})
|
||||
assert isinstance(v.value, dict)
|
||||
assert v["key"] == 2
|
||||
|
||||
|
||||
def test_variant_int_dict():
|
||||
v = Variant({1: 2})
|
||||
assert isinstance(v.value, dict)
|
||||
assert v[1] == 2
|
||||
|
||||
|
||||
def test_variant_set_new_value():
|
||||
v = Variant(int(1))
|
||||
assert isinstance(v.value, int)
|
||||
v = Variant("test")
|
||||
assert isinstance(v.value, str)
|
||||
assert v == "test"
|
||||
|
||||
|
||||
def test_variant_class():
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.text = "test"
|
||||
|
||||
v = Variant(TestClass())
|
||||
assert isinstance(v.value, TestClass)
|
||||
assert v.value.text == "test"
|
Loading…
Reference in New Issue
Block a user