[PYTHON] Improve API with compact functions and Python wrappers (#8603)
* Added new Tensor dispatch and improvements, replace injections with inheritance of pybind objects, remove Blob from public python API. * Clean-up tests and API from unused classes * Remove unused import * cpp codestyle * Update AsyncInferQueue with python wrapper * Codestyle cpp * Applying comments * Common tensor setting for requests
This commit is contained in:
parent
c171be238f
commit
4247a9d6b0
@ -6,47 +6,32 @@
|
|||||||
|
|
||||||
from pkg_resources import get_distribution, DistributionNotFound
|
from pkg_resources import get_distribution, DistributionNotFound
|
||||||
|
|
||||||
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore # mypy issue #1422
|
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore # mypy issue #1422
|
||||||
|
|
||||||
try:
|
try:
|
||||||
__version__ = get_distribution("openvino-core").version
|
__version__ = get_distribution("openvino-core").version
|
||||||
except DistributionNotFound:
|
except DistributionNotFound:
|
||||||
__version__ = "0.0.0.dev0"
|
__version__ = "0.0.0.dev0"
|
||||||
|
|
||||||
from openvino.ie_api import BlobWrapper
|
|
||||||
from openvino.ie_api import infer
|
|
||||||
from openvino.ie_api import start_async
|
|
||||||
from openvino.ie_api import blob_from_file
|
|
||||||
from openvino.ie_api import tensor_from_file
|
|
||||||
from openvino.ie_api import infer_new_request
|
|
||||||
|
|
||||||
|
# Openvino pybind bindings and python extended classes
|
||||||
from openvino.impl import Dimension
|
from openvino.impl import Dimension
|
||||||
from openvino.impl import Function
|
from openvino.impl import Function
|
||||||
from openvino.impl import Node
|
from openvino.impl import Node
|
||||||
from openvino.impl import PartialShape
|
from openvino.impl import PartialShape
|
||||||
from openvino.impl import Layout
|
from openvino.impl import Layout
|
||||||
|
|
||||||
from openvino.pyopenvino import Core
|
from openvino.ie_api import Core
|
||||||
from openvino.pyopenvino import IENetwork
|
from openvino.ie_api import ExecutableNetwork
|
||||||
from openvino.pyopenvino import ExecutableNetwork
|
from openvino.ie_api import InferRequest
|
||||||
|
from openvino.ie_api import AsyncInferQueue
|
||||||
from openvino.pyopenvino import Version
|
from openvino.pyopenvino import Version
|
||||||
from openvino.pyopenvino import Parameter
|
from openvino.pyopenvino import Parameter
|
||||||
from openvino.pyopenvino import InputInfoPtr
|
|
||||||
from openvino.pyopenvino import InputInfoCPtr
|
|
||||||
from openvino.pyopenvino import DataPtr
|
|
||||||
from openvino.pyopenvino import TensorDesc
|
|
||||||
from openvino.pyopenvino import get_version
|
|
||||||
from openvino.pyopenvino import AsyncInferQueue
|
|
||||||
from openvino.pyopenvino import InferRequest # TODO: move to ie_api?
|
|
||||||
from openvino.pyopenvino import Blob
|
|
||||||
from openvino.pyopenvino import PreProcessInfo
|
|
||||||
from openvino.pyopenvino import MeanVariant
|
|
||||||
from openvino.pyopenvino import ResizeAlgorithm
|
|
||||||
from openvino.pyopenvino import ColorFormat
|
|
||||||
from openvino.pyopenvino import PreProcessChannel
|
|
||||||
from openvino.pyopenvino import Tensor
|
from openvino.pyopenvino import Tensor
|
||||||
from openvino.pyopenvino import ProfilingInfo
|
from openvino.pyopenvino import ProfilingInfo
|
||||||
|
from openvino.pyopenvino import get_version
|
||||||
|
|
||||||
|
# Import opsets
|
||||||
from openvino import opset1
|
from openvino import opset1
|
||||||
from openvino import opset2
|
from openvino import opset2
|
||||||
from openvino import opset3
|
from openvino import opset3
|
||||||
@ -56,6 +41,10 @@ from openvino import opset6
|
|||||||
from openvino import opset7
|
from openvino import opset7
|
||||||
from openvino import opset8
|
from openvino import opset8
|
||||||
|
|
||||||
|
# Helper functions for openvino module
|
||||||
|
from openvino.ie_api import tensor_from_file
|
||||||
|
from openvino.ie_api import compile_model
|
||||||
|
|
||||||
# Extend Node class to support binary operators
|
# Extend Node class to support binary operators
|
||||||
Node.__add__ = opset8.add
|
Node.__add__ = opset8.add
|
||||||
Node.__sub__ = opset8.subtract
|
Node.__sub__ = opset8.subtract
|
||||||
@ -73,15 +62,3 @@ Node.__lt__ = opset8.less
|
|||||||
Node.__le__ = opset8.less_equal
|
Node.__le__ = opset8.less_equal
|
||||||
Node.__gt__ = opset8.greater
|
Node.__gt__ = opset8.greater
|
||||||
Node.__ge__ = opset8.greater_equal
|
Node.__ge__ = opset8.greater_equal
|
||||||
|
|
||||||
# Patching for Blob class
|
|
||||||
# flake8: noqa: F811
|
|
||||||
# this class will be removed
|
|
||||||
Blob = BlobWrapper
|
|
||||||
# Patching ExecutableNetwork
|
|
||||||
ExecutableNetwork.infer_new_request = infer_new_request
|
|
||||||
# Patching InferRequest
|
|
||||||
InferRequest.infer = infer
|
|
||||||
InferRequest.start_async = start_async
|
|
||||||
# Patching AsyncInferQueue
|
|
||||||
AsyncInferQueue.start_async = start_async
|
|
||||||
|
@ -3,126 +3,109 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
from typing import List, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
from openvino.pyopenvino import TBlobFloat32
|
from openvino.pyopenvino import Function
|
||||||
from openvino.pyopenvino import TBlobFloat64
|
from openvino.pyopenvino import Core as CoreBase
|
||||||
from openvino.pyopenvino import TBlobInt64
|
from openvino.pyopenvino import ExecutableNetwork as ExecutableNetworkBase
|
||||||
from openvino.pyopenvino import TBlobUint64
|
from openvino.pyopenvino import InferRequest as InferRequestBase
|
||||||
from openvino.pyopenvino import TBlobInt32
|
from openvino.pyopenvino import AsyncInferQueue as AsyncInferQueueBase
|
||||||
from openvino.pyopenvino import TBlobUint32
|
|
||||||
from openvino.pyopenvino import TBlobInt16
|
|
||||||
from openvino.pyopenvino import TBlobUint16
|
|
||||||
from openvino.pyopenvino import TBlobInt8
|
|
||||||
from openvino.pyopenvino import TBlobUint8
|
|
||||||
from openvino.pyopenvino import TensorDesc
|
|
||||||
from openvino.pyopenvino import InferRequest
|
|
||||||
from openvino.pyopenvino import AsyncInferQueue
|
|
||||||
from openvino.pyopenvino import ExecutableNetwork
|
|
||||||
from openvino.pyopenvino import Tensor
|
from openvino.pyopenvino import Tensor
|
||||||
|
|
||||||
|
from openvino.utils.types import get_dtype
|
||||||
precision_map = {"FP32": np.float32,
|
|
||||||
"FP64": np.float64,
|
|
||||||
"FP16": np.int16,
|
|
||||||
"BF16": np.int16,
|
|
||||||
"I16": np.int16,
|
|
||||||
"I8": np.int8,
|
|
||||||
"BIN": np.int8,
|
|
||||||
"I32": np.int32,
|
|
||||||
"I64": np.int64,
|
|
||||||
"U8": np.uint8,
|
|
||||||
"BOOL": np.uint8,
|
|
||||||
"U16": np.uint16,
|
|
||||||
"U32": np.uint32,
|
|
||||||
"U64": np.uint64}
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_inputs(py_dict: dict) -> dict:
|
|
||||||
"""Normalize a dictionary of inputs to contiguous numpy arrays."""
|
|
||||||
return {k: (Tensor(v) if isinstance(v, np.ndarray) else v)
|
|
||||||
for k, v in py_dict.items()}
|
|
||||||
|
|
||||||
# flake8: noqa: D102
|
|
||||||
def infer(request: InferRequest, inputs: dict = {}) -> np.ndarray:
|
|
||||||
res = request._infer(inputs=normalize_inputs(inputs))
|
|
||||||
# Required to return list since np.ndarray forces all of tensors data to match in
|
|
||||||
# dimensions. This results in errors when running ops like variadic split.
|
|
||||||
return [copy.deepcopy(tensor.data) for tensor in res]
|
|
||||||
|
|
||||||
|
|
||||||
def infer_new_request(exec_net: ExecutableNetwork, inputs: dict = None) -> List[np.ndarray]:
|
|
||||||
res = exec_net._infer_new_request(inputs=normalize_inputs(inputs if inputs is not None else {}))
|
|
||||||
# Required to return list since np.ndarray forces all of tensors data to match in
|
|
||||||
# dimensions. This results in errors when running ops like variadic split.
|
|
||||||
return [copy.deepcopy(tensor.data) for tensor in res]
|
|
||||||
|
|
||||||
# flake8: noqa: D102
|
|
||||||
def start_async(request: Union[InferRequest, AsyncInferQueue], inputs: dict = {}, userdata: dict = None) -> None: # type: ignore
|
|
||||||
request._start_async(inputs=normalize_inputs(inputs), userdata=userdata)
|
|
||||||
|
|
||||||
# flake8: noqa: C901
|
|
||||||
# Dispatch Blob types on Python side.
|
|
||||||
class BlobWrapper:
|
|
||||||
def __new__(cls, tensor_desc: TensorDesc, arr: np.ndarray = None): # type: ignore
|
|
||||||
arr_size = 0
|
|
||||||
precision = ""
|
|
||||||
if tensor_desc is not None:
|
|
||||||
tensor_desc_size = int(np.prod(tensor_desc.dims))
|
|
||||||
precision = tensor_desc.precision
|
|
||||||
if arr is not None:
|
|
||||||
arr = np.array(arr) # Keeping array as numpy array
|
|
||||||
arr_size = int(np.prod(arr.shape))
|
|
||||||
if np.isfortran(arr):
|
|
||||||
arr = arr.ravel(order="F")
|
|
||||||
else:
|
|
||||||
arr = arr.ravel(order="C")
|
|
||||||
if arr_size != tensor_desc_size:
|
|
||||||
raise AttributeError(f"Number of elements in provided numpy array "
|
|
||||||
f"{arr_size} and required by TensorDesc "
|
|
||||||
f"{tensor_desc_size} are not equal")
|
|
||||||
if arr.dtype != precision_map[precision]:
|
|
||||||
raise ValueError(f"Data type {arr.dtype} of provided numpy array "
|
|
||||||
f"doesn't match to TensorDesc precision {precision}")
|
|
||||||
if not arr.flags["C_CONTIGUOUS"]:
|
|
||||||
arr = np.ascontiguousarray(arr)
|
|
||||||
elif arr is None:
|
|
||||||
arr = np.empty(0, dtype=precision_map[precision])
|
|
||||||
else:
|
|
||||||
raise AttributeError("TensorDesc can't be None")
|
|
||||||
|
|
||||||
if precision in ["FP32"]:
|
|
||||||
return TBlobFloat32(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["FP64"]:
|
|
||||||
return TBlobFloat64(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["FP16", "BF16"]:
|
|
||||||
return TBlobInt16(tensor_desc, arr.view(dtype=np.int16), arr_size)
|
|
||||||
elif precision in ["I64"]:
|
|
||||||
return TBlobInt64(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["U64"]:
|
|
||||||
return TBlobUint64(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["I32"]:
|
|
||||||
return TBlobInt32(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["U32"]:
|
|
||||||
return TBlobUint32(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["I16"]:
|
|
||||||
return TBlobInt16(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["U16"]:
|
|
||||||
return TBlobUint16(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["I8", "BIN"]:
|
|
||||||
return TBlobInt8(tensor_desc, arr, arr_size)
|
|
||||||
elif precision in ["U8", "BOOL"]:
|
|
||||||
return TBlobUint8(tensor_desc, arr, arr_size)
|
|
||||||
else:
|
|
||||||
raise AttributeError(f"Unsupported precision {precision} for Blob")
|
|
||||||
|
|
||||||
# flake8: noqa: D102
|
|
||||||
def blob_from_file(path_to_bin_file: str) -> BlobWrapper:
|
|
||||||
array = np.fromfile(path_to_bin_file, dtype=np.uint8)
|
|
||||||
tensor_desc = TensorDesc("U8", array.shape, "C")
|
|
||||||
return BlobWrapper(tensor_desc, array)
|
|
||||||
|
|
||||||
# flake8: noqa: D102
|
|
||||||
def tensor_from_file(path: str) -> Tensor:
|
def tensor_from_file(path: str) -> Tensor:
|
||||||
"""The data will be read with dtype of unit8"""
|
"""Create Tensor from file. Data will be read with dtype of unit8."""
|
||||||
return Tensor(np.fromfile(path, dtype=np.uint8))
|
return Tensor(np.fromfile(path, dtype=np.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_inputs(py_dict: dict, py_types: dict) -> dict:
|
||||||
|
"""Normalize a dictionary of inputs to Tensors."""
|
||||||
|
for k, val in py_dict.items():
|
||||||
|
try:
|
||||||
|
if isinstance(k, int):
|
||||||
|
ov_type = list(py_types.values())[k]
|
||||||
|
elif isinstance(k, str):
|
||||||
|
ov_type = py_types[k]
|
||||||
|
else:
|
||||||
|
raise TypeError("Incompatible key type for tensor named: {}".format(k))
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("Port for tensor named {} was not found!".format(k))
|
||||||
|
py_dict[k] = val if isinstance(val, Tensor) else Tensor(np.array(val, get_dtype(ov_type)))
|
||||||
|
return py_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_types(obj: Union[InferRequestBase, ExecutableNetworkBase]) -> dict:
|
||||||
|
"""Get all precisions from object inputs."""
|
||||||
|
return {i.get_node().get_friendly_name(): i.get_node().get_element_type() for i in obj.inputs}
|
||||||
|
|
||||||
|
|
||||||
|
class InferRequest(InferRequestBase):
|
||||||
|
"""InferRequest wrapper."""
|
||||||
|
|
||||||
|
def infer(self, inputs: dict = {}) -> List[np.ndarray]: # noqa: B006
|
||||||
|
"""Infer wrapper for InferRequest."""
|
||||||
|
res = super().infer(inputs=normalize_inputs(inputs, get_input_types(self)))
|
||||||
|
# Required to return list since np.ndarray forces all of tensors data to match in
|
||||||
|
# dimensions. This results in errors when running ops like variadic split.
|
||||||
|
return [copy.deepcopy(tensor.data) for tensor in res]
|
||||||
|
|
||||||
|
def start_async(self, inputs: dict = {}, userdata: Any = None) -> None: # noqa: B006, type: ignore
|
||||||
|
"""Asynchronous infer wrapper for InferRequest."""
|
||||||
|
super().start_async(inputs=normalize_inputs(inputs, get_input_types(self)), userdata=userdata)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutableNetwork(ExecutableNetworkBase):
|
||||||
|
"""ExecutableNetwork wrapper."""
|
||||||
|
|
||||||
|
def create_infer_request(self) -> InferRequest:
|
||||||
|
"""Create new InferRequest object."""
|
||||||
|
return InferRequest(super().create_infer_request())
|
||||||
|
|
||||||
|
def infer_new_request(self, inputs: dict = {}) -> List[np.ndarray]: # noqa: B006
|
||||||
|
"""Infer wrapper for ExecutableNetwork."""
|
||||||
|
res = super().infer_new_request(inputs=normalize_inputs(inputs, get_input_types(self)))
|
||||||
|
# Required to return list since np.ndarray forces all of tensors data to match in
|
||||||
|
# dimensions. This results in errors when running ops like variadic split.
|
||||||
|
return [copy.deepcopy(tensor.data) for tensor in res]
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncInferQueue(AsyncInferQueueBase):
|
||||||
|
"""AsyncInferQueue wrapper."""
|
||||||
|
|
||||||
|
def __getitem__(self, i: int) -> InferRequest:
|
||||||
|
"""Return i-th InferRequest from AsyncInferQueue."""
|
||||||
|
return InferRequest(super().__getitem__(i))
|
||||||
|
|
||||||
|
def start_async(
|
||||||
|
self, inputs: dict = {}, userdata: Any = None # noqa: B006
|
||||||
|
) -> None: # type: ignore
|
||||||
|
"""Asynchronous infer wrapper for AsyncInferQueue."""
|
||||||
|
super().start_async(
|
||||||
|
inputs=normalize_inputs(
|
||||||
|
inputs, get_input_types(self[self.get_idle_request_id()])
|
||||||
|
),
|
||||||
|
userdata=userdata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Core(CoreBase):
|
||||||
|
"""Core wrapper."""
|
||||||
|
|
||||||
|
def compile_model(
|
||||||
|
self, model: Function, device_name: str, config: dict = {} # noqa: B006
|
||||||
|
) -> ExecutableNetwork:
|
||||||
|
"""Compile a model from given Function."""
|
||||||
|
return ExecutableNetwork(super().compile_model(model, device_name, config))
|
||||||
|
|
||||||
|
def import_model(
|
||||||
|
self, model_file: str, device_name: str, config: dict = {} # noqa: B006
|
||||||
|
) -> ExecutableNetwork:
|
||||||
|
"""Compile a model from given model file path."""
|
||||||
|
return ExecutableNetwork(super().import_model(model_file, device_name, config))
|
||||||
|
|
||||||
|
|
||||||
|
def compile_model(model_path: str) -> ExecutableNetwork:
|
||||||
|
"""Compact method to compile model with AUTO plugin."""
|
||||||
|
return Core().compile_model(model_path, "AUTO")
|
||||||
|
@ -53,7 +53,6 @@ public:
|
|||||||
});
|
});
|
||||||
|
|
||||||
return _idle_handles.front();
|
return _idle_handles.front();
|
||||||
;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void wait_all() {
|
void wait_all() {
|
||||||
@ -135,7 +134,7 @@ void regclass_AsyncInferQueue(py::module m) {
|
|||||||
py::arg("jobs") = 0);
|
py::arg("jobs") = 0);
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
"_start_async",
|
"start_async",
|
||||||
[](AsyncInferQueue& self, const py::dict inputs, py::object userdata) {
|
[](AsyncInferQueue& self, const py::dict inputs, py::object userdata) {
|
||||||
// getIdleRequestId function has an intention to block InferQueue
|
// getIdleRequestId function has an intention to block InferQueue
|
||||||
// until there is at least one idle (free to use) InferRequest
|
// until there is at least one idle (free to use) InferRequest
|
||||||
@ -144,19 +143,7 @@ void regclass_AsyncInferQueue(py::module m) {
|
|||||||
// Set new inputs label/id from user
|
// Set new inputs label/id from user
|
||||||
self._user_ids[handle] = userdata;
|
self._user_ids[handle] = userdata;
|
||||||
// Update inputs if there are any
|
// Update inputs if there are any
|
||||||
if (!inputs.empty()) {
|
Common::set_request_tensors(self._requests[handle]._request, inputs);
|
||||||
if (py::isinstance<std::string>(inputs.begin()->first)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_name_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
self._requests[handle]._request.set_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
} else if (py::isinstance<int>(inputs.begin()->first)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_index_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
self._requests[handle]._request.set_input_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Now GIL can be released - we are NOT working with Python objects in this block
|
// Now GIL can be released - we are NOT working with Python objects in this block
|
||||||
{
|
{
|
||||||
py::gil_scoped_release release;
|
py::gil_scoped_release release;
|
||||||
|
@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
|
||||||
|
|
||||||
namespace Common {
|
namespace Common {
|
||||||
const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
|
const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
|
||||||
static const std::map<ov::element::Type, py::dtype> ov_type_to_dtype_mapping = {
|
static const std::map<ov::element::Type, py::dtype> ov_type_to_dtype_mapping = {
|
||||||
@ -45,42 +47,130 @@ const std::map<py::str, ov::element::Type>& dtype_to_ov_type() {
|
|||||||
return dtype_to_ov_type_mapping;
|
return dtype_to_ov_type_mapping;
|
||||||
}
|
}
|
||||||
|
|
||||||
InferenceEngine::Layout get_layout_from_string(const std::string& layout) {
|
ov::runtime::Tensor tensor_from_numpy(py::array& array, bool shared_memory) {
|
||||||
static const std::unordered_map<std::string, InferenceEngine::Layout> layout_str_to_enum = {
|
// Check if passed array has C-style contiguous memory layout.
|
||||||
{"ANY", InferenceEngine::Layout::ANY},
|
bool is_contiguous = C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS);
|
||||||
{"NHWC", InferenceEngine::Layout::NHWC},
|
auto type = Common::dtype_to_ov_type().at(py::str(array.dtype()));
|
||||||
{"NCHW", InferenceEngine::Layout::NCHW},
|
std::vector<size_t> shape(array.shape(), array.shape() + array.ndim());
|
||||||
{"NCDHW", InferenceEngine::Layout::NCDHW},
|
|
||||||
{"NDHWC", InferenceEngine::Layout::NDHWC},
|
// If memory is going to be shared it needs to be contiguous before
|
||||||
{"OIHW", InferenceEngine::Layout::OIHW},
|
// passing to the constructor. This case should be handled by advanced
|
||||||
{"GOIHW", InferenceEngine::Layout::GOIHW},
|
// users on their side of the code.
|
||||||
{"OIDHW", InferenceEngine::Layout::OIDHW},
|
if (shared_memory) {
|
||||||
{"GOIDHW", InferenceEngine::Layout::GOIDHW},
|
if (is_contiguous) {
|
||||||
{"SCALAR", InferenceEngine::Layout::SCALAR},
|
std::vector<size_t> strides(array.strides(), array.strides() + array.ndim());
|
||||||
{"C", InferenceEngine::Layout::C},
|
return ov::runtime::Tensor(type, shape, const_cast<void*>(array.data(0)), strides);
|
||||||
{"CHW", InferenceEngine::Layout::CHW},
|
} else {
|
||||||
{"HW", InferenceEngine::Layout::HW},
|
throw ov::Exception("Tensor with shared memory must be C contiguous!");
|
||||||
{"NC", InferenceEngine::Layout::NC},
|
}
|
||||||
{"CN", InferenceEngine::Layout::CN},
|
}
|
||||||
{"BLOCKED", InferenceEngine::Layout::BLOCKED}};
|
// Convert to contiguous array if not already C-style.
|
||||||
return layout_str_to_enum.at(layout);
|
if (!is_contiguous) {
|
||||||
|
array = Common::as_contiguous(array, type);
|
||||||
|
}
|
||||||
|
// Create actual Tensor and copy data.
|
||||||
|
auto tensor = ov::runtime::Tensor(type, shape);
|
||||||
|
// If ndim of py::array is 0, array is a numpy scalar. That results in size to be equal to 0.
|
||||||
|
// To gain access to actual raw/low-level data, it is needed to use buffer protocol.
|
||||||
|
py::buffer_info buf = array.request();
|
||||||
|
std::memcpy(tensor.data(), buf.ptr, buf.ndim == 0 ? buf.itemsize : buf.itemsize * buf.size);
|
||||||
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& get_layout_from_enum(const InferenceEngine::Layout& layout) {
|
py::array as_contiguous(py::array& array, ov::element::Type type) {
|
||||||
static const std::unordered_map<int, std::string> layout_int_to_str_map = {{0, "ANY"},
|
switch (type) {
|
||||||
{1, "NCHW"},
|
// floating
|
||||||
{2, "NHWC"},
|
case ov::element::f64:
|
||||||
{3, "NCDHW"},
|
return array.cast<py::array_t<double, py::array::c_style | py::array::forcecast>>();
|
||||||
{4, "NDHWC"},
|
case ov::element::f32:
|
||||||
{64, "OIHW"},
|
return array.cast<py::array_t<float, py::array::c_style | py::array::forcecast>>();
|
||||||
{95, "SCALAR"},
|
// signed
|
||||||
{96, "C"},
|
case ov::element::i64:
|
||||||
{128, "CHW"},
|
return array.cast<py::array_t<int64_t, py::array::c_style | py::array::forcecast>>();
|
||||||
{192, "HW"},
|
case ov::element::i32:
|
||||||
{193, "NC"},
|
return array.cast<py::array_t<int32_t, py::array::c_style | py::array::forcecast>>();
|
||||||
{194, "CN"},
|
case ov::element::i16:
|
||||||
{200, "BLOCKED"}};
|
return array.cast<py::array_t<int16_t, py::array::c_style | py::array::forcecast>>();
|
||||||
return layout_int_to_str_map.at(layout);
|
case ov::element::i8:
|
||||||
|
return array.cast<py::array_t<int8_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
// unsigned
|
||||||
|
case ov::element::u64:
|
||||||
|
return array.cast<py::array_t<uint64_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
case ov::element::u32:
|
||||||
|
return array.cast<py::array_t<uint32_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
case ov::element::u16:
|
||||||
|
return array.cast<py::array_t<uint16_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
case ov::element::u8:
|
||||||
|
return array.cast<py::array_t<uint8_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
// other
|
||||||
|
case ov::element::boolean:
|
||||||
|
return array.cast<py::array_t<bool, py::array::c_style | py::array::forcecast>>();
|
||||||
|
case ov::element::u1:
|
||||||
|
return array.cast<py::array_t<uint8_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
// need to create a view on array to cast it correctly
|
||||||
|
case ov::element::f16:
|
||||||
|
case ov::element::bf16:
|
||||||
|
return array.view("int16").cast<py::array_t<int16_t, py::array::c_style | py::array::forcecast>>();
|
||||||
|
default:
|
||||||
|
throw ov::Exception("Tensor cannot be created as contiguous!");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor) {
|
||||||
|
return tensor.cast<const ov::runtime::Tensor&>();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs) {
|
||||||
|
Containers::TensorNameMap result_map;
|
||||||
|
for (auto&& input : inputs) {
|
||||||
|
std::string name;
|
||||||
|
if (py::isinstance<py::str>(input.first)) {
|
||||||
|
name = input.first.cast<std::string>();
|
||||||
|
} else {
|
||||||
|
throw py::type_error("incompatible function arguments!");
|
||||||
|
}
|
||||||
|
if (py::isinstance<ov::runtime::Tensor>(input.second)) {
|
||||||
|
auto tensor = Common::cast_to_tensor(input.second);
|
||||||
|
result_map[name] = tensor;
|
||||||
|
} else {
|
||||||
|
throw ov::Exception("Unable to cast tensor " + name + "!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) {
|
||||||
|
Containers::TensorIndexMap result_map;
|
||||||
|
for (auto&& input : inputs) {
|
||||||
|
int idx;
|
||||||
|
if (py::isinstance<py::int_>(input.first)) {
|
||||||
|
idx = input.first.cast<int>();
|
||||||
|
} else {
|
||||||
|
throw py::type_error("incompatible function arguments!");
|
||||||
|
}
|
||||||
|
if (py::isinstance<ov::runtime::Tensor>(input.second)) {
|
||||||
|
auto tensor = Common::cast_to_tensor(input.second);
|
||||||
|
result_map[idx] = tensor;
|
||||||
|
} else {
|
||||||
|
throw ov::Exception("Unable to cast tensor " + std::to_string(idx) + "!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inputs) {
|
||||||
|
if (!inputs.empty()) {
|
||||||
|
for (auto&& input : inputs) {
|
||||||
|
if (py::isinstance<py::str>(input.first)) {
|
||||||
|
request.set_tensor(input.first.cast<std::string>(), Common::cast_to_tensor(input.second));
|
||||||
|
} else if (py::isinstance<py::int_>(input.first)) {
|
||||||
|
request.set_input_tensor(input.first.cast<size_t>(), Common::cast_to_tensor(input.second));
|
||||||
|
} else {
|
||||||
|
throw py::type_error("Incompatible key type for tensor named: " + input.first.cast<std::string>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
||||||
@ -185,142 +275,6 @@ PyObject* parse_parameter(const InferenceEngine::Parameter& param) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_TBlob(const py::handle& blob) {
|
|
||||||
if (py::isinstance<InferenceEngine::TBlob<float>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<double>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int8_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int16_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int32_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int64_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint8_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint16_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint32_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint64_t>>(blob)) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor) {
|
|
||||||
return tensor.cast<const ov::runtime::Tensor&>();
|
|
||||||
}
|
|
||||||
|
|
||||||
const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs) {
|
|
||||||
Containers::TensorNameMap result_map;
|
|
||||||
for (auto&& input : inputs) {
|
|
||||||
std::string name;
|
|
||||||
if (py::isinstance<py::str>(input.first)) {
|
|
||||||
name = input.first.cast<std::string>();
|
|
||||||
} else {
|
|
||||||
throw py::type_error("incompatible function arguments!");
|
|
||||||
}
|
|
||||||
if (py::isinstance<ov::runtime::Tensor>(input.second)) {
|
|
||||||
auto tensor = Common::cast_to_tensor(input.second);
|
|
||||||
result_map[name] = tensor;
|
|
||||||
} else {
|
|
||||||
throw ov::Exception("Unable to cast tensor " + name + "!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) {
|
|
||||||
Containers::TensorIndexMap result_map;
|
|
||||||
for (auto&& input : inputs) {
|
|
||||||
int idx;
|
|
||||||
if (py::isinstance<py::int_>(input.first)) {
|
|
||||||
idx = input.first.cast<int>();
|
|
||||||
} else {
|
|
||||||
throw py::type_error("incompatible function arguments!");
|
|
||||||
}
|
|
||||||
if (py::isinstance<ov::runtime::Tensor>(input.second)) {
|
|
||||||
auto tensor = Common::cast_to_tensor(input.second);
|
|
||||||
result_map[idx] = tensor;
|
|
||||||
} else {
|
|
||||||
throw ov::Exception("Unable to cast tensor " + std::to_string(idx) + "!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::shared_ptr<InferenceEngine::Blob> cast_to_blob(const py::handle& blob) {
|
|
||||||
if (py::isinstance<InferenceEngine::TBlob<float>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<float>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<double>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<double>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int8_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<int8_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int16_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<int16_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int32_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<int32_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<int64_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<int64_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint8_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<uint8_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint16_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<uint16_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint32_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<uint32_t>>&>();
|
|
||||||
} else if (py::isinstance<InferenceEngine::TBlob<uint64_t>>(blob)) {
|
|
||||||
return blob.cast<const std::shared_ptr<InferenceEngine::TBlob<uint64_t>>&>();
|
|
||||||
} else {
|
|
||||||
IE_THROW() << "Unsupported data type for when casting to blob!";
|
|
||||||
// return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void blob_from_numpy(const py::handle& arr, InferenceEngine::Blob::Ptr blob) {
|
|
||||||
if (py::isinstance<py::array_t<float>>(arr)) {
|
|
||||||
Common::fill_blob<float>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<double>>(arr)) {
|
|
||||||
Common::fill_blob<double>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<bool>>(arr)) {
|
|
||||||
Common::fill_blob<bool>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<int8_t>>(arr)) {
|
|
||||||
Common::fill_blob<int8_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<int16_t>>(arr)) {
|
|
||||||
Common::fill_blob<int16_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<int32_t>>(arr)) {
|
|
||||||
Common::fill_blob<int32_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<int64_t>>(arr)) {
|
|
||||||
Common::fill_blob<int64_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<uint8_t>>(arr)) {
|
|
||||||
Common::fill_blob<uint8_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<uint16_t>>(arr)) {
|
|
||||||
Common::fill_blob<uint16_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<uint32_t>>(arr)) {
|
|
||||||
Common::fill_blob<uint32_t>(arr, blob);
|
|
||||||
} else if (py::isinstance<py::array_t<uint64_t>>(arr)) {
|
|
||||||
Common::fill_blob<uint64_t>(arr, blob);
|
|
||||||
} else {
|
|
||||||
IE_THROW() << "Unsupported data type for when filling blob!";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_request_blobs(InferenceEngine::InferRequest& request, const py::dict& dictonary) {
|
|
||||||
for (auto&& pair : dictonary) {
|
|
||||||
const std::string& name = pair.first.cast<std::string>();
|
|
||||||
if (py::isinstance<py::array>(pair.second)) {
|
|
||||||
Common::blob_from_numpy(pair.second, request.GetBlob(name));
|
|
||||||
} else if (is_TBlob(pair.second)) {
|
|
||||||
request.SetBlob(name, Common::cast_to_blob(pair.second));
|
|
||||||
} else {
|
|
||||||
IE_THROW() << "Unable to set blob " << name << "!";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual) {
|
uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual) {
|
||||||
try {
|
try {
|
||||||
auto parameter_value = actual.get_metric(METRIC_KEY(SUPPORTED_METRICS));
|
auto parameter_value = actual.get_metric(METRIC_KEY(SUPPORTED_METRICS));
|
||||||
|
@ -4,62 +4,43 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cpp/ie_executable_network.hpp>
|
#include <string>
|
||||||
#include <cpp/ie_infer_request.hpp>
|
|
||||||
#include <ie_plugin_config.hpp>
|
#include <pybind11/stl.h>
|
||||||
#include <ie_blob.h>
|
|
||||||
#include <ie_parameter.hpp>
|
|
||||||
#include <pybind11/numpy.h>
|
#include <pybind11/numpy.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <string>
|
|
||||||
|
#include <ie_plugin_config.hpp>
|
||||||
|
#include <ie_parameter.hpp>
|
||||||
|
#include <openvino/core/type/element_type.hpp>
|
||||||
#include "Python.h"
|
#include "Python.h"
|
||||||
#include "ie_common.h"
|
#include "ie_common.h"
|
||||||
#include "openvino/runtime/tensor.hpp"
|
#include "openvino/runtime/tensor.hpp"
|
||||||
#include "openvino/runtime/executable_network.hpp"
|
#include "openvino/runtime/executable_network.hpp"
|
||||||
|
#include "openvino/runtime/infer_request.hpp"
|
||||||
#include "pyopenvino/core/containers.hpp"
|
#include "pyopenvino/core/containers.hpp"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace Common
|
namespace Common
|
||||||
{
|
{
|
||||||
template <typename T>
|
|
||||||
void fill_blob(const py::handle& py_array, InferenceEngine::Blob::Ptr blob)
|
|
||||||
{
|
|
||||||
py::array_t<T> arr = py::cast<py::array>(py_array);
|
|
||||||
if (arr.size() != 0) {
|
|
||||||
// blob->allocate();
|
|
||||||
InferenceEngine::MemoryBlob::Ptr mem_blob = InferenceEngine::as<InferenceEngine::MemoryBlob>(blob);
|
|
||||||
std::copy(
|
|
||||||
arr.data(0), arr.data(0) + arr.size(), mem_blob->rwmap().as<T*>());
|
|
||||||
} else {
|
|
||||||
py::print("Empty array!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype();
|
const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype();
|
||||||
|
|
||||||
const std::map<py::str, ov::element::Type>& dtype_to_ov_type();
|
const std::map<py::str, ov::element::Type>& dtype_to_ov_type();
|
||||||
|
|
||||||
InferenceEngine::Layout get_layout_from_string(const std::string& layout);
|
ov::runtime::Tensor tensor_from_numpy(py::array& array, bool shared_memory);
|
||||||
|
|
||||||
const std::string& get_layout_from_enum(const InferenceEngine::Layout& layout);
|
py::array as_contiguous(py::array& array, ov::element::Type type);
|
||||||
|
|
||||||
PyObject* parse_parameter(const InferenceEngine::Parameter& param);
|
const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor);
|
||||||
|
|
||||||
PyObject* parse_parameter(const InferenceEngine::Parameter& param);
|
|
||||||
|
|
||||||
bool is_TBlob(const py::handle& blob);
|
|
||||||
|
|
||||||
const std::shared_ptr<InferenceEngine::Blob> cast_to_blob(const py::handle& blob);
|
|
||||||
|
|
||||||
const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs);
|
const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs);
|
||||||
|
|
||||||
const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs);
|
const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs);
|
||||||
|
|
||||||
const ov::runtime::Tensor& cast_to_tensor(const py::handle& tensor);
|
void set_request_tensors(ov::runtime::InferRequest& request, const py::dict& inputs);
|
||||||
|
|
||||||
void blob_from_numpy(const py::handle& _arr, InferenceEngine::Blob::Ptr &blob);
|
PyObject* parse_parameter(const InferenceEngine::Parameter& param);
|
||||||
|
|
||||||
void set_request_blobs(InferenceEngine::InferRequest& request, const py::dict& dictonary);
|
|
||||||
|
|
||||||
uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual);
|
uint32_t get_optimal_number_of_requests(const ov::runtime::ExecutableNetwork& actual);
|
||||||
}; // namespace Common
|
}; // namespace Common
|
||||||
|
@ -23,6 +23,7 @@ std::string to_string(py::handle handle) {
|
|||||||
|
|
||||||
void regclass_Core(py::module m) {
|
void regclass_Core(py::module m) {
|
||||||
py::class_<ov::runtime::Core, std::shared_ptr<ov::runtime::Core>> cls(m, "Core");
|
py::class_<ov::runtime::Core, std::shared_ptr<ov::runtime::Core>> cls(m, "Core");
|
||||||
|
|
||||||
cls.def(py::init<const std::string&>(), py::arg("xml_config_file") = "");
|
cls.def(py::init<const std::string&>(), py::arg("xml_config_file") = "");
|
||||||
|
|
||||||
cls.def("set_config",
|
cls.def("set_config",
|
||||||
@ -35,7 +36,15 @@ void regclass_Core(py::module m) {
|
|||||||
(ov::runtime::ExecutableNetwork(
|
(ov::runtime::ExecutableNetwork(
|
||||||
ov::runtime::Core::*)(const std::shared_ptr<const ov::Function>&, const std::string&, const ConfigMap&)) &
|
ov::runtime::Core::*)(const std::shared_ptr<const ov::Function>&, const std::string&, const ConfigMap&)) &
|
||||||
ov::runtime::Core::compile_model,
|
ov::runtime::Core::compile_model,
|
||||||
py::arg("network"),
|
py::arg("model"),
|
||||||
|
py::arg("device_name"),
|
||||||
|
py::arg("config") = py::dict());
|
||||||
|
|
||||||
|
cls.def("compile_model",
|
||||||
|
(ov::runtime::ExecutableNetwork(
|
||||||
|
ov::runtime::Core::*)(const std::string&, const std::string&, const ConfigMap&)) &
|
||||||
|
ov::runtime::Core::compile_model,
|
||||||
|
py::arg("model_path"),
|
||||||
py::arg("device_name"),
|
py::arg("device_name"),
|
||||||
py::arg("config") = py::dict());
|
py::arg("config") = py::dict());
|
||||||
|
|
||||||
|
@ -20,31 +20,21 @@ void regclass_ExecutableNetwork(py::module m) {
|
|||||||
m,
|
m,
|
||||||
"ExecutableNetwork");
|
"ExecutableNetwork");
|
||||||
|
|
||||||
|
cls.def(py::init([](ov::runtime::ExecutableNetwork& other) {
|
||||||
|
return other;
|
||||||
|
}),
|
||||||
|
py::arg("other"));
|
||||||
|
|
||||||
cls.def("create_infer_request", [](ov::runtime::ExecutableNetwork& self) {
|
cls.def("create_infer_request", [](ov::runtime::ExecutableNetwork& self) {
|
||||||
return InferRequestWrapper(self.create_infer_request(), self.inputs(), self.outputs());
|
return InferRequestWrapper(self.create_infer_request(), self.inputs(), self.outputs());
|
||||||
});
|
});
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
"_infer_new_request",
|
"infer_new_request",
|
||||||
[](ov::runtime::ExecutableNetwork& self, const py::dict& inputs) {
|
[](ov::runtime::ExecutableNetwork& self, const py::dict& inputs) {
|
||||||
auto request = self.create_infer_request();
|
auto request = self.create_infer_request();
|
||||||
const auto key = inputs.begin()->first;
|
// Update inputs if there are any
|
||||||
if (!inputs.empty()) {
|
Common::set_request_tensors(request, inputs);
|
||||||
if (py::isinstance<py::str>(key)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_name_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
request.set_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
} else if (py::isinstance<py::int_>(key)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_index_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
request.set_input_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw py::type_error("Incompatible key type! Supported types are string and int.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
request.infer();
|
request.infer();
|
||||||
|
|
||||||
Containers::InferResults results;
|
Containers::InferResults results;
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "ie_blob.h"
|
|
||||||
|
|
||||||
#include <ie_blob.h>
|
|
||||||
#include <ie_common.h>
|
|
||||||
#include <ie_layouts.h>
|
|
||||||
|
|
||||||
#include <ie_precision.hpp>
|
|
||||||
|
|
||||||
#include "pyopenvino/core/ie_blob.hpp"
|
|
||||||
#include "pyopenvino/core/tensor_description.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_Blob(py::module m) {
|
|
||||||
py::class_<InferenceEngine::Blob, std::shared_ptr<InferenceEngine::Blob>> cls(m, "Blob");
|
|
||||||
}
|
|
@ -1,56 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <pybind11/numpy.h>
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
|
|
||||||
#include "ie_blob.h"
|
|
||||||
#include "ie_common.h"
|
|
||||||
#include "ie_layouts.h"
|
|
||||||
#include "ie_precision.hpp"
|
|
||||||
|
|
||||||
#include "pyopenvino/core/tensor_description.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_Blob(py::module m);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void regclass_TBlob(py::module m, std::string typestring)
|
|
||||||
{
|
|
||||||
auto pyclass_name = py::detail::c_str((std::string("TBlob") + typestring));
|
|
||||||
|
|
||||||
py::class_<InferenceEngine::TBlob<T>, std::shared_ptr<InferenceEngine::TBlob<T>>> cls(
|
|
||||||
m, pyclass_name);
|
|
||||||
|
|
||||||
cls.def(py::init(
|
|
||||||
[](const InferenceEngine::TensorDesc& tensorDesc, py::array_t<T>& arr, size_t size = 0) {
|
|
||||||
auto blob = InferenceEngine::make_shared_blob<T>(tensorDesc);
|
|
||||||
blob->allocate();
|
|
||||||
if (size != 0) {
|
|
||||||
std::copy(arr.data(0), arr.data(0) + size, blob->rwmap().template as<T *>());
|
|
||||||
}
|
|
||||||
return blob;
|
|
||||||
}));
|
|
||||||
|
|
||||||
cls.def_property_readonly("buffer", [](InferenceEngine::TBlob<T>& self) {
|
|
||||||
auto blob_ptr = self.buffer().template as<T*>();
|
|
||||||
auto shape = self.getTensorDesc().getDims();
|
|
||||||
return py::array_t<T>(shape, &blob_ptr[0], py::cast(self));
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def_property_readonly("tensor_desc",
|
|
||||||
[](InferenceEngine::TBlob<T>& self) { return self.getTensorDesc(); });
|
|
||||||
|
|
||||||
cls.def("__str__", [](InferenceEngine::TBlob<T>& self) -> std::string {
|
|
||||||
std::stringstream ss;
|
|
||||||
auto blob_ptr = self.buffer().template as<T*>();
|
|
||||||
auto shape = self.getTensorDesc().getDims();
|
|
||||||
auto py_arr = py::array_t<T>(shape, &blob_ptr[0], py::cast(self));
|
|
||||||
ss << py_arr;
|
|
||||||
return ss.str();
|
|
||||||
});
|
|
||||||
}
|
|
@ -1,39 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "pyopenvino/core/ie_data.hpp"
|
|
||||||
|
|
||||||
#include <ie_data.h>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
|
|
||||||
#include "common.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_Data(py::module m) {
|
|
||||||
py::class_<InferenceEngine::Data, std::shared_ptr<InferenceEngine::Data>> cls(m, "DataPtr");
|
|
||||||
|
|
||||||
cls.def_property(
|
|
||||||
"layout",
|
|
||||||
[](InferenceEngine::Data& self) {
|
|
||||||
return Common::get_layout_from_enum(self.getLayout());
|
|
||||||
},
|
|
||||||
[](InferenceEngine::Data& self, const std::string& layout) {
|
|
||||||
self.setLayout(Common::get_layout_from_string(layout));
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def_property(
|
|
||||||
"precision",
|
|
||||||
[](InferenceEngine::Data& self) {
|
|
||||||
return self.getPrecision().name();
|
|
||||||
},
|
|
||||||
[](InferenceEngine::Data& self, const std::string& precision) {
|
|
||||||
self.setPrecision(InferenceEngine::Precision::FromStr(precision));
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def_property_readonly("shape", &InferenceEngine::Data::getDims);
|
|
||||||
|
|
||||||
cls.def_property_readonly("name", &InferenceEngine::Data::getName);
|
|
||||||
// cls.def_property_readonly("initialized", );
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_Data(py::module m);
|
|
@ -1,78 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "pyopenvino/core/ie_input_info.hpp"
|
|
||||||
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
|
|
||||||
#include <ie_input_info.hpp>
|
|
||||||
|
|
||||||
#include "common.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
class ConstInputInfoWrapper {
|
|
||||||
public:
|
|
||||||
ConstInputInfoWrapper() = default;
|
|
||||||
~ConstInputInfoWrapper() = default;
|
|
||||||
const InferenceEngine::InputInfo& cref() const {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
const InferenceEngine::InputInfo& ref() {
|
|
||||||
return this->value;
|
|
||||||
}
|
|
||||||
const InferenceEngine::InputInfo value = InferenceEngine::InputInfo();
|
|
||||||
};
|
|
||||||
|
|
||||||
void regclass_InputInfo(py::module m) {
|
|
||||||
// Workaround for constant class
|
|
||||||
py::class_<ConstInputInfoWrapper, std::shared_ptr<ConstInputInfoWrapper>> cls_const(m, "InputInfoCPtr");
|
|
||||||
|
|
||||||
cls_const.def(py::init<>());
|
|
||||||
|
|
||||||
cls_const.def_property_readonly("input_data", [](const ConstInputInfoWrapper& self) {
|
|
||||||
return self.cref().getInputData();
|
|
||||||
});
|
|
||||||
cls_const.def_property_readonly("precision", [](const ConstInputInfoWrapper& self) {
|
|
||||||
return self.cref().getPrecision().name();
|
|
||||||
});
|
|
||||||
cls_const.def_property_readonly("tensor_desc", [](const ConstInputInfoWrapper& self) {
|
|
||||||
return self.cref().getTensorDesc();
|
|
||||||
});
|
|
||||||
cls_const.def_property_readonly("name", [](const ConstInputInfoWrapper& self) {
|
|
||||||
return self.cref().name();
|
|
||||||
});
|
|
||||||
// Mutable version
|
|
||||||
py::class_<InferenceEngine::InputInfo, std::shared_ptr<InferenceEngine::InputInfo>> cls(m, "InputInfoPtr");
|
|
||||||
|
|
||||||
cls.def(py::init<>());
|
|
||||||
|
|
||||||
cls.def_property("input_data",
|
|
||||||
&InferenceEngine::InputInfo::getInputData,
|
|
||||||
&InferenceEngine::InputInfo::setInputData);
|
|
||||||
cls.def_property(
|
|
||||||
"layout",
|
|
||||||
[](InferenceEngine::InputInfo& self) {
|
|
||||||
return Common::get_layout_from_enum(self.getLayout());
|
|
||||||
},
|
|
||||||
[](InferenceEngine::InputInfo& self, const std::string& layout) {
|
|
||||||
self.setLayout(Common::get_layout_from_string(layout));
|
|
||||||
});
|
|
||||||
cls.def_property(
|
|
||||||
"precision",
|
|
||||||
[](InferenceEngine::InputInfo& self) {
|
|
||||||
return self.getPrecision().name();
|
|
||||||
},
|
|
||||||
[](InferenceEngine::InputInfo& self, const std::string& precision) {
|
|
||||||
self.setPrecision(InferenceEngine::Precision::FromStr(precision));
|
|
||||||
});
|
|
||||||
cls.def_property_readonly("tensor_desc", &InferenceEngine::InputInfo::getTensorDesc);
|
|
||||||
cls.def_property_readonly("name", &InferenceEngine::InputInfo::name);
|
|
||||||
cls.def_property_readonly("preprocess_info", [](InferenceEngine::InputInfo& self) {
|
|
||||||
InferenceEngine::PreProcessInfo& preprocess = self.getPreProcess();
|
|
||||||
return preprocess;
|
|
||||||
});
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_InputInfo(py::module m);
|
|
@ -1,87 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "pyopenvino/core/ie_network.hpp"
|
|
||||||
|
|
||||||
#include <cpp/ie_cnn_network.h>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
#include <pybind11/stl_bind.h>
|
|
||||||
|
|
||||||
#include <ie_input_info.hpp>
|
|
||||||
|
|
||||||
#include "openvino/core/function.hpp"
|
|
||||||
#include "pyopenvino/core/containers.hpp"
|
|
||||||
#include "pyopenvino/core/ie_input_info.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_IENetwork(py::module m) {
|
|
||||||
py::class_<InferenceEngine::CNNNetwork, std::shared_ptr<InferenceEngine::CNNNetwork>> cls(m, "IENetwork");
|
|
||||||
cls.def(py::init());
|
|
||||||
|
|
||||||
cls.def(py::init([](std::shared_ptr<ov::Function>& function) {
|
|
||||||
InferenceEngine::CNNNetwork cnnNetwork(function);
|
|
||||||
return std::make_shared<InferenceEngine::CNNNetwork>(cnnNetwork);
|
|
||||||
}));
|
|
||||||
|
|
||||||
cls.def("reshape",
|
|
||||||
[](InferenceEngine::CNNNetwork& self, const std::map<std::string, std::vector<size_t>>& input_shapes) {
|
|
||||||
self.reshape(input_shapes);
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def(
|
|
||||||
"add_outputs",
|
|
||||||
[](InferenceEngine::CNNNetwork& self, py::handle& outputs) {
|
|
||||||
int i = 0;
|
|
||||||
py::list _outputs;
|
|
||||||
if (!py::isinstance<py::list>(outputs)) {
|
|
||||||
if (py::isinstance<py::str>(outputs)) {
|
|
||||||
_outputs.append(outputs.cast<py::str>());
|
|
||||||
} else if (py::isinstance<py::tuple>(outputs)) {
|
|
||||||
_outputs.append(outputs.cast<py::tuple>());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_outputs = outputs.cast<py::list>();
|
|
||||||
}
|
|
||||||
for (py::handle output : _outputs) {
|
|
||||||
if (py::isinstance<py::str>(_outputs[i])) {
|
|
||||||
self.addOutput(output.cast<std::string>(), 0);
|
|
||||||
} else if (py::isinstance<py::tuple>(output)) {
|
|
||||||
py::tuple output_tuple = output.cast<py::tuple>();
|
|
||||||
self.addOutput(output_tuple[0].cast<std::string>(), output_tuple[1].cast<int>());
|
|
||||||
} else {
|
|
||||||
IE_THROW() << "Incorrect type " << output.get_type() << "for layer to add at index " << i
|
|
||||||
<< ". Expected string with layer name or tuple with two elements: layer name as "
|
|
||||||
"first element and port id as second";
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
py::arg("outputs"));
|
|
||||||
cls.def("add_output", &InferenceEngine::CNNNetwork::addOutput, py::arg("layer_name"), py::arg("output_index") = 0);
|
|
||||||
|
|
||||||
cls.def(
|
|
||||||
"serialize",
|
|
||||||
[](InferenceEngine::CNNNetwork& self, const std::string& path_to_xml, const std::string& path_to_bin) {
|
|
||||||
self.serialize(path_to_xml, path_to_bin);
|
|
||||||
},
|
|
||||||
py::arg("path_to_xml"),
|
|
||||||
py::arg("path_to_bin") = "");
|
|
||||||
|
|
||||||
cls.def("get_function", [](InferenceEngine::CNNNetwork& self) {
|
|
||||||
return self.getFunction();
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def("get_ov_name_for_tensor", &InferenceEngine::CNNNetwork::getOVNameForTensor, py::arg("orig_name"));
|
|
||||||
|
|
||||||
cls.def_property("batch_size",
|
|
||||||
&InferenceEngine::CNNNetwork::getBatchSize,
|
|
||||||
&InferenceEngine::CNNNetwork::setBatchSize);
|
|
||||||
|
|
||||||
cls.def_property_readonly("outputs", [](InferenceEngine::CNNNetwork& self) {
|
|
||||||
return self.getOutputsInfo();
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def_property_readonly("name", &InferenceEngine::CNNNetwork::getName);
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_IENetwork(py::module m);
|
|
@ -1,70 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "pyopenvino/core/ie_preprocess_info.hpp"
|
|
||||||
|
|
||||||
#include <ie_common.h>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
|
|
||||||
#include <ie_preprocess.hpp>
|
|
||||||
|
|
||||||
#include "pyopenvino/core/common.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_PreProcessInfo(py::module m) {
|
|
||||||
py::class_<InferenceEngine::PreProcessChannel, std::shared_ptr<InferenceEngine::PreProcessChannel>>(
|
|
||||||
m,
|
|
||||||
"PreProcessChannel")
|
|
||||||
.def_readwrite("std_scale", &InferenceEngine::PreProcessChannel::stdScale)
|
|
||||||
.def_readwrite("mean_value", &InferenceEngine::PreProcessChannel::meanValue)
|
|
||||||
.def_readwrite("mean_data", &InferenceEngine::PreProcessChannel::meanData);
|
|
||||||
|
|
||||||
py::class_<InferenceEngine::PreProcessInfo> cls(m, "PreProcessInfo");
|
|
||||||
|
|
||||||
cls.def(py::init());
|
|
||||||
cls.def("__getitem__", [](InferenceEngine::PreProcessInfo& self, size_t& index) {
|
|
||||||
return self[index];
|
|
||||||
});
|
|
||||||
cls.def("get_number_of_channels", &InferenceEngine::PreProcessInfo::getNumberOfChannels);
|
|
||||||
cls.def("init", &InferenceEngine::PreProcessInfo::init);
|
|
||||||
cls.def("set_mean_image", [](InferenceEngine::PreProcessInfo& self, py::handle meanImage) {
|
|
||||||
self.setMeanImage(Common::cast_to_blob(meanImage));
|
|
||||||
});
|
|
||||||
cls.def("set_mean_image_for_channel",
|
|
||||||
[](InferenceEngine::PreProcessInfo& self, py::handle meanImage, const size_t channel) {
|
|
||||||
self.setMeanImageForChannel(Common::cast_to_blob(meanImage), channel);
|
|
||||||
});
|
|
||||||
cls.def_property("mean_variant",
|
|
||||||
&InferenceEngine::PreProcessInfo::getMeanVariant,
|
|
||||||
&InferenceEngine::PreProcessInfo::setVariant);
|
|
||||||
cls.def_property("resize_algorithm",
|
|
||||||
&InferenceEngine::PreProcessInfo::getResizeAlgorithm,
|
|
||||||
&InferenceEngine::PreProcessInfo::setResizeAlgorithm);
|
|
||||||
cls.def_property("color_format",
|
|
||||||
&InferenceEngine::PreProcessInfo::getColorFormat,
|
|
||||||
&InferenceEngine::PreProcessInfo::setColorFormat);
|
|
||||||
|
|
||||||
py::enum_<InferenceEngine::MeanVariant>(m, "MeanVariant")
|
|
||||||
.value("MEAN_IMAGE", InferenceEngine::MeanVariant::MEAN_IMAGE)
|
|
||||||
.value("MEAN_VALUE", InferenceEngine::MeanVariant::MEAN_VALUE)
|
|
||||||
.value("NONE", InferenceEngine::MeanVariant::NONE)
|
|
||||||
.export_values();
|
|
||||||
|
|
||||||
py::enum_<InferenceEngine::ResizeAlgorithm>(m, "ResizeAlgorithm")
|
|
||||||
.value("NO_RESIZE", InferenceEngine::ResizeAlgorithm::NO_RESIZE)
|
|
||||||
.value("RESIZE_BILINEAR", InferenceEngine::ResizeAlgorithm::RESIZE_BILINEAR)
|
|
||||||
.value("RESIZE_AREA", InferenceEngine::ResizeAlgorithm::RESIZE_AREA)
|
|
||||||
.export_values();
|
|
||||||
|
|
||||||
py::enum_<InferenceEngine::ColorFormat>(m, "ColorFormat")
|
|
||||||
.value("RAW", InferenceEngine::ColorFormat::RAW)
|
|
||||||
.value("RGB", InferenceEngine::ColorFormat::RGB)
|
|
||||||
.value("BGR", InferenceEngine::ColorFormat::BGR)
|
|
||||||
.value("RGBX", InferenceEngine::ColorFormat::RGBX)
|
|
||||||
.value("BGRX", InferenceEngine::ColorFormat::BGRX)
|
|
||||||
.value("NV12", InferenceEngine::ColorFormat::NV12)
|
|
||||||
.value("I420", InferenceEngine::ColorFormat::I420)
|
|
||||||
.export_values();
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_PreProcessInfo(py::module m);
|
|
@ -20,6 +20,12 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
void regclass_InferRequest(py::module m) {
|
void regclass_InferRequest(py::module m) {
|
||||||
py::class_<InferRequestWrapper, std::shared_ptr<InferRequestWrapper>> cls(m, "InferRequest");
|
py::class_<InferRequestWrapper, std::shared_ptr<InferRequestWrapper>> cls(m, "InferRequest");
|
||||||
|
|
||||||
|
cls.def(py::init([](InferRequestWrapper& other) {
|
||||||
|
return other;
|
||||||
|
}),
|
||||||
|
py::arg("other"));
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
"set_tensors",
|
"set_tensors",
|
||||||
[](InferRequestWrapper& self, const py::dict& inputs) {
|
[](InferRequestWrapper& self, const py::dict& inputs) {
|
||||||
@ -51,22 +57,10 @@ void regclass_InferRequest(py::module m) {
|
|||||||
py::arg("inputs"));
|
py::arg("inputs"));
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
"_infer",
|
"infer",
|
||||||
[](InferRequestWrapper& self, const py::dict& inputs) {
|
[](InferRequestWrapper& self, const py::dict& inputs) {
|
||||||
// Update inputs if there are any
|
// Update inputs if there are any
|
||||||
if (!inputs.empty()) {
|
Common::set_request_tensors(self._request, inputs);
|
||||||
if (py::isinstance<py::str>(inputs.begin()->first)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_name_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
self._request.set_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
} else if (py::isinstance<py::int_>(inputs.begin()->first)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_index_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
self._request.set_input_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Call Infer function
|
// Call Infer function
|
||||||
self._start_time = Time::now();
|
self._start_time = Time::now();
|
||||||
self._request.infer();
|
self._request.infer();
|
||||||
@ -80,23 +74,11 @@ void regclass_InferRequest(py::module m) {
|
|||||||
py::arg("inputs"));
|
py::arg("inputs"));
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
"_start_async",
|
"start_async",
|
||||||
[](InferRequestWrapper& self, const py::dict& inputs, py::object& userdata) {
|
[](InferRequestWrapper& self, const py::dict& inputs, py::object& userdata) {
|
||||||
// Update inputs if there are any
|
// Update inputs if there are any
|
||||||
if (!inputs.empty()) {
|
Common::set_request_tensors(self._request, inputs);
|
||||||
if (py::isinstance<std::string>(inputs.begin()->first)) {
|
if (!userdata.is(py::none())) {
|
||||||
auto inputs_map = Common::cast_to_tensor_name_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
self._request.set_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
} else if (py::isinstance<int>(inputs.begin()->first)) {
|
|
||||||
auto inputs_map = Common::cast_to_tensor_index_map(inputs);
|
|
||||||
for (auto&& input : inputs_map) {
|
|
||||||
self._request.set_input_tensor(input.first, input.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (userdata != py::none()) {
|
|
||||||
if (self.user_callback_defined) {
|
if (self.user_callback_defined) {
|
||||||
self.userdata = userdata;
|
self.userdata = userdata;
|
||||||
} else {
|
} else {
|
||||||
|
@ -10,28 +10,13 @@
|
|||||||
#include "openvino/runtime/tensor.hpp"
|
#include "openvino/runtime/tensor.hpp"
|
||||||
#include "pyopenvino/core/common.hpp"
|
#include "pyopenvino/core/common.hpp"
|
||||||
|
|
||||||
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
void regclass_Tensor(py::module m) {
|
void regclass_Tensor(py::module m) {
|
||||||
py::class_<ov::runtime::Tensor, std::shared_ptr<ov::runtime::Tensor>> cls(m, "Tensor");
|
py::class_<ov::runtime::Tensor, std::shared_ptr<ov::runtime::Tensor>> cls(m, "Tensor");
|
||||||
|
|
||||||
cls.def(py::init([](py::array& array, bool shared_memory) {
|
cls.def(py::init([](py::array& array, bool shared_memory) {
|
||||||
auto type = Common::dtype_to_ov_type().at(py::str(array.dtype()));
|
return Common::tensor_from_numpy(array, shared_memory);
|
||||||
std::vector<size_t> shape(array.shape(), array.shape() + array.ndim());
|
|
||||||
if (shared_memory) {
|
|
||||||
if (C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS)) {
|
|
||||||
std::vector<size_t> strides(array.strides(), array.strides() + array.ndim());
|
|
||||||
return ov::runtime::Tensor(type, shape, const_cast<void*>(array.data(0)), strides);
|
|
||||||
} else {
|
|
||||||
IE_THROW() << "Tensor with shared memory must be C contiguous!";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
array = py::module::import("numpy").attr("ascontiguousarray")(array).cast<py::array>();
|
|
||||||
auto tensor = ov::runtime::Tensor(type, shape);
|
|
||||||
std::memcpy(tensor.data(), array.data(0), array.nbytes());
|
|
||||||
return tensor;
|
|
||||||
}),
|
}),
|
||||||
py::arg("array"),
|
py::arg("array"),
|
||||||
py::arg("shared_memory") = false);
|
py::arg("shared_memory") = false);
|
||||||
|
@ -1,58 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "pyopenvino/core/tensor_description.hpp"
|
|
||||||
|
|
||||||
#include <ie_common.h>
|
|
||||||
#include <ie_layouts.h>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
|
|
||||||
#include <ie_precision.hpp>
|
|
||||||
|
|
||||||
#include "common.hpp"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
using namespace InferenceEngine;
|
|
||||||
|
|
||||||
void regclass_TensorDecription(py::module m) {
|
|
||||||
py::class_<TensorDesc, std::shared_ptr<TensorDesc>> cls(m, "TensorDesc");
|
|
||||||
cls.def(py::init<const Precision&, const SizeVector&, Layout>());
|
|
||||||
cls.def(py::init([](const std::string& precision, const SizeVector& dims, const std::string& layout) {
|
|
||||||
return TensorDesc(Precision::FromStr(precision), dims, Common::get_layout_from_string(layout));
|
|
||||||
}));
|
|
||||||
|
|
||||||
cls.def_property(
|
|
||||||
"layout",
|
|
||||||
[](TensorDesc& self) {
|
|
||||||
return Common::get_layout_from_enum(self.getLayout());
|
|
||||||
},
|
|
||||||
[](TensorDesc& self, const std::string& layout) {
|
|
||||||
self.setLayout(Common::get_layout_from_string(layout));
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def_property(
|
|
||||||
"precision",
|
|
||||||
[](TensorDesc& self) {
|
|
||||||
return self.getPrecision().name();
|
|
||||||
},
|
|
||||||
[](TensorDesc& self, const std::string& precision) {
|
|
||||||
self.setPrecision(InferenceEngine::Precision::FromStr(precision));
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def_property(
|
|
||||||
"dims",
|
|
||||||
[](TensorDesc& self) {
|
|
||||||
return self.getDims();
|
|
||||||
},
|
|
||||||
[](TensorDesc& self, const SizeVector& dims) {
|
|
||||||
self.setDims(dims);
|
|
||||||
});
|
|
||||||
|
|
||||||
cls.def(
|
|
||||||
"__eq__",
|
|
||||||
[](const TensorDesc& a, const TensorDesc b) {
|
|
||||||
return a == b;
|
|
||||||
},
|
|
||||||
py::is_operator());
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
// Copyright (C) 2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
void regclass_TensorDecription(py::module m);
|
|
@ -23,17 +23,11 @@
|
|||||||
#include "pyopenvino/core/containers.hpp"
|
#include "pyopenvino/core/containers.hpp"
|
||||||
#include "pyopenvino/core/core.hpp"
|
#include "pyopenvino/core/core.hpp"
|
||||||
#include "pyopenvino/core/executable_network.hpp"
|
#include "pyopenvino/core/executable_network.hpp"
|
||||||
#include "pyopenvino/core/ie_blob.hpp"
|
|
||||||
#include "pyopenvino/core/ie_data.hpp"
|
|
||||||
#include "pyopenvino/core/ie_input_info.hpp"
|
|
||||||
#include "pyopenvino/core/ie_network.hpp"
|
|
||||||
#include "pyopenvino/core/ie_parameter.hpp"
|
#include "pyopenvino/core/ie_parameter.hpp"
|
||||||
#include "pyopenvino/core/ie_preprocess_info.hpp"
|
|
||||||
#include "pyopenvino/core/infer_request.hpp"
|
#include "pyopenvino/core/infer_request.hpp"
|
||||||
#include "pyopenvino/core/offline_transformations.hpp"
|
#include "pyopenvino/core/offline_transformations.hpp"
|
||||||
#include "pyopenvino/core/profiling_info.hpp"
|
#include "pyopenvino/core/profiling_info.hpp"
|
||||||
#include "pyopenvino/core/tensor.hpp"
|
#include "pyopenvino/core/tensor.hpp"
|
||||||
#include "pyopenvino/core/tensor_description.hpp"
|
|
||||||
#include "pyopenvino/core/version.hpp"
|
#include "pyopenvino/core/version.hpp"
|
||||||
#include "pyopenvino/graph/dimension.hpp"
|
#include "pyopenvino/graph/dimension.hpp"
|
||||||
#include "pyopenvino/graph/layout.hpp"
|
#include "pyopenvino/graph/layout.hpp"
|
||||||
@ -96,28 +90,7 @@ PYBIND11_MODULE(pyopenvino, m) {
|
|||||||
regclass_graph_Output<const ov::Node>(m, std::string("Const"));
|
regclass_graph_Output<const ov::Node>(m, std::string("Const"));
|
||||||
|
|
||||||
regclass_Core(m);
|
regclass_Core(m);
|
||||||
regclass_IENetwork(m);
|
|
||||||
|
|
||||||
regclass_Data(m);
|
|
||||||
regclass_TensorDecription(m);
|
|
||||||
|
|
||||||
// Blob will be removed
|
|
||||||
// Registering template of Blob
|
|
||||||
regclass_Blob(m);
|
|
||||||
// Registering specific types of Blobs
|
|
||||||
regclass_TBlob<float>(m, "Float32");
|
|
||||||
regclass_TBlob<double>(m, "Float64");
|
|
||||||
regclass_TBlob<int64_t>(m, "Int64");
|
|
||||||
regclass_TBlob<uint64_t>(m, "Uint64");
|
|
||||||
regclass_TBlob<int32_t>(m, "Int32");
|
|
||||||
regclass_TBlob<uint32_t>(m, "Uint32");
|
|
||||||
regclass_TBlob<int16_t>(m, "Int16");
|
|
||||||
regclass_TBlob<uint16_t>(m, "Uint16");
|
|
||||||
regclass_TBlob<int8_t>(m, "Int8");
|
|
||||||
regclass_TBlob<uint8_t>(m, "Uint8");
|
|
||||||
|
|
||||||
regclass_Tensor(m);
|
regclass_Tensor(m);
|
||||||
|
|
||||||
// Registering specific types of containers
|
// Registering specific types of containers
|
||||||
Containers::regclass_TensorIndexMap(m);
|
Containers::regclass_TensorIndexMap(m);
|
||||||
Containers::regclass_TensorNameMap(m);
|
Containers::regclass_TensorNameMap(m);
|
||||||
@ -126,10 +99,8 @@ PYBIND11_MODULE(pyopenvino, m) {
|
|||||||
regclass_InferRequest(m);
|
regclass_InferRequest(m);
|
||||||
regclass_Version(m);
|
regclass_Version(m);
|
||||||
regclass_Parameter(m);
|
regclass_Parameter(m);
|
||||||
regclass_InputInfo(m);
|
|
||||||
regclass_AsyncInferQueue(m);
|
regclass_AsyncInferQueue(m);
|
||||||
regclass_ProfilingInfo(m);
|
regclass_ProfilingInfo(m);
|
||||||
regclass_PreProcessInfo(m);
|
|
||||||
|
|
||||||
regmodule_offline_transformations(m);
|
regmodule_offline_transformations(m);
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import tests
|
import tests
|
||||||
|
|
||||||
@ -15,6 +16,19 @@ def image_path():
|
|||||||
return path_to_img
|
return path_to_img
|
||||||
|
|
||||||
|
|
||||||
|
def read_image():
|
||||||
|
import cv2
|
||||||
|
n, c, h, w = (1, 3, 32, 32)
|
||||||
|
image = cv2.imread(image_path())
|
||||||
|
if image is None:
|
||||||
|
raise FileNotFoundError("Input image not found")
|
||||||
|
|
||||||
|
image = cv2.resize(image, (h, w)) / 255
|
||||||
|
image = image.transpose((2, 0, 1)).astype(np.float32)
|
||||||
|
image = image.reshape((n, c, h, w))
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
def model_path(is_myriad=False):
|
def model_path(is_myriad=False):
|
||||||
path_to_repo = os.environ["MODELS_PATH"]
|
path_to_repo = os.environ["MODELS_PATH"]
|
||||||
if not is_myriad:
|
if not is_myriad:
|
||||||
|
@ -99,14 +99,7 @@ class Computation(object):
|
|||||||
"Expected %s params, received not enough %s values.", len(self.parameters), len(input_values)
|
"Expected %s params, received not enough %s values.", len(self.parameters), len(input_values)
|
||||||
)
|
)
|
||||||
|
|
||||||
param_types = [param.get_element_type() for param in self.parameters]
|
|
||||||
param_names = [param.friendly_name for param in self.parameters]
|
param_names = [param.friendly_name for param in self.parameters]
|
||||||
|
|
||||||
# ignore not needed input values
|
|
||||||
input_values = [
|
|
||||||
np.array(input_value[0], dtype=get_dtype(input_value[1]))
|
|
||||||
for input_value in zip(input_values[: len(self.parameters)], param_types)
|
|
||||||
]
|
|
||||||
input_shapes = [get_shape(input_value) for input_value in input_values]
|
input_shapes = [get_shape(input_value) for input_value in input_values]
|
||||||
|
|
||||||
if self.network_cache.get(str(input_shapes)) is None:
|
if self.network_cache.get(str(input_shapes)) is None:
|
||||||
@ -121,8 +114,8 @@ class Computation(object):
|
|||||||
|
|
||||||
for parameter, input in zip(self.parameters, input_values):
|
for parameter, input in zip(self.parameters, input_values):
|
||||||
parameter_shape = parameter.get_output_partial_shape(0)
|
parameter_shape = parameter.get_output_partial_shape(0)
|
||||||
input_shape = PartialShape(input.shape)
|
input_shape = PartialShape([]) if isinstance(input, (int, float)) else PartialShape(input.shape)
|
||||||
if len(input.shape) > 0 and not parameter_shape.compatible(input_shape):
|
if not parameter_shape.compatible(input_shape):
|
||||||
raise UserInputError(
|
raise UserInputError(
|
||||||
"Provided tensor's shape: %s does not match the expected: %s.",
|
"Provided tensor's shape: %s does not match the expected: %s.",
|
||||||
input_shape,
|
input_shape,
|
||||||
|
@ -8,37 +8,34 @@ from sys import platform
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import openvino.opset8 as ov
|
import openvino.opset8 as ov
|
||||||
from openvino import Core, IENetwork, ExecutableNetwork, tensor_from_file
|
from openvino import Function, Core, ExecutableNetwork, Tensor, tensor_from_file, compile_model
|
||||||
from openvino.impl import Function
|
|
||||||
from openvino import TensorDesc, Blob
|
from ..conftest import model_path, model_onnx_path, plugins_path, read_image
|
||||||
|
|
||||||
from ..conftest import model_path, model_onnx_path, plugins_path
|
|
||||||
|
|
||||||
test_net_xml, test_net_bin = model_path()
|
test_net_xml, test_net_bin = model_path()
|
||||||
test_net_onnx = model_onnx_path()
|
test_net_onnx = model_onnx_path()
|
||||||
plugins_xml, plugins_win_xml, plugins_osx_xml = plugins_path()
|
plugins_xml, plugins_win_xml, plugins_osx_xml = plugins_path()
|
||||||
|
|
||||||
|
|
||||||
def test_blobs():
|
def test_compact_api_xml():
|
||||||
input_shape = [1, 3, 4, 4]
|
img = read_image()
|
||||||
input_data_float32 = (np.random.rand(*input_shape) - 0.5).astype(np.float32)
|
|
||||||
|
|
||||||
td = TensorDesc("FP32", input_shape, "NCHW")
|
model = compile_model(test_net_xml)
|
||||||
|
assert(isinstance(model, ExecutableNetwork))
|
||||||
input_blob_float32 = Blob(td, input_data_float32)
|
results = model.infer_new_request({"data": img})
|
||||||
|
assert np.argmax(results) == 2
|
||||||
assert np.all(np.equal(input_blob_float32.buffer, input_data_float32))
|
|
||||||
|
|
||||||
input_data_int16 = (np.random.rand(*input_shape) + 0.5).astype(np.int16)
|
def test_compact_api_onnx():
|
||||||
|
img = read_image()
|
||||||
td = TensorDesc("I16", input_shape, "NCHW")
|
|
||||||
|
model = compile_model(test_net_onnx)
|
||||||
input_blob_i16 = Blob(td, input_data_int16)
|
assert(isinstance(model, ExecutableNetwork))
|
||||||
|
results = model.infer_new_request({"data": img})
|
||||||
assert np.all(np.equal(input_blob_i16.buffer, input_data_int16))
|
assert np.argmax(results) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Fix")
|
|
||||||
def test_core_class():
|
def test_core_class():
|
||||||
input_shape = [1, 3, 4, 4]
|
input_shape = [1, 3, 4, 4]
|
||||||
param = ov.parameter(input_shape, np.float32, name="parameter")
|
param = ov.parameter(input_shape, np.float32, name="parameter")
|
||||||
@ -46,29 +43,18 @@ def test_core_class():
|
|||||||
func = Function([relu], [param], "test")
|
func = Function([relu], [param], "test")
|
||||||
func.get_ordered_ops()[2].friendly_name = "friendly"
|
func.get_ordered_ops()[2].friendly_name = "friendly"
|
||||||
|
|
||||||
cnn_network = IENetwork(func)
|
|
||||||
|
|
||||||
core = Core()
|
core = Core()
|
||||||
core.set_config({}, device_name="CPU")
|
model = core.compile_model(func, "CPU", {})
|
||||||
executable_network = core.compile_model(cnn_network, "CPU", {})
|
|
||||||
|
|
||||||
td = TensorDesc("FP32", input_shape, "NCHW")
|
request = model.create_infer_request()
|
||||||
|
input_data = np.random.rand(*input_shape).astype(np.float32) - 0.5
|
||||||
# from IPython import embed; embed()
|
|
||||||
|
|
||||||
request = executable_network.create_infer_request()
|
|
||||||
input_data = np.random.rand(*input_shape) - 0.5
|
|
||||||
|
|
||||||
expected_output = np.maximum(0.0, input_data)
|
expected_output = np.maximum(0.0, input_data)
|
||||||
|
|
||||||
input_blob = Blob(td, input_data)
|
input_tensor = Tensor(input_data)
|
||||||
|
results = request.infer({"parameter": input_tensor})
|
||||||
|
|
||||||
request.set_input({"parameter": input_blob})
|
assert np.allclose(results, expected_output)
|
||||||
request.infer()
|
|
||||||
|
|
||||||
result = request.get_blob("relu").buffer
|
|
||||||
|
|
||||||
assert np.allclose(result, expected_output)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compile_model(device):
|
def test_compile_model(device):
|
||||||
@ -119,15 +105,15 @@ def test_read_model_from_onnx_as_path():
|
|||||||
assert isinstance(func, Function)
|
assert isinstance(func, Function)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail("68212")
|
# @pytest.mark.xfail("68212")
|
||||||
def test_read_net_from_buffer():
|
# def test_read_net_from_buffer():
|
||||||
core = Core()
|
# core = Core()
|
||||||
with open(test_net_bin, "rb") as f:
|
# with open(test_net_bin, "rb") as f:
|
||||||
bin = f.read()
|
# bin = f.read()
|
||||||
with open(model_path()[0], "rb") as f:
|
# with open(model_path()[0], "rb") as f:
|
||||||
xml = f.read()
|
# xml = f.read()
|
||||||
func = core.read_model(model=xml, weights=bin)
|
# func = core.read_model(model=xml, weights=bin)
|
||||||
assert isinstance(func, IENetwork)
|
# assert isinstance(func, IENetwork)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail("68212")
|
@pytest.mark.xfail("68212")
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..conftest import model_path, image_path
|
from ..conftest import model_path, read_image
|
||||||
from openvino.impl import Function, ConstOutput, Shape
|
from openvino.impl import Function, ConstOutput, Shape
|
||||||
|
|
||||||
from openvino import Core, Tensor
|
from openvino import Core, Tensor
|
||||||
@ -14,19 +14,6 @@ is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
|
|||||||
test_net_xml, test_net_bin = model_path(is_myriad)
|
test_net_xml, test_net_bin = model_path(is_myriad)
|
||||||
|
|
||||||
|
|
||||||
def read_image():
|
|
||||||
import cv2
|
|
||||||
n, c, h, w = (1, 3, 32, 32)
|
|
||||||
image = cv2.imread(image_path())
|
|
||||||
if image is None:
|
|
||||||
raise FileNotFoundError("Input image not found")
|
|
||||||
|
|
||||||
image = cv2.resize(image, (h, w)) / 255
|
|
||||||
image = image.transpose((2, 0, 1)).astype(np.float32)
|
|
||||||
image = image.reshape((n, c, h, w))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_metric(device):
|
def test_get_metric(device):
|
||||||
core = Core()
|
core = Core()
|
||||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||||
@ -278,9 +265,9 @@ def test_infer_new_request_wrong_port_name(device):
|
|||||||
img = read_image()
|
img = read_image()
|
||||||
tensor = Tensor(img)
|
tensor = Tensor(img)
|
||||||
exec_net = ie.compile_model(func, device)
|
exec_net = ie.compile_model(func, device)
|
||||||
with pytest.raises(RuntimeError) as e:
|
with pytest.raises(KeyError) as e:
|
||||||
exec_net.infer_new_request({"_data_": tensor})
|
exec_net.infer_new_request({"_data_": tensor})
|
||||||
assert "Port for tensor name _data_ was not found." in str(e.value)
|
assert "Port for tensor named _data_ was not found!" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
def test_infer_tensor_wrong_input_data(device):
|
def test_infer_tensor_wrong_input_data(device):
|
||||||
@ -291,5 +278,5 @@ def test_infer_tensor_wrong_input_data(device):
|
|||||||
tensor = Tensor(img, shared_memory=True)
|
tensor = Tensor(img, shared_memory=True)
|
||||||
exec_net = ie.compile_model(func, device)
|
exec_net = ie.compile_model(func, device)
|
||||||
with pytest.raises(TypeError) as e:
|
with pytest.raises(TypeError) as e:
|
||||||
exec_net.infer_new_request({4.5: tensor})
|
exec_net.infer_new_request({0.: tensor})
|
||||||
assert "Incompatible key type!" in str(e.value)
|
assert "Incompatible key type for tensor named: 0." in str(e.value)
|
||||||
|
@ -7,26 +7,13 @@ import pytest
|
|||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from ..conftest import image_path, model_path
|
from ..conftest import model_path, read_image
|
||||||
from openvino import Core, AsyncInferQueue, Tensor, ProfilingInfo
|
from openvino import Core, AsyncInferQueue, Tensor, ProfilingInfo
|
||||||
|
|
||||||
is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
|
is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
|
||||||
test_net_xml, test_net_bin = model_path(is_myriad)
|
test_net_xml, test_net_bin = model_path(is_myriad)
|
||||||
|
|
||||||
|
|
||||||
def read_image():
|
|
||||||
import cv2
|
|
||||||
n, c, h, w = (1, 3, 32, 32)
|
|
||||||
image = cv2.imread(image_path())
|
|
||||||
if image is None:
|
|
||||||
raise FileNotFoundError("Input image not found")
|
|
||||||
|
|
||||||
image = cv2.resize(image, (h, w)) / 255
|
|
||||||
image = image.transpose((2, 0, 1)).astype(np.float32)
|
|
||||||
image = image.reshape((n, c, h, w))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_profiling_info(device):
|
def test_get_profiling_info(device):
|
||||||
core = Core()
|
core = Core()
|
||||||
func = core.read_model(test_net_xml, test_net_bin)
|
func = core.read_model(test_net_xml, test_net_bin)
|
||||||
@ -48,8 +35,8 @@ def test_get_profiling_info(device):
|
|||||||
def test_tensor_setter(device):
|
def test_tensor_setter(device):
|
||||||
core = Core()
|
core = Core()
|
||||||
func = core.read_model(test_net_xml, test_net_bin)
|
func = core.read_model(test_net_xml, test_net_bin)
|
||||||
exec_net_1 = core.compile_model(network=func, device_name=device)
|
exec_net_1 = core.compile_model(model=func, device_name=device)
|
||||||
exec_net_2 = core.compile_model(network=func, device_name=device)
|
exec_net_2 = core.compile_model(model=func, device_name=device)
|
||||||
|
|
||||||
img = read_image()
|
img = read_image()
|
||||||
tensor = Tensor(img)
|
tensor = Tensor(img)
|
||||||
@ -177,18 +164,17 @@ def test_infer_mixed_keys(device):
|
|||||||
core = Core()
|
core = Core()
|
||||||
func = core.read_model(test_net_xml, test_net_bin)
|
func = core.read_model(test_net_xml, test_net_bin)
|
||||||
core.set_config({"PERF_COUNT": "YES"}, device)
|
core.set_config({"PERF_COUNT": "YES"}, device)
|
||||||
exec_net = core.compile_model(func, device)
|
model = core.compile_model(func, device)
|
||||||
|
|
||||||
img = read_image()
|
img = read_image()
|
||||||
tensor = Tensor(img)
|
tensor = Tensor(img)
|
||||||
|
|
||||||
data2 = np.ones(shape=(1, 10), dtype=np.float32)
|
data2 = np.ones(shape=img.shape, dtype=np.float32)
|
||||||
tensor2 = Tensor(data2)
|
tensor2 = Tensor(data2)
|
||||||
|
|
||||||
request = exec_net.create_infer_request()
|
request = model.create_infer_request()
|
||||||
with pytest.raises(TypeError) as e:
|
res = request.infer({0: tensor2, "data": tensor})
|
||||||
request.infer({0: tensor, "fc_out": tensor2})
|
assert np.argmax(res) == 2
|
||||||
assert "incompatible function arguments!" in str(e.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_infer_queue(device):
|
def test_infer_queue(device):
|
||||||
|
@ -4,25 +4,11 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import image_path
|
from ..conftest import read_image
|
||||||
from openvino import Tensor
|
from openvino import Tensor
|
||||||
import openvino as ov
|
import openvino as ov
|
||||||
|
|
||||||
|
|
||||||
def read_image():
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
n, c, h, w = (1, 3, 32, 32)
|
|
||||||
image = cv2.imread(image_path())
|
|
||||||
if image is None:
|
|
||||||
raise FileNotFoundError("Input image not found")
|
|
||||||
|
|
||||||
image = cv2.resize(image, (h, w)) / 255
|
|
||||||
image = image.transpose((2, 0, 1)).astype(np.float32)
|
|
||||||
image = image.reshape((n, c, h, w))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("ov_type, numpy_dtype", [
|
@pytest.mark.parametrize("ov_type, numpy_dtype", [
|
||||||
(ov.impl.Type.f32, np.float32),
|
(ov.impl.Type.f32, np.float32),
|
||||||
(ov.impl.Type.f64, np.float64),
|
(ov.impl.Type.f64, np.float64),
|
||||||
|
Loading…
Reference in New Issue
Block a user