From 4247a9d6b020ce0c5a49349674e6933c61b38e7d Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 16 Nov 2021 22:09:47 +0100 Subject: [PATCH] [PYTHON] Improve API with compact functions and Python wrappers (#8603) * Added new Tensor dispatch and improvements, replace injections with inheritance of pybind objects, remove Blob from public python API. * Clean-up tests and API from unused classes * Remove unused import * cpp codestyle * Update AsyncInferQueue with python wrapper * Codestyle cpp * Applying comments * Common tensor setting for requests --- .../bindings/python/src/openvino/__init__.py | 47 +-- .../bindings/python/src/openvino/ie_api.py | 215 ++++++------- .../src/pyopenvino/core/async_infer_queue.cpp | 17 +- .../python/src/pyopenvino/core/common.cpp | 294 ++++++++---------- .../python/src/pyopenvino/core/common.hpp | 47 +-- .../python/src/pyopenvino/core/core.cpp | 11 +- .../pyopenvino/core/executable_network.cpp | 26 +- .../python/src/pyopenvino/core/ie_blob.cpp | 20 -- .../python/src/pyopenvino/core/ie_blob.hpp | 56 ---- .../python/src/pyopenvino/core/ie_data.cpp | 39 --- .../python/src/pyopenvino/core/ie_data.hpp | 11 - .../src/pyopenvino/core/ie_input_info.cpp | 78 ----- .../src/pyopenvino/core/ie_input_info.hpp | 11 - .../python/src/pyopenvino/core/ie_network.cpp | 87 ------ .../python/src/pyopenvino/core/ie_network.hpp | 11 - .../pyopenvino/core/ie_preprocess_info.cpp | 70 ----- .../pyopenvino/core/ie_preprocess_info.hpp | 11 - .../src/pyopenvino/core/infer_request.cpp | 40 +-- .../python/src/pyopenvino/core/tensor.cpp | 17 +- .../pyopenvino/core/tensor_description.cpp | 58 ---- .../pyopenvino/core/tensor_description.hpp | 11 - .../python/src/pyopenvino/pyopenvino.cpp | 29 -- runtime/bindings/python/tests/conftest.py | 14 + runtime/bindings/python/tests/runtime.py | 11 +- .../tests/test_inference_engine/test_core.py | 80 ++--- .../test_executable_network.py | 23 +- .../test_infer_request.py | 30 +- .../test_inference_engine/test_tensor.py | 16 +- 28 files changed, 344 insertions(+), 1036 deletions(-) delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_blob.cpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_blob.hpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_data.cpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_data.hpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_input_info.cpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_input_info.hpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_network.cpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_network.hpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.cpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.hpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/tensor_description.cpp delete mode 100644 runtime/bindings/python/src/pyopenvino/core/tensor_description.hpp diff --git a/runtime/bindings/python/src/openvino/__init__.py b/runtime/bindings/python/src/openvino/__init__.py index aea37b7c58c..ca73e6b4d6a 100644 --- a/runtime/bindings/python/src/openvino/__init__.py +++ b/runtime/bindings/python/src/openvino/__init__.py @@ -6,47 +6,32 @@ from pkg_resources import get_distribution, DistributionNotFound -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore # mypy issue #1422 +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore # mypy issue #1422 try: __version__ = get_distribution("openvino-core").version except DistributionNotFound: __version__ = "0.0.0.dev0" -from openvino.ie_api import BlobWrapper -from openvino.ie_api import infer -from openvino.ie_api import start_async -from openvino.ie_api import blob_from_file -from openvino.ie_api import tensor_from_file -from openvino.ie_api import infer_new_request +# Openvino pybind bindings and python extended classes from openvino.impl import Dimension from openvino.impl import Function from openvino.impl import Node from openvino.impl import PartialShape from openvino.impl import Layout -from openvino.pyopenvino import Core -from openvino.pyopenvino import IENetwork -from openvino.pyopenvino import ExecutableNetwork +from openvino.ie_api import Core +from openvino.ie_api import ExecutableNetwork +from openvino.ie_api import InferRequest +from openvino.ie_api import AsyncInferQueue from openvino.pyopenvino import Version from openvino.pyopenvino import Parameter -from openvino.pyopenvino import InputInfoPtr -from openvino.pyopenvino import InputInfoCPtr -from openvino.pyopenvino import DataPtr -from openvino.pyopenvino import TensorDesc -from openvino.pyopenvino import get_version -from openvino.pyopenvino import AsyncInferQueue -from openvino.pyopenvino import InferRequest # TODO: move to ie_api? -from openvino.pyopenvino import Blob -from openvino.pyopenvino import PreProcessInfo -from openvino.pyopenvino import MeanVariant -from openvino.pyopenvino import ResizeAlgorithm -from openvino.pyopenvino import ColorFormat -from openvino.pyopenvino import PreProcessChannel from openvino.pyopenvino import Tensor from openvino.pyopenvino import ProfilingInfo +from openvino.pyopenvino import get_version +# Import opsets from openvino import opset1 from openvino import opset2 from openvino import opset3 @@ -56,6 +41,10 @@ from openvino import opset6 from openvino import opset7 from openvino import opset8 +# Helper functions for openvino module +from openvino.ie_api import tensor_from_file +from openvino.ie_api import compile_model + # Extend Node class to support binary operators Node.__add__ = opset8.add Node.__sub__ = opset8.subtract @@ -73,15 +62,3 @@ Node.__lt__ = opset8.less Node.__le__ = opset8.less_equal Node.__gt__ = opset8.greater Node.__ge__ = opset8.greater_equal - -# Patching for Blob class -# flake8: noqa: F811 -# this class will be removed -Blob = BlobWrapper -# Patching ExecutableNetwork -ExecutableNetwork.infer_new_request = infer_new_request -# Patching InferRequest -InferRequest.infer = infer -InferRequest.start_async = start_async -# Patching AsyncInferQueue -AsyncInferQueue.start_async = start_async diff --git a/runtime/bindings/python/src/openvino/ie_api.py b/runtime/bindings/python/src/openvino/ie_api.py index 4314398feac..6300f88ce7c 100644 --- a/runtime/bindings/python/src/openvino/ie_api.py +++ b/runtime/bindings/python/src/openvino/ie_api.py @@ -3,126 +3,109 @@ import numpy as np import copy -from typing import List, Union +from typing import Any, List, Union -from openvino.pyopenvino import TBlobFloat32 -from openvino.pyopenvino import TBlobFloat64 -from openvino.pyopenvino import TBlobInt64 -from openvino.pyopenvino import TBlobUint64 -from openvino.pyopenvino import TBlobInt32 -from openvino.pyopenvino import TBlobUint32 -from openvino.pyopenvino import TBlobInt16 -from openvino.pyopenvino import TBlobUint16 -from openvino.pyopenvino import TBlobInt8 -from openvino.pyopenvino import TBlobUint8 -from openvino.pyopenvino import TensorDesc -from openvino.pyopenvino import InferRequest -from openvino.pyopenvino import AsyncInferQueue -from openvino.pyopenvino import ExecutableNetwork +from openvino.pyopenvino import Function +from openvino.pyopenvino import Core as CoreBase +from openvino.pyopenvino import ExecutableNetwork as ExecutableNetworkBase +from openvino.pyopenvino import InferRequest as InferRequestBase +from openvino.pyopenvino import AsyncInferQueue as AsyncInferQueueBase from openvino.pyopenvino import Tensor - -precision_map = {"FP32": np.float32, - "FP64": np.float64, - "FP16": np.int16, - "BF16": np.int16, - "I16": np.int16, - "I8": np.int8, - "BIN": np.int8, - "I32": np.int32, - "I64": np.int64, - "U8": np.uint8, - "BOOL": np.uint8, - "U16": np.uint16, - "U32": np.uint32, - "U64": np.uint64} +from openvino.utils.types import get_dtype -def normalize_inputs(py_dict: dict) -> dict: - """Normalize a dictionary of inputs to contiguous numpy arrays.""" - return {k: (Tensor(v) if isinstance(v, np.ndarray) else v) - for k, v in py_dict.items()} - -# flake8: noqa: D102 -def infer(request: InferRequest, inputs: dict = {}) -> np.ndarray: - res = request._infer(inputs=normalize_inputs(inputs)) - # Required to return list since np.ndarray forces all of tensors data to match in - # dimensions. This results in errors when running ops like variadic split. - return [copy.deepcopy(tensor.data) for tensor in res] - - -def infer_new_request(exec_net: ExecutableNetwork, inputs: dict = None) -> List[np.ndarray]: - res = exec_net._infer_new_request(inputs=normalize_inputs(inputs if inputs is not None else {})) - # Required to return list since np.ndarray forces all of tensors data to match in - # dimensions. This results in errors when running ops like variadic split. - return [copy.deepcopy(tensor.data) for tensor in res] - -# flake8: noqa: D102 -def start_async(request: Union[InferRequest, AsyncInferQueue], inputs: dict = {}, userdata: dict = None) -> None: # type: ignore - request._start_async(inputs=normalize_inputs(inputs), userdata=userdata) - -# flake8: noqa: C901 -# Dispatch Blob types on Python side. -class BlobWrapper: - def __new__(cls, tensor_desc: TensorDesc, arr: np.ndarray = None): # type: ignore - arr_size = 0 - precision = "" - if tensor_desc is not None: - tensor_desc_size = int(np.prod(tensor_desc.dims)) - precision = tensor_desc.precision - if arr is not None: - arr = np.array(arr) # Keeping array as numpy array - arr_size = int(np.prod(arr.shape)) - if np.isfortran(arr): - arr = arr.ravel(order="F") - else: - arr = arr.ravel(order="C") - if arr_size != tensor_desc_size: - raise AttributeError(f"Number of elements in provided numpy array " - f"{arr_size} and required by TensorDesc " - f"{tensor_desc_size} are not equal") - if arr.dtype != precision_map[precision]: - raise ValueError(f"Data type {arr.dtype} of provided numpy array " - f"doesn't match to TensorDesc precision {precision}") - if not arr.flags["C_CONTIGUOUS"]: - arr = np.ascontiguousarray(arr) - elif arr is None: - arr = np.empty(0, dtype=precision_map[precision]) - else: - raise AttributeError("TensorDesc can't be None") - - if precision in ["FP32"]: - return TBlobFloat32(tensor_desc, arr, arr_size) - elif precision in ["FP64"]: - return TBlobFloat64(tensor_desc, arr, arr_size) - elif precision in ["FP16", "BF16"]: - return TBlobInt16(tensor_desc, arr.view(dtype=np.int16), arr_size) - elif precision in ["I64"]: - return TBlobInt64(tensor_desc, arr, arr_size) - elif precision in ["U64"]: - return TBlobUint64(tensor_desc, arr, arr_size) - elif precision in ["I32"]: - return TBlobInt32(tensor_desc, arr, arr_size) - elif precision in ["U32"]: - return TBlobUint32(tensor_desc, arr, arr_size) - elif precision in ["I16"]: - return TBlobInt16(tensor_desc, arr, arr_size) - elif precision in ["U16"]: - return TBlobUint16(tensor_desc, arr, arr_size) - elif precision in ["I8", "BIN"]: - return TBlobInt8(tensor_desc, arr, arr_size) - elif precision in ["U8", "BOOL"]: - return TBlobUint8(tensor_desc, arr, arr_size) - else: - raise AttributeError(f"Unsupported precision {precision} for Blob") - -# flake8: noqa: D102 -def blob_from_file(path_to_bin_file: str) -> BlobWrapper: - array = np.fromfile(path_to_bin_file, dtype=np.uint8) - tensor_desc = TensorDesc("U8", array.shape, "C") - return BlobWrapper(tensor_desc, array) - -# flake8: noqa: D102 def tensor_from_file(path: str) -> Tensor: - """The data will be read with dtype of unit8""" + """Create Tensor from file. Data will be read with dtype of unit8.""" return Tensor(np.fromfile(path, dtype=np.uint8)) + + +def normalize_inputs(py_dict: dict, py_types: dict) -> dict: + """Normalize a dictionary of inputs to Tensors.""" + for k, val in py_dict.items(): + try: + if isinstance(k, int): + ov_type = list(py_types.values())[k] + elif isinstance(k, str): + ov_type = py_types[k] + else: + raise TypeError("Incompatible key type for tensor named: {}".format(k)) + except KeyError: + raise KeyError("Port for tensor named {} was not found!".format(k)) + py_dict[k] = val if isinstance(val, Tensor) else Tensor(np.array(val, get_dtype(ov_type))) + return py_dict + + +def get_input_types(obj: Union[InferRequestBase, ExecutableNetworkBase]) -> dict: + """Get all precisions from object inputs.""" + return {i.get_node().get_friendly_name(): i.get_node().get_element_type() for i in obj.inputs} + + +class InferRequest(InferRequestBase): + """InferRequest wrapper.""" + + def infer(self, inputs: dict = {}) -> List[np.ndarray]: # noqa: B006 + """Infer wrapper for InferRequest.""" + res = super().infer(inputs=normalize_inputs(inputs, get_input_types(self))) + # Required to return list since np.ndarray forces all of tensors data to match in + # dimensions. This results in errors when running ops like variadic split. + return [copy.deepcopy(tensor.data) for tensor in res] + + def start_async(self, inputs: dict = {}, userdata: Any = None) -> None: # noqa: B006, type: ignore + """Asynchronous infer wrapper for InferRequest.""" + super().start_async(inputs=normalize_inputs(inputs, get_input_types(self)), userdata=userdata) + + +class ExecutableNetwork(ExecutableNetworkBase): + """ExecutableNetwork wrapper.""" + + def create_infer_request(self) -> InferRequest: + """Create new InferRequest object.""" + return InferRequest(super().create_infer_request()) + + def infer_new_request(self, inputs: dict = {}) -> List[np.ndarray]: # noqa: B006 + """Infer wrapper for ExecutableNetwork.""" + res = super().infer_new_request(inputs=normalize_inputs(inputs, get_input_types(self))) + # Required to return list since np.ndarray forces all of tensors data to match in + # dimensions. This results in errors when running ops like variadic split. + return [copy.deepcopy(tensor.data) for tensor in res] + + +class AsyncInferQueue(AsyncInferQueueBase): + """AsyncInferQueue wrapper.""" + + def __getitem__(self, i: int) -> InferRequest: + """Return i-th InferRequest from AsyncInferQueue.""" + return InferRequest(super().__getitem__(i)) + + def start_async( + self, inputs: dict = {}, userdata: Any = None # noqa: B006 + ) -> None: # type: ignore + """Asynchronous infer wrapper for AsyncInferQueue.""" + super().start_async( + inputs=normalize_inputs( + inputs, get_input_types(self[self.get_idle_request_id()]) + ), + userdata=userdata, + ) + + +class Core(CoreBase): + """Core wrapper.""" + + def compile_model( + self, model: Function, device_name: str, config: dict = {} # noqa: B006 + ) -> ExecutableNetwork: + """Compile a model from given Function.""" + return ExecutableNetwork(super().compile_model(model, device_name, config)) + + def import_model( + self, model_file: str, device_name: str, config: dict = {} # noqa: B006 + ) -> ExecutableNetwork: + """Compile a model from given model file path.""" + return ExecutableNetwork(super().import_model(model_file, device_name, config)) + + +def compile_model(model_path: str) -> ExecutableNetwork: + """Compact method to compile model with AUTO plugin.""" + return Core().compile_model(model_path, "AUTO") diff --git a/runtime/bindings/python/src/pyopenvino/core/async_infer_queue.cpp b/runtime/bindings/python/src/pyopenvino/core/async_infer_queue.cpp index 8c68958a1ef..160a3ef2307 100644 --- a/runtime/bindings/python/src/pyopenvino/core/async_infer_queue.cpp +++ b/runtime/bindings/python/src/pyopenvino/core/async_infer_queue.cpp @@ -53,7 +53,6 @@ public: }); return _idle_handles.front(); - ; } void wait_all() { @@ -135,7 +134,7 @@ void regclass_AsyncInferQueue(py::module m) { py::arg("jobs") = 0); cls.def( - "_start_async", + "start_async", [](AsyncInferQueue& self, const py::dict inputs, py::object userdata) { // getIdleRequestId function has an intention to block InferQueue // until there is at least one idle (free to use) InferRequest @@ -144,19 +143,7 @@ void regclass_AsyncInferQueue(py::module m) { // Set new inputs label/id from user self._user_ids[handle] = userdata; // Update inputs if there are any - if (!inputs.empty()) { - if (py::isinstance(inputs.begin()->first)) { - auto inputs_map = Common::cast_to_tensor_name_map(inputs); - for (auto&& input : inputs_map) { - self._requests[handle]._request.set_tensor(input.first, input.second); - } - } else if (py::isinstance(inputs.begin()->first)) { - auto inputs_map = Common::cast_to_tensor_index_map(inputs); - for (auto&& input : inputs_map) { - self._requests[handle]._request.set_input_tensor(input.first, input.second); - } - } - } + Common::set_request_tensors(self._requests[handle]._request, inputs); // Now GIL can be released - we are NOT working with Python objects in this block { py::gil_scoped_release release; diff --git a/runtime/bindings/python/src/pyopenvino/core/common.cpp b/runtime/bindings/python/src/pyopenvino/core/common.cpp index 8a15aaf6b92..6ef5774163a 100644 --- a/runtime/bindings/python/src/pyopenvino/core/common.cpp +++ b/runtime/bindings/python/src/pyopenvino/core/common.cpp @@ -6,6 +6,8 @@ #include +#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_ + namespace Common { const std::map& ov_type_to_dtype() { static const std::map ov_type_to_dtype_mapping = { @@ -45,42 +47,130 @@ const std::map& dtype_to_ov_type() { return dtype_to_ov_type_mapping; } -InferenceEngine::Layout get_layout_from_string(const std::string& layout) { - static const std::unordered_map layout_str_to_enum = { - {"ANY", InferenceEngine::Layout::ANY}, - {"NHWC", InferenceEngine::Layout::NHWC}, - {"NCHW", InferenceEngine::Layout::NCHW}, - {"NCDHW", InferenceEngine::Layout::NCDHW}, - {"NDHWC", InferenceEngine::Layout::NDHWC}, - {"OIHW", InferenceEngine::Layout::OIHW}, - {"GOIHW", InferenceEngine::Layout::GOIHW}, - {"OIDHW", InferenceEngine::Layout::OIDHW}, - {"GOIDHW", InferenceEngine::Layout::GOIDHW}, - {"SCALAR", InferenceEngine::Layout::SCALAR}, - {"C", InferenceEngine::Layout::C}, - {"CHW", InferenceEngine::Layout::CHW}, - {"HW", InferenceEngine::Layout::HW}, - {"NC", InferenceEngine::Layout::NC}, - {"CN", InferenceEngine::Layout::CN}, - {"BLOCKED", InferenceEngine::Layout::BLOCKED}}; - return layout_str_to_enum.at(layout); +ov::runtime::Tensor tensor_from_numpy(py::array& array, bool shared_memory) { + // Check if passed array has C-style contiguous memory layout. + bool is_contiguous = C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS); + auto type = Common::dtype_to_ov_type().at(py::str(array.dtype())); + std::vector shape(array.shape(), array.shape() + array.ndim()); + + // If memory is going to be shared it needs to be contiguous before + // passing to the constructor. This case should be handled by advanced + // users on their side of the code. + if (shared_memory) { + if (is_contiguous) { + std::vector strides(array.strides(), array.strides() + array.ndim()); + return ov::runtime::Tensor(type, shape, const_cast(array.data(0)), strides); + } else { + throw ov::Exception("Tensor with shared memory must be C contiguous!"); + } + } + // Convert to contiguous array if not already C-style. + if (!is_contiguous) { + array = Common::as_contiguous(array, type); + } + // Create actual Tensor and copy data. + auto tensor = ov::runtime::Tensor(type, shape); + // If ndim of py::array is 0, array is a numpy scalar. That results in size to be equal to 0. + // To gain access to actual raw/low-level data, it is needed to use buffer protocol. + py::buffer_info buf = array.request(); + std::memcpy(tensor.data(), buf.ptr, buf.ndim == 0 ? buf.itemsize : buf.itemsize * buf.size); + return tensor; } -const std::string& get_layout_from_enum(const InferenceEngine::Layout& layout) { - static const std::unordered_map layout_int_to_str_map = {{0, "ANY"}, - {1, "NCHW"}, - {2, "NHWC"}, - {3, "NCDHW"}, - {4, "NDHWC"}, - {64, "OIHW"}, - {95, "SCALAR"}, - {96, "C"}, - {128, "CHW"}, - {192, "HW"}, - {193, "NC"}, - {194, "CN"}, - {200, "BLOCKED"}}; - return layout_int_to_str_map.at(layout); +py::array as_contiguous(py::array& array, ov::element::Type type) { + switch (type) { + // floating + case ov::element::f64: + return array.cast>(); + case ov::element::f32: + return array.cast>(); + // signed + case ov::element::i64: + return array.cast>(); + case ov::element::i32: + return array.cast>(); + case ov::element::i16: + return array.cast>(); + case ov::element::i8: + return array.cast>(); + // unsigned + case ov::element::u64: + return array.cast>(); + case ov::element::u32: + return array.cast>(); + case ov::element::u16: + return array.cast>(); + case ov::element::u8: + return array.cast>(); + // other + case ov::element::boolean: + return array.cast>(); + case ov::element::u1: + return array.cast>(); + // need to create a view on array to cast it correctly + case ov::element::f16: + case ov::element::bf16: + return array.view("int16").cast>(); + default: + throw ov::Exception("Tensor cannot be created as contiguous!"); + break; + } +} + +const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor) { + return tensor.cast(); +} + +const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs) { + Containers::TensorNameMap result_map; + for (auto&& input : inputs) { + std::string name; + if (py::isinstance(input.first)) { + name = input.first.cast(); + } else { + throw py::type_error("incompatible function arguments!"); + } + if (py::isinstance(input.second)) { + auto tensor = Common::cast_to_tensor(input.second); + result_map[name] = tensor; + } else { + throw ov::Exception("Unable to cast tensor " + name + "!"); + } + } + return result_map; +} + +const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) { + Containers::TensorIndexMap result_map; + for (auto&& input : inputs) { + int idx; + if (py::isinstance(input.first)) { + idx = input.first.cast(); + } else { + throw py::type_error("incompatible function arguments!"); + } + if (py::isinstance(input.second)) { + auto tensor = Common::cast_to_tensor(input.second); + result_map[idx] = tensor; + } else { + throw ov::Exception("Unable to cast tensor " + std::to_string(idx) + "!"); + } + } + return result_map; +} + +void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inputs) { + if (!inputs.empty()) { + for (auto&& input : inputs) { + if (py::isinstance(input.first)) { + request.set_tensor(input.first.cast(), Common::cast_to_tensor(input.second)); + } else if (py::isinstance(input.first)) { + request.set_input_tensor(input.first.cast(), Common::cast_to_tensor(input.second)); + } else { + throw py::type_error("Incompatible key type for tensor named: " + input.first.cast()); + } + } + } } PyObject* parse_parameter(const InferenceEngine::Parameter& param) { @@ -185,142 +275,6 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) { } } -bool is_TBlob(const py::handle& blob) { - if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else if (py::isinstance>(blob)) { - return true; - } else { - return false; - } -} - -const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor) { - return tensor.cast(); -} - -const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs) { - Containers::TensorNameMap result_map; - for (auto&& input : inputs) { - std::string name; - if (py::isinstance(input.first)) { - name = input.first.cast(); - } else { - throw py::type_error("incompatible function arguments!"); - } - if (py::isinstance(input.second)) { - auto tensor = Common::cast_to_tensor(input.second); - result_map[name] = tensor; - } else { - throw ov::Exception("Unable to cast tensor " + name + "!"); - } - } - return result_map; -} - -const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) { - Containers::TensorIndexMap result_map; - for (auto&& input : inputs) { - int idx; - if (py::isinstance(input.first)) { - idx = input.first.cast(); - } else { - throw py::type_error("incompatible function arguments!"); - } - if (py::isinstance(input.second)) { - auto tensor = Common::cast_to_tensor(input.second); - result_map[idx] = tensor; - } else { - throw ov::Exception("Unable to cast tensor " + std::to_string(idx) + "!"); - } - } - return result_map; -} - -const std::shared_ptr cast_to_blob(const py::handle& blob) { - if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else if (py::isinstance>(blob)) { - return blob.cast>&>(); - } else { - IE_THROW() << "Unsupported data type for when casting to blob!"; - // return nullptr; - } -} - -void blob_from_numpy(const py::handle& arr, InferenceEngine::Blob::Ptr blob) { - if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else if (py::isinstance>(arr)) { - Common::fill_blob(arr, blob); - } else { - IE_THROW() << "Unsupported data type for when filling blob!"; - } -} - -void set_request_blobs(InferenceEngine::InferRequest& request, const py::dict& dictonary) { - for (auto&& pair : dictonary) { - const std::string& name = pair.first.cast(); - if (py::isinstance(pair.second)) { - Common::blob_from_numpy(pair.second, request.GetBlob(name)); - } else if (is_TBlob(pair.second)) { - request.SetBlob(name, Common::cast_to_blob(pair.second)); - } else { - IE_THROW() << "Unable to set blob " << name << "!"; - } - } -} - uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual) { try { auto parameter_value = actual.get_metric(METRIC_KEY(SUPPORTED_METRICS)); diff --git a/runtime/bindings/python/src/pyopenvino/core/common.hpp b/runtime/bindings/python/src/pyopenvino/core/common.hpp index 867330640f3..7e2e0f0dfc9 100644 --- a/runtime/bindings/python/src/pyopenvino/core/common.hpp +++ b/runtime/bindings/python/src/pyopenvino/core/common.hpp @@ -4,62 +4,43 @@ #pragma once -#include -#include -#include -#include -#include +#include + +#include #include #include -#include + +#include +#include +#include #include "Python.h" #include "ie_common.h" #include "openvino/runtime/tensor.hpp" #include "openvino/runtime/executable_network.hpp" +#include "openvino/runtime/infer_request.hpp" #include "pyopenvino/core/containers.hpp" namespace py = pybind11; namespace Common { - template - void fill_blob(const py::handle& py_array, InferenceEngine::Blob::Ptr blob) - { - py::array_t arr = py::cast(py_array); - if (arr.size() != 0) { - // blob->allocate(); - InferenceEngine::MemoryBlob::Ptr mem_blob = InferenceEngine::as(blob); - std::copy( - arr.data(0), arr.data(0) + arr.size(), mem_blob->rwmap().as()); - } else { - py::print("Empty array!"); - } - } - const std::map& ov_type_to_dtype(); + const std::map& dtype_to_ov_type(); - InferenceEngine::Layout get_layout_from_string(const std::string& layout); + ov::runtime::Tensor tensor_from_numpy(py::array& array, bool shared_memory); - const std::string& get_layout_from_enum(const InferenceEngine::Layout& layout); + py::array as_contiguous(py::array& array, ov::element::Type type); - PyObject* parse_parameter(const InferenceEngine::Parameter& param); - - PyObject* parse_parameter(const InferenceEngine::Parameter& param); - - bool is_TBlob(const py::handle& blob); - - const std::shared_ptr cast_to_blob(const py::handle& blob); + const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor); const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs); const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs); - const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor); + void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inputs); - void blob_from_numpy(const py::handle& _arr, InferenceEngine::Blob::Ptr &blob); - - void set_request_blobs(InferenceEngine::InferRequest& request, const py::dict& dictonary); + PyObject* parse_parameter(const InferenceEngine::Parameter& param); uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual); }; // namespace Common diff --git a/runtime/bindings/python/src/pyopenvino/core/core.cpp b/runtime/bindings/python/src/pyopenvino/core/core.cpp index 7538ffd7f79..46eed93d02b 100644 --- a/runtime/bindings/python/src/pyopenvino/core/core.cpp +++ b/runtime/bindings/python/src/pyopenvino/core/core.cpp @@ -23,6 +23,7 @@ std::string to_string(py::handle handle) { void regclass_Core(py::module m) { py::class_> cls(m, "Core"); + cls.def(py::init(), py::arg("xml_config_file") = ""); cls.def("set_config", @@ -35,10 +36,18 @@ void regclass_Core(py::module m) { (ov::runtime::ExecutableNetwork( ov::runtime::Core::*)(const std::shared_ptr&, const std::string&, const ConfigMap&)) & ov::runtime::Core::compile_model, - py::arg("network"), + py::arg("model"), py::arg("device_name"), py::arg("config") = py::dict()); + cls.def("compile_model", + (ov::runtime::ExecutableNetwork( + ov::runtime::Core::*)(const std::string&, const std::string&, const ConfigMap&)) & + ov::runtime::Core::compile_model, + py::arg("model_path"), + py::arg("device_name"), + py::arg("config") = py::dict()); + cls.def("get_versions", &ov::runtime::Core::get_versions); cls.def("read_model", diff --git a/runtime/bindings/python/src/pyopenvino/core/executable_network.cpp b/runtime/bindings/python/src/pyopenvino/core/executable_network.cpp index 389a865aea1..37d6811f38d 100644 --- a/runtime/bindings/python/src/pyopenvino/core/executable_network.cpp +++ b/runtime/bindings/python/src/pyopenvino/core/executable_network.cpp @@ -20,31 +20,21 @@ void regclass_ExecutableNetwork(py::module m) { m, "ExecutableNetwork"); + cls.def(py::init([](ov::runtime::ExecutableNetwork& other) { + return other; + }), + py::arg("other")); + cls.def("create_infer_request", [](ov::runtime::ExecutableNetwork& self) { return InferRequestWrapper(self.create_infer_request(), self.inputs(), self.outputs()); }); cls.def( - "_infer_new_request", + "infer_new_request", [](ov::runtime::ExecutableNetwork& self, const py::dict& inputs) { auto request = self.create_infer_request(); - const auto key = inputs.begin()->first; - if (!inputs.empty()) { - if (py::isinstance(key)) { - auto inputs_map = Common::cast_to_tensor_name_map(inputs); - for (auto&& input : inputs_map) { - request.set_tensor(input.first, input.second); - } - } else if (py::isinstance(key)) { - auto inputs_map = Common::cast_to_tensor_index_map(inputs); - for (auto&& input : inputs_map) { - request.set_input_tensor(input.first, input.second); - } - } else { - throw py::type_error("Incompatible key type! Supported types are string and int."); - } - } - + // Update inputs if there are any + Common::set_request_tensors(request, inputs); request.infer(); Containers::InferResults results; diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_blob.cpp b/runtime/bindings/python/src/pyopenvino/core/ie_blob.cpp deleted file mode 100644 index 11e6c7634a3..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_blob.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "ie_blob.h" - -#include -#include -#include - -#include - -#include "pyopenvino/core/ie_blob.hpp" -#include "pyopenvino/core/tensor_description.hpp" - -namespace py = pybind11; - -void regclass_Blob(py::module m) { - py::class_> cls(m, "Blob"); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_blob.hpp b/runtime/bindings/python/src/pyopenvino/core/ie_blob.hpp deleted file mode 100644 index b5efc440386..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_blob.hpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include -#include -#include - -#include "ie_blob.h" -#include "ie_common.h" -#include "ie_layouts.h" -#include "ie_precision.hpp" - -#include "pyopenvino/core/tensor_description.hpp" - -namespace py = pybind11; - -void regclass_Blob(py::module m); - -template -void regclass_TBlob(py::module m, std::string typestring) -{ - auto pyclass_name = py::detail::c_str((std::string("TBlob") + typestring)); - - py::class_, std::shared_ptr>> cls( - m, pyclass_name); - - cls.def(py::init( - [](const InferenceEngine::TensorDesc& tensorDesc, py::array_t& arr, size_t size = 0) { - auto blob = InferenceEngine::make_shared_blob(tensorDesc); - blob->allocate(); - if (size != 0) { - std::copy(arr.data(0), arr.data(0) + size, blob->rwmap().template as()); - } - return blob; - })); - - cls.def_property_readonly("buffer", [](InferenceEngine::TBlob& self) { - auto blob_ptr = self.buffer().template as(); - auto shape = self.getTensorDesc().getDims(); - return py::array_t(shape, &blob_ptr[0], py::cast(self)); - }); - - cls.def_property_readonly("tensor_desc", - [](InferenceEngine::TBlob& self) { return self.getTensorDesc(); }); - - cls.def("__str__", [](InferenceEngine::TBlob& self) -> std::string { - std::stringstream ss; - auto blob_ptr = self.buffer().template as(); - auto shape = self.getTensorDesc().getDims(); - auto py_arr = py::array_t(shape, &blob_ptr[0], py::cast(self)); - ss << py_arr; - return ss.str(); - }); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_data.cpp b/runtime/bindings/python/src/pyopenvino/core/ie_data.cpp deleted file mode 100644 index d1fd3bf760d..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_data.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "pyopenvino/core/ie_data.hpp" - -#include -#include - -#include "common.hpp" - -namespace py = pybind11; - -void regclass_Data(py::module m) { - py::class_> cls(m, "DataPtr"); - - cls.def_property( - "layout", - [](InferenceEngine::Data& self) { - return Common::get_layout_from_enum(self.getLayout()); - }, - [](InferenceEngine::Data& self, const std::string& layout) { - self.setLayout(Common::get_layout_from_string(layout)); - }); - - cls.def_property( - "precision", - [](InferenceEngine::Data& self) { - return self.getPrecision().name(); - }, - [](InferenceEngine::Data& self, const std::string& precision) { - self.setPrecision(InferenceEngine::Precision::FromStr(precision)); - }); - - cls.def_property_readonly("shape", &InferenceEngine::Data::getDims); - - cls.def_property_readonly("name", &InferenceEngine::Data::getName); - // cls.def_property_readonly("initialized", ); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_data.hpp b/runtime/bindings/python/src/pyopenvino/core/ie_data.hpp deleted file mode 100644 index 6b1459714ec..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_data.hpp +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace py = pybind11; - -void regclass_Data(py::module m); diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_input_info.cpp b/runtime/bindings/python/src/pyopenvino/core/ie_input_info.cpp deleted file mode 100644 index d47020b537a..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_input_info.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "pyopenvino/core/ie_input_info.hpp" - -#include - -#include - -#include "common.hpp" - -namespace py = pybind11; - -class ConstInputInfoWrapper { -public: - ConstInputInfoWrapper() = default; - ~ConstInputInfoWrapper() = default; - const InferenceEngine::InputInfo& cref() const { - return value; - } - -protected: - const InferenceEngine::InputInfo& ref() { - return this->value; - } - const InferenceEngine::InputInfo value = InferenceEngine::InputInfo(); -}; - -void regclass_InputInfo(py::module m) { - // Workaround for constant class - py::class_> cls_const(m, "InputInfoCPtr"); - - cls_const.def(py::init<>()); - - cls_const.def_property_readonly("input_data", [](const ConstInputInfoWrapper& self) { - return self.cref().getInputData(); - }); - cls_const.def_property_readonly("precision", [](const ConstInputInfoWrapper& self) { - return self.cref().getPrecision().name(); - }); - cls_const.def_property_readonly("tensor_desc", [](const ConstInputInfoWrapper& self) { - return self.cref().getTensorDesc(); - }); - cls_const.def_property_readonly("name", [](const ConstInputInfoWrapper& self) { - return self.cref().name(); - }); - // Mutable version - py::class_> cls(m, "InputInfoPtr"); - - cls.def(py::init<>()); - - cls.def_property("input_data", - &InferenceEngine::InputInfo::getInputData, - &InferenceEngine::InputInfo::setInputData); - cls.def_property( - "layout", - [](InferenceEngine::InputInfo& self) { - return Common::get_layout_from_enum(self.getLayout()); - }, - [](InferenceEngine::InputInfo& self, const std::string& layout) { - self.setLayout(Common::get_layout_from_string(layout)); - }); - cls.def_property( - "precision", - [](InferenceEngine::InputInfo& self) { - return self.getPrecision().name(); - }, - [](InferenceEngine::InputInfo& self, const std::string& precision) { - self.setPrecision(InferenceEngine::Precision::FromStr(precision)); - }); - cls.def_property_readonly("tensor_desc", &InferenceEngine::InputInfo::getTensorDesc); - cls.def_property_readonly("name", &InferenceEngine::InputInfo::name); - cls.def_property_readonly("preprocess_info", [](InferenceEngine::InputInfo& self) { - InferenceEngine::PreProcessInfo& preprocess = self.getPreProcess(); - return preprocess; - }); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_input_info.hpp b/runtime/bindings/python/src/pyopenvino/core/ie_input_info.hpp deleted file mode 100644 index 69d17221bc2..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_input_info.hpp +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace py = pybind11; - -void regclass_InputInfo(py::module m); diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_network.cpp b/runtime/bindings/python/src/pyopenvino/core/ie_network.cpp deleted file mode 100644 index e06f9bf79bb..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_network.cpp +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "pyopenvino/core/ie_network.hpp" - -#include -#include -#include - -#include - -#include "openvino/core/function.hpp" -#include "pyopenvino/core/containers.hpp" -#include "pyopenvino/core/ie_input_info.hpp" - -namespace py = pybind11; - -void regclass_IENetwork(py::module m) { - py::class_> cls(m, "IENetwork"); - cls.def(py::init()); - - cls.def(py::init([](std::shared_ptr& function) { - InferenceEngine::CNNNetwork cnnNetwork(function); - return std::make_shared(cnnNetwork); - })); - - cls.def("reshape", - [](InferenceEngine::CNNNetwork& self, const std::map>& input_shapes) { - self.reshape(input_shapes); - }); - - cls.def( - "add_outputs", - [](InferenceEngine::CNNNetwork& self, py::handle& outputs) { - int i = 0; - py::list _outputs; - if (!py::isinstance(outputs)) { - if (py::isinstance(outputs)) { - _outputs.append(outputs.cast()); - } else if (py::isinstance(outputs)) { - _outputs.append(outputs.cast()); - } - } else { - _outputs = outputs.cast(); - } - for (py::handle output : _outputs) { - if (py::isinstance(_outputs[i])) { - self.addOutput(output.cast(), 0); - } else if (py::isinstance(output)) { - py::tuple output_tuple = output.cast(); - self.addOutput(output_tuple[0].cast(), output_tuple[1].cast()); - } else { - IE_THROW() << "Incorrect type " << output.get_type() << "for layer to add at index " << i - << ". Expected string with layer name or tuple with two elements: layer name as " - "first element and port id as second"; - } - i++; - } - }, - py::arg("outputs")); - cls.def("add_output", &InferenceEngine::CNNNetwork::addOutput, py::arg("layer_name"), py::arg("output_index") = 0); - - cls.def( - "serialize", - [](InferenceEngine::CNNNetwork& self, const std::string& path_to_xml, const std::string& path_to_bin) { - self.serialize(path_to_xml, path_to_bin); - }, - py::arg("path_to_xml"), - py::arg("path_to_bin") = ""); - - cls.def("get_function", [](InferenceEngine::CNNNetwork& self) { - return self.getFunction(); - }); - - cls.def("get_ov_name_for_tensor", &InferenceEngine::CNNNetwork::getOVNameForTensor, py::arg("orig_name")); - - cls.def_property("batch_size", - &InferenceEngine::CNNNetwork::getBatchSize, - &InferenceEngine::CNNNetwork::setBatchSize); - - cls.def_property_readonly("outputs", [](InferenceEngine::CNNNetwork& self) { - return self.getOutputsInfo(); - }); - - cls.def_property_readonly("name", &InferenceEngine::CNNNetwork::getName); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_network.hpp b/runtime/bindings/python/src/pyopenvino/core/ie_network.hpp deleted file mode 100644 index 9cbd5e43456..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_network.hpp +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace py = pybind11; - -void regclass_IENetwork(py::module m); diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.cpp b/runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.cpp deleted file mode 100644 index 7accf2f2b66..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.cpp +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "pyopenvino/core/ie_preprocess_info.hpp" - -#include -#include - -#include - -#include "pyopenvino/core/common.hpp" - -namespace py = pybind11; - -void regclass_PreProcessInfo(py::module m) { - py::class_>( - m, - "PreProcessChannel") - .def_readwrite("std_scale", &InferenceEngine::PreProcessChannel::stdScale) - .def_readwrite("mean_value", &InferenceEngine::PreProcessChannel::meanValue) - .def_readwrite("mean_data", &InferenceEngine::PreProcessChannel::meanData); - - py::class_ cls(m, "PreProcessInfo"); - - cls.def(py::init()); - cls.def("__getitem__", [](InferenceEngine::PreProcessInfo& self, size_t& index) { - return self[index]; - }); - cls.def("get_number_of_channels", &InferenceEngine::PreProcessInfo::getNumberOfChannels); - cls.def("init", &InferenceEngine::PreProcessInfo::init); - cls.def("set_mean_image", [](InferenceEngine::PreProcessInfo& self, py::handle meanImage) { - self.setMeanImage(Common::cast_to_blob(meanImage)); - }); - cls.def("set_mean_image_for_channel", - [](InferenceEngine::PreProcessInfo& self, py::handle meanImage, const size_t channel) { - self.setMeanImageForChannel(Common::cast_to_blob(meanImage), channel); - }); - cls.def_property("mean_variant", - &InferenceEngine::PreProcessInfo::getMeanVariant, - &InferenceEngine::PreProcessInfo::setVariant); - cls.def_property("resize_algorithm", - &InferenceEngine::PreProcessInfo::getResizeAlgorithm, - &InferenceEngine::PreProcessInfo::setResizeAlgorithm); - cls.def_property("color_format", - &InferenceEngine::PreProcessInfo::getColorFormat, - &InferenceEngine::PreProcessInfo::setColorFormat); - - py::enum_(m, "MeanVariant") - .value("MEAN_IMAGE", InferenceEngine::MeanVariant::MEAN_IMAGE) - .value("MEAN_VALUE", InferenceEngine::MeanVariant::MEAN_VALUE) - .value("NONE", InferenceEngine::MeanVariant::NONE) - .export_values(); - - py::enum_(m, "ResizeAlgorithm") - .value("NO_RESIZE", InferenceEngine::ResizeAlgorithm::NO_RESIZE) - .value("RESIZE_BILINEAR", InferenceEngine::ResizeAlgorithm::RESIZE_BILINEAR) - .value("RESIZE_AREA", InferenceEngine::ResizeAlgorithm::RESIZE_AREA) - .export_values(); - - py::enum_(m, "ColorFormat") - .value("RAW", InferenceEngine::ColorFormat::RAW) - .value("RGB", InferenceEngine::ColorFormat::RGB) - .value("BGR", InferenceEngine::ColorFormat::BGR) - .value("RGBX", InferenceEngine::ColorFormat::RGBX) - .value("BGRX", InferenceEngine::ColorFormat::BGRX) - .value("NV12", InferenceEngine::ColorFormat::NV12) - .value("I420", InferenceEngine::ColorFormat::I420) - .export_values(); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.hpp b/runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.hpp deleted file mode 100644 index cc762ada0cb..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/ie_preprocess_info.hpp +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace py = pybind11; - -void regclass_PreProcessInfo(py::module m); \ No newline at end of file diff --git a/runtime/bindings/python/src/pyopenvino/core/infer_request.cpp b/runtime/bindings/python/src/pyopenvino/core/infer_request.cpp index b57d4f7569e..65ed43e8499 100644 --- a/runtime/bindings/python/src/pyopenvino/core/infer_request.cpp +++ b/runtime/bindings/python/src/pyopenvino/core/infer_request.cpp @@ -20,6 +20,12 @@ namespace py = pybind11; void regclass_InferRequest(py::module m) { py::class_> cls(m, "InferRequest"); + + cls.def(py::init([](InferRequestWrapper& other) { + return other; + }), + py::arg("other")); + cls.def( "set_tensors", [](InferRequestWrapper& self, const py::dict& inputs) { @@ -51,22 +57,10 @@ void regclass_InferRequest(py::module m) { py::arg("inputs")); cls.def( - "_infer", + "infer", [](InferRequestWrapper& self, const py::dict& inputs) { // Update inputs if there are any - if (!inputs.empty()) { - if (py::isinstance(inputs.begin()->first)) { - auto inputs_map = Common::cast_to_tensor_name_map(inputs); - for (auto&& input : inputs_map) { - self._request.set_tensor(input.first, input.second); - } - } else if (py::isinstance(inputs.begin()->first)) { - auto inputs_map = Common::cast_to_tensor_index_map(inputs); - for (auto&& input : inputs_map) { - self._request.set_input_tensor(input.first, input.second); - } - } - } + Common::set_request_tensors(self._request, inputs); // Call Infer function self._start_time = Time::now(); self._request.infer(); @@ -80,23 +74,11 @@ void regclass_InferRequest(py::module m) { py::arg("inputs")); cls.def( - "_start_async", + "start_async", [](InferRequestWrapper& self, const py::dict& inputs, py::object& userdata) { // Update inputs if there are any - if (!inputs.empty()) { - if (py::isinstance(inputs.begin()->first)) { - auto inputs_map = Common::cast_to_tensor_name_map(inputs); - for (auto&& input : inputs_map) { - self._request.set_tensor(input.first, input.second); - } - } else if (py::isinstance(inputs.begin()->first)) { - auto inputs_map = Common::cast_to_tensor_index_map(inputs); - for (auto&& input : inputs_map) { - self._request.set_input_tensor(input.first, input.second); - } - } - } - if (userdata != py::none()) { + Common::set_request_tensors(self._request, inputs); + if (!userdata.is(py::none())) { if (self.user_callback_defined) { self.userdata = userdata; } else { diff --git a/runtime/bindings/python/src/pyopenvino/core/tensor.cpp b/runtime/bindings/python/src/pyopenvino/core/tensor.cpp index a0d6952f04e..1f1146c045c 100644 --- a/runtime/bindings/python/src/pyopenvino/core/tensor.cpp +++ b/runtime/bindings/python/src/pyopenvino/core/tensor.cpp @@ -10,28 +10,13 @@ #include "openvino/runtime/tensor.hpp" #include "pyopenvino/core/common.hpp" -#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_ - namespace py = pybind11; void regclass_Tensor(py::module m) { py::class_> cls(m, "Tensor"); cls.def(py::init([](py::array& array, bool shared_memory) { - auto type = Common::dtype_to_ov_type().at(py::str(array.dtype())); - std::vector shape(array.shape(), array.shape() + array.ndim()); - if (shared_memory) { - if (C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS)) { - std::vector strides(array.strides(), array.strides() + array.ndim()); - return ov::runtime::Tensor(type, shape, const_cast(array.data(0)), strides); - } else { - IE_THROW() << "Tensor with shared memory must be C contiguous!"; - } - } - array = py::module::import("numpy").attr("ascontiguousarray")(array).cast(); - auto tensor = ov::runtime::Tensor(type, shape); - std::memcpy(tensor.data(), array.data(0), array.nbytes()); - return tensor; + return Common::tensor_from_numpy(array, shared_memory); }), py::arg("array"), py::arg("shared_memory") = false); diff --git a/runtime/bindings/python/src/pyopenvino/core/tensor_description.cpp b/runtime/bindings/python/src/pyopenvino/core/tensor_description.cpp deleted file mode 100644 index b9382f34e5c..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/tensor_description.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "pyopenvino/core/tensor_description.hpp" - -#include -#include -#include - -#include - -#include "common.hpp" - -namespace py = pybind11; -using namespace InferenceEngine; - -void regclass_TensorDecription(py::module m) { - py::class_> cls(m, "TensorDesc"); - cls.def(py::init()); - cls.def(py::init([](const std::string& precision, const SizeVector& dims, const std::string& layout) { - return TensorDesc(Precision::FromStr(precision), dims, Common::get_layout_from_string(layout)); - })); - - cls.def_property( - "layout", - [](TensorDesc& self) { - return Common::get_layout_from_enum(self.getLayout()); - }, - [](TensorDesc& self, const std::string& layout) { - self.setLayout(Common::get_layout_from_string(layout)); - }); - - cls.def_property( - "precision", - [](TensorDesc& self) { - return self.getPrecision().name(); - }, - [](TensorDesc& self, const std::string& precision) { - self.setPrecision(InferenceEngine::Precision::FromStr(precision)); - }); - - cls.def_property( - "dims", - [](TensorDesc& self) { - return self.getDims(); - }, - [](TensorDesc& self, const SizeVector& dims) { - self.setDims(dims); - }); - - cls.def( - "__eq__", - [](const TensorDesc& a, const TensorDesc b) { - return a == b; - }, - py::is_operator()); -} diff --git a/runtime/bindings/python/src/pyopenvino/core/tensor_description.hpp b/runtime/bindings/python/src/pyopenvino/core/tensor_description.hpp deleted file mode 100644 index 806c7b9d3b8..00000000000 --- a/runtime/bindings/python/src/pyopenvino/core/tensor_description.hpp +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) 2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace py = pybind11; - -void regclass_TensorDecription(py::module m); diff --git a/runtime/bindings/python/src/pyopenvino/pyopenvino.cpp b/runtime/bindings/python/src/pyopenvino/pyopenvino.cpp index c1529428789..3e3a3b2aaed 100644 --- a/runtime/bindings/python/src/pyopenvino/pyopenvino.cpp +++ b/runtime/bindings/python/src/pyopenvino/pyopenvino.cpp @@ -23,17 +23,11 @@ #include "pyopenvino/core/containers.hpp" #include "pyopenvino/core/core.hpp" #include "pyopenvino/core/executable_network.hpp" -#include "pyopenvino/core/ie_blob.hpp" -#include "pyopenvino/core/ie_data.hpp" -#include "pyopenvino/core/ie_input_info.hpp" -#include "pyopenvino/core/ie_network.hpp" #include "pyopenvino/core/ie_parameter.hpp" -#include "pyopenvino/core/ie_preprocess_info.hpp" #include "pyopenvino/core/infer_request.hpp" #include "pyopenvino/core/offline_transformations.hpp" #include "pyopenvino/core/profiling_info.hpp" #include "pyopenvino/core/tensor.hpp" -#include "pyopenvino/core/tensor_description.hpp" #include "pyopenvino/core/version.hpp" #include "pyopenvino/graph/dimension.hpp" #include "pyopenvino/graph/layout.hpp" @@ -96,28 +90,7 @@ PYBIND11_MODULE(pyopenvino, m) { regclass_graph_Output(m, std::string("Const")); regclass_Core(m); - regclass_IENetwork(m); - - regclass_Data(m); - regclass_TensorDecription(m); - - // Blob will be removed - // Registering template of Blob - regclass_Blob(m); - // Registering specific types of Blobs - regclass_TBlob(m, "Float32"); - regclass_TBlob(m, "Float64"); - regclass_TBlob(m, "Int64"); - regclass_TBlob(m, "Uint64"); - regclass_TBlob(m, "Int32"); - regclass_TBlob(m, "Uint32"); - regclass_TBlob(m, "Int16"); - regclass_TBlob(m, "Uint16"); - regclass_TBlob(m, "Int8"); - regclass_TBlob(m, "Uint8"); - regclass_Tensor(m); - // Registering specific types of containers Containers::regclass_TensorIndexMap(m); Containers::regclass_TensorNameMap(m); @@ -126,10 +99,8 @@ PYBIND11_MODULE(pyopenvino, m) { regclass_InferRequest(m); regclass_Version(m); regclass_Parameter(m); - regclass_InputInfo(m); regclass_AsyncInferQueue(m); regclass_ProfilingInfo(m); - regclass_PreProcessInfo(m); regmodule_offline_transformations(m); } diff --git a/runtime/bindings/python/tests/conftest.py b/runtime/bindings/python/tests/conftest.py index 55fc20acc4f..9efb9a9ecaa 100644 --- a/runtime/bindings/python/tests/conftest.py +++ b/runtime/bindings/python/tests/conftest.py @@ -3,6 +3,7 @@ import os import pytest +import numpy as np import tests @@ -15,6 +16,19 @@ def image_path(): return path_to_img +def read_image(): + import cv2 + n, c, h, w = (1, 3, 32, 32) + image = cv2.imread(image_path()) + if image is None: + raise FileNotFoundError("Input image not found") + + image = cv2.resize(image, (h, w)) / 255 + image = image.transpose((2, 0, 1)).astype(np.float32) + image = image.reshape((n, c, h, w)) + return image + + def model_path(is_myriad=False): path_to_repo = os.environ["MODELS_PATH"] if not is_myriad: diff --git a/runtime/bindings/python/tests/runtime.py b/runtime/bindings/python/tests/runtime.py index 61712cfd314..a1c0bae384b 100644 --- a/runtime/bindings/python/tests/runtime.py +++ b/runtime/bindings/python/tests/runtime.py @@ -99,14 +99,7 @@ class Computation(object): "Expected %s params, received not enough %s values.", len(self.parameters), len(input_values) ) - param_types = [param.get_element_type() for param in self.parameters] param_names = [param.friendly_name for param in self.parameters] - - # ignore not needed input values - input_values = [ - np.array(input_value[0], dtype=get_dtype(input_value[1])) - for input_value in zip(input_values[: len(self.parameters)], param_types) - ] input_shapes = [get_shape(input_value) for input_value in input_values] if self.network_cache.get(str(input_shapes)) is None: @@ -121,8 +114,8 @@ class Computation(object): for parameter, input in zip(self.parameters, input_values): parameter_shape = parameter.get_output_partial_shape(0) - input_shape = PartialShape(input.shape) - if len(input.shape) > 0 and not parameter_shape.compatible(input_shape): + input_shape = PartialShape([]) if isinstance(input, (int, float)) else PartialShape(input.shape) + if not parameter_shape.compatible(input_shape): raise UserInputError( "Provided tensor's shape: %s does not match the expected: %s.", input_shape, diff --git a/runtime/bindings/python/tests/test_inference_engine/test_core.py b/runtime/bindings/python/tests/test_inference_engine/test_core.py index d5ee36a2b2a..400a03b7f5a 100644 --- a/runtime/bindings/python/tests/test_inference_engine/test_core.py +++ b/runtime/bindings/python/tests/test_inference_engine/test_core.py @@ -8,37 +8,34 @@ from sys import platform from pathlib import Path import openvino.opset8 as ov -from openvino import Core, IENetwork, ExecutableNetwork, tensor_from_file -from openvino.impl import Function -from openvino import TensorDesc, Blob +from openvino import Function, Core, ExecutableNetwork, Tensor, tensor_from_file, compile_model + +from ..conftest import model_path, model_onnx_path, plugins_path, read_image -from ..conftest import model_path, model_onnx_path, plugins_path test_net_xml, test_net_bin = model_path() test_net_onnx = model_onnx_path() plugins_xml, plugins_win_xml, plugins_osx_xml = plugins_path() -def test_blobs(): - input_shape = [1, 3, 4, 4] - input_data_float32 = (np.random.rand(*input_shape) - 0.5).astype(np.float32) +def test_compact_api_xml(): + img = read_image() - td = TensorDesc("FP32", input_shape, "NCHW") - - input_blob_float32 = Blob(td, input_data_float32) - - assert np.all(np.equal(input_blob_float32.buffer, input_data_float32)) - - input_data_int16 = (np.random.rand(*input_shape) + 0.5).astype(np.int16) - - td = TensorDesc("I16", input_shape, "NCHW") - - input_blob_i16 = Blob(td, input_data_int16) - - assert np.all(np.equal(input_blob_i16.buffer, input_data_int16)) + model = compile_model(test_net_xml) + assert(isinstance(model, ExecutableNetwork)) + results = model.infer_new_request({"data": img}) + assert np.argmax(results) == 2 + + +def test_compact_api_onnx(): + img = read_image() + + model = compile_model(test_net_onnx) + assert(isinstance(model, ExecutableNetwork)) + results = model.infer_new_request({"data": img}) + assert np.argmax(results) == 2 -@pytest.mark.skip(reason="Fix") def test_core_class(): input_shape = [1, 3, 4, 4] param = ov.parameter(input_shape, np.float32, name="parameter") @@ -46,29 +43,18 @@ def test_core_class(): func = Function([relu], [param], "test") func.get_ordered_ops()[2].friendly_name = "friendly" - cnn_network = IENetwork(func) - core = Core() - core.set_config({}, device_name="CPU") - executable_network = core.compile_model(cnn_network, "CPU", {}) + model = core.compile_model(func, "CPU", {}) - td = TensorDesc("FP32", input_shape, "NCHW") - - # from IPython import embed; embed() - - request = executable_network.create_infer_request() - input_data = np.random.rand(*input_shape) - 0.5 + request = model.create_infer_request() + input_data = np.random.rand(*input_shape).astype(np.float32) - 0.5 expected_output = np.maximum(0.0, input_data) - input_blob = Blob(td, input_data) + input_tensor = Tensor(input_data) + results = request.infer({"parameter": input_tensor}) - request.set_input({"parameter": input_blob}) - request.infer() - - result = request.get_blob("relu").buffer - - assert np.allclose(result, expected_output) + assert np.allclose(results, expected_output) def test_compile_model(device): @@ -119,15 +105,15 @@ def test_read_model_from_onnx_as_path(): assert isinstance(func, Function) -@pytest.mark.xfail("68212") -def test_read_net_from_buffer(): - core = Core() - with open(test_net_bin, "rb") as f: - bin = f.read() - with open(model_path()[0], "rb") as f: - xml = f.read() - func = core.read_model(model=xml, weights=bin) - assert isinstance(func, IENetwork) +# @pytest.mark.xfail("68212") +# def test_read_net_from_buffer(): +# core = Core() +# with open(test_net_bin, "rb") as f: +# bin = f.read() +# with open(model_path()[0], "rb") as f: +# xml = f.read() +# func = core.read_model(model=xml, weights=bin) +# assert isinstance(func, IENetwork) @pytest.mark.xfail("68212") diff --git a/runtime/bindings/python/tests/test_inference_engine/test_executable_network.py b/runtime/bindings/python/tests/test_inference_engine/test_executable_network.py index 2a2e80b6c8c..76d6775fcb4 100644 --- a/runtime/bindings/python/tests/test_inference_engine/test_executable_network.py +++ b/runtime/bindings/python/tests/test_inference_engine/test_executable_network.py @@ -5,7 +5,7 @@ import os import pytest import numpy as np -from ..conftest import model_path, image_path +from ..conftest import model_path, read_image from openvino.impl import Function, ConstOutput, Shape from openvino import Core, Tensor @@ -14,19 +14,6 @@ is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD" test_net_xml, test_net_bin = model_path(is_myriad) -def read_image(): - import cv2 - n, c, h, w = (1, 3, 32, 32) - image = cv2.imread(image_path()) - if image is None: - raise FileNotFoundError("Input image not found") - - image = cv2.resize(image, (h, w)) / 255 - image = image.transpose((2, 0, 1)).astype(np.float32) - image = image.reshape((n, c, h, w)) - return image - - def test_get_metric(device): core = Core() func = core.read_model(model=test_net_xml, weights=test_net_bin) @@ -278,9 +265,9 @@ def test_infer_new_request_wrong_port_name(device): img = read_image() tensor = Tensor(img) exec_net = ie.compile_model(func, device) - with pytest.raises(RuntimeError) as e: + with pytest.raises(KeyError) as e: exec_net.infer_new_request({"_data_": tensor}) - assert "Port for tensor name _data_ was not found." in str(e.value) + assert "Port for tensor named _data_ was not found!" in str(e.value) def test_infer_tensor_wrong_input_data(device): @@ -291,5 +278,5 @@ def test_infer_tensor_wrong_input_data(device): tensor = Tensor(img, shared_memory=True) exec_net = ie.compile_model(func, device) with pytest.raises(TypeError) as e: - exec_net.infer_new_request({4.5: tensor}) - assert "Incompatible key type!" in str(e.value) + exec_net.infer_new_request({0.: tensor}) + assert "Incompatible key type for tensor named: 0." in str(e.value) diff --git a/runtime/bindings/python/tests/test_inference_engine/test_infer_request.py b/runtime/bindings/python/tests/test_inference_engine/test_infer_request.py index 98075c099f2..cf9a17f459f 100644 --- a/runtime/bindings/python/tests/test_inference_engine/test_infer_request.py +++ b/runtime/bindings/python/tests/test_inference_engine/test_infer_request.py @@ -7,26 +7,13 @@ import pytest import datetime import time -from ..conftest import image_path, model_path +from ..conftest import model_path, read_image from openvino import Core, AsyncInferQueue, Tensor, ProfilingInfo is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD" test_net_xml, test_net_bin = model_path(is_myriad) -def read_image(): - import cv2 - n, c, h, w = (1, 3, 32, 32) - image = cv2.imread(image_path()) - if image is None: - raise FileNotFoundError("Input image not found") - - image = cv2.resize(image, (h, w)) / 255 - image = image.transpose((2, 0, 1)).astype(np.float32) - image = image.reshape((n, c, h, w)) - return image - - def test_get_profiling_info(device): core = Core() func = core.read_model(test_net_xml, test_net_bin) @@ -48,8 +35,8 @@ def test_get_profiling_info(device): def test_tensor_setter(device): core = Core() func = core.read_model(test_net_xml, test_net_bin) - exec_net_1 = core.compile_model(network=func, device_name=device) - exec_net_2 = core.compile_model(network=func, device_name=device) + exec_net_1 = core.compile_model(model=func, device_name=device) + exec_net_2 = core.compile_model(model=func, device_name=device) img = read_image() tensor = Tensor(img) @@ -177,18 +164,17 @@ def test_infer_mixed_keys(device): core = Core() func = core.read_model(test_net_xml, test_net_bin) core.set_config({"PERF_COUNT": "YES"}, device) - exec_net = core.compile_model(func, device) + model = core.compile_model(func, device) img = read_image() tensor = Tensor(img) - data2 = np.ones(shape=(1, 10), dtype=np.float32) + data2 = np.ones(shape=img.shape, dtype=np.float32) tensor2 = Tensor(data2) - request = exec_net.create_infer_request() - with pytest.raises(TypeError) as e: - request.infer({0: tensor, "fc_out": tensor2}) - assert "incompatible function arguments!" in str(e.value) + request = model.create_infer_request() + res = request.infer({0: tensor2, "data": tensor}) + assert np.argmax(res) == 2 def test_infer_queue(device): diff --git a/runtime/bindings/python/tests/test_inference_engine/test_tensor.py b/runtime/bindings/python/tests/test_inference_engine/test_tensor.py index 081334013a5..607364d9bc6 100644 --- a/runtime/bindings/python/tests/test_inference_engine/test_tensor.py +++ b/runtime/bindings/python/tests/test_inference_engine/test_tensor.py @@ -4,25 +4,11 @@ import numpy as np import pytest -from ..conftest import image_path +from ..conftest import read_image from openvino import Tensor import openvino as ov -def read_image(): - import cv2 - - n, c, h, w = (1, 3, 32, 32) - image = cv2.imread(image_path()) - if image is None: - raise FileNotFoundError("Input image not found") - - image = cv2.resize(image, (h, w)) / 255 - image = image.transpose((2, 0, 1)).astype(np.float32) - image = image.reshape((n, c, h, w)) - return image - - @pytest.mark.parametrize("ov_type, numpy_dtype", [ (ov.impl.Type.f32, np.float32), (ov.impl.Type.f64, np.float64),