[PyOV] OVDict class - new return value from inference (#16370)

This commit is contained in:
Jan Iwaszkiewicz 2023-03-22 16:12:07 +01:00 committed by GitHub
parent 8509d0dd82
commit 4561aa7109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 489 additions and 173 deletions

View File

@ -1 +1,2 @@
numpy>=1.16.6 numpy>=1.16.6
singledispatchmethod; python_version<'3.8'

View File

@ -40,3 +40,4 @@ types-pkg_resources
wheel>=0.38.1 wheel>=0.38.1
protobuf~=3.18.1 protobuf~=3.18.1
numpy>=1.16.6,<=1.23.4 numpy>=1.16.6,<=1.23.4
singledispatchmethod; python_version<'3.8'

View File

@ -2,7 +2,6 @@
# Copyright (C) 2018-2023 Intel Corporation # Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import singledispatch
from typing import Any, Iterable, Union, Dict, Optional from typing import Any, Iterable, Union, Dict, Optional
from pathlib import Path from pathlib import Path
@ -16,6 +15,7 @@ from openvino._pyopenvino import ConstOutput
from openvino._pyopenvino import Tensor from openvino._pyopenvino import Tensor
from openvino.runtime.utils.data_helpers import ( from openvino.runtime.utils.data_helpers import (
OVDict,
_InferRequestWrapper, _InferRequestWrapper,
_data_dispatch, _data_dispatch,
tensor_from_file, tensor_from_file,
@ -25,7 +25,7 @@ from openvino.runtime.utils.data_helpers import (
class InferRequest(_InferRequestWrapper): class InferRequest(_InferRequestWrapper):
"""InferRequest class represents infer request which can be run in asynchronous or synchronous manners.""" """InferRequest class represents infer request which can be run in asynchronous or synchronous manners."""
def infer(self, inputs: Any = None, shared_memory: bool = False) -> dict: def infer(self, inputs: Any = None, shared_memory: bool = False) -> OVDict:
"""Infers specified input(s) in synchronous mode. """Infers specified input(s) in synchronous mode.
Blocks all methods of InferRequest while request is running. Blocks all methods of InferRequest while request is running.
@ -68,14 +68,14 @@ class InferRequest(_InferRequestWrapper):
Default value: False Default value: False
:type shared_memory: bool, optional :type shared_memory: bool, optional
:return: Dictionary of results from output tensors with ports as keys. :return: Dictionary of results from output tensors with port/int/str keys.
:rtype: Dict[openvino.runtime.ConstOutput, numpy.ndarray] :rtype: OVDict
""" """
return super().infer(_data_dispatch( return OVDict(super().infer(_data_dispatch(
self, self,
inputs, inputs,
is_shared=shared_memory, is_shared=shared_memory,
)) )))
def start_async( def start_async(
self, self,
@ -138,6 +138,15 @@ class InferRequest(_InferRequestWrapper):
userdata, userdata,
) )
@property
def results(self) -> OVDict:
"""Gets all outputs tensors of this InferRequest.
:return: Dictionary of results from output tensors with ports as keys.
:rtype: Dict[openvino.runtime.ConstOutput, numpy.array]
"""
return OVDict(super().results)
class CompiledModel(CompiledModelBase): class CompiledModel(CompiledModelBase):
"""CompiledModel class. """CompiledModel class.
@ -161,7 +170,7 @@ class CompiledModel(CompiledModelBase):
""" """
return InferRequest(super().create_infer_request()) return InferRequest(super().create_infer_request())
def infer_new_request(self, inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None) -> dict: def infer_new_request(self, inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None) -> OVDict:
"""Infers specified input(s) in synchronous mode. """Infers specified input(s) in synchronous mode.
Blocks all methods of CompiledModel while request is running. Blocks all methods of CompiledModel while request is running.
@ -187,8 +196,8 @@ class CompiledModel(CompiledModelBase):
:param inputs: Data to be set on input tensors. :param inputs: Data to be set on input tensors.
:type inputs: Union[Dict[keys, values], List[values], Tuple[values], Tensor, numpy.ndarray], optional :type inputs: Union[Dict[keys, values], List[values], Tuple[values], Tensor, numpy.ndarray], optional
:return: Dictionary of results from output tensors with ports as keys. :return: Dictionary of results from output tensors with port/int/str keys.
:rtype: Dict[openvino.runtime.ConstOutput, numpy.array] :rtype: OVDict
""" """
# It returns wrapped python InferReqeust and then call upon # It returns wrapped python InferReqeust and then call upon
# overloaded functions of InferRequest class # overloaded functions of InferRequest class
@ -196,7 +205,7 @@ class CompiledModel(CompiledModelBase):
def __call__(self, def __call__(self,
inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None, inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None,
shared_memory: bool = True) -> dict: shared_memory: bool = True) -> OVDict:
"""Callable infer wrapper for CompiledModel. """Callable infer wrapper for CompiledModel.
Infers specified input(s) in synchronous mode. Infers specified input(s) in synchronous mode.
@ -248,8 +257,8 @@ class CompiledModel(CompiledModelBase):
Default value: True Default value: True
:type shared_memory: bool, optional :type shared_memory: bool, optional
:return: Dictionary of results from output tensors with ports as keys. :return: Dictionary of results from output tensors with port/int/str as keys.
:rtype: Dict[openvino.runtime.ConstOutput, numpy.ndarray] :rtype: OVDict
""" """
if self._infer_request is None: if self._infer_request is None:
self._infer_request = self.create_infer_request() self._infer_request = self.create_infer_request()

View File

@ -5,3 +5,4 @@
from openvino.runtime.utils.data_helpers.data_dispatcher import _data_dispatch from openvino.runtime.utils.data_helpers.data_dispatcher import _data_dispatch
from openvino.runtime.utils.data_helpers.wrappers import tensor_from_file from openvino.runtime.utils.data_helpers.wrappers import tensor_from_file
from openvino.runtime.utils.data_helpers.wrappers import _InferRequestWrapper from openvino.runtime.utils.data_helpers.wrappers import _InferRequestWrapper
from openvino.runtime.utils.data_helpers.wrappers import OVDict

View File

@ -4,7 +4,17 @@
import numpy as np import numpy as np
from openvino._pyopenvino import Tensor # TODO: remove this WA and refactor OVDict when Python3.8
# becomes minimal supported version.
try:
from functools import singledispatchmethod
except ImportError:
from singledispatchmethod import singledispatchmethod # type: ignore[no-redef]
from collections.abc import Mapping
from typing import Union, Dict, List, Iterator, KeysView, ItemsView, ValuesView
from openvino._pyopenvino import Tensor, ConstOutput
from openvino._pyopenvino import InferRequest as InferRequestBase from openvino._pyopenvino import InferRequest as InferRequestBase
@ -20,3 +30,109 @@ class _InferRequestWrapper(InferRequestBase):
# Private memeber to store newly created shared memory data # Private memeber to store newly created shared memory data
self._inputs_data = None self._inputs_data = None
super().__init__(other) super().__init__(other)
class OVDict(Mapping):
"""Custom OpenVINO dictionary with inference results.
This class is a dict-like object. It provides possibility to
address data tensors with three key types:
* `openvino.runtime.ConstOutput` - port of the output
* `int` - index of the output
* `str` - names of the output
This class follows `frozenset`/`tuple` concept of immutability.
It is prohibited to assign new items or edit them.
To revert to the previous behavior use `to_dict` method which
return shallow copy of underlaying dictionary.
Note: It removes addressing feature! New dictionary keeps
only `ConstOutput` keys.
If a tuple returns value is needed, use `to_tuple` method which
converts values to the tuple.
:Example:
.. code-block:: python
# Reverts to the previous behavior of the native dict
result = request.infer(inputs).to_dict()
# or alternatively:
result = dict(request.infer(inputs))
.. code-block:: python
# To dispatch outputs of multi-ouput inference:
out1, out2, out3, _ = request.infer(inputs).values()
# or alternatively:
out1, out2, out3, _ = request.infer(inputs).to_tuple()
"""
def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None:
self._dict = _dict
def __iter__(self) -> Iterator:
return self._dict.__iter__()
def __len__(self) -> int:
return len(self._dict)
def __repr__(self) -> str:
return self._dict.__repr__()
def __get_key(self, index: int) -> ConstOutput:
return list(self._dict.keys())[index]
@singledispatchmethod
def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
raise TypeError("Unknown key type!")
@__getitem_impl.register
def _(self, key: ConstOutput) -> np.ndarray:
return self._dict[key]
@__getitem_impl.register
def _(self, key: int) -> np.ndarray:
try:
return self._dict[self.__get_key(key)]
except IndexError:
raise KeyError(key)
@__getitem_impl.register
def _(self, key: str) -> np.ndarray:
try:
return self._dict[self.__get_key(self.names().index(key))]
except ValueError:
raise KeyError(key)
def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
return self.__getitem_impl(key)
def keys(self) -> KeysView[ConstOutput]:
return self._dict.keys()
def values(self) -> ValuesView[np.ndarray]:
return self._dict.values()
def items(self) -> ItemsView[ConstOutput, np.ndarray]:
return self._dict.items()
def names(self) -> List[str]:
"""Return a name of every output key.
Throws RuntimeError if any of ConstOutput keys has no name.
"""
return [key.get_any_name() for key in self._dict.keys()]
def to_dict(self) -> Dict[ConstOutput, np.ndarray]:
"""Return underlaying native dictionary.
Function performs shallow copy, thus any modifications to
returned values may affect this class as well.
"""
return self._dict
def to_tuple(self) -> tuple:
"""Convert values of this dictionary to a tuple."""
return tuple(self._dict.values())

View File

@ -53,6 +53,27 @@ const std::map<std::string, ov::element::Type>& dtype_to_ov_type() {
return dtype_to_ov_type_mapping; return dtype_to_ov_type_mapping;
} }
namespace containers {
const TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) {
TensorIndexMap result_map;
for (auto&& input : inputs) {
int idx;
if (py::isinstance<py::int_>(input.first)) {
idx = input.first.cast<int>();
} else {
throw py::type_error("incompatible function arguments!");
}
if (py::isinstance<ov::Tensor>(input.second)) {
auto tensor = Common::cast_to_tensor(input.second);
result_map[idx] = tensor;
} else {
throw ov::Exception("Unable to cast tensor " + std::to_string(idx) + "!");
}
}
return result_map;
}
}; // namespace containers
namespace array_helpers { namespace array_helpers {
bool is_contiguous(const py::array& array) { bool is_contiguous(const py::array& array) {
@ -110,6 +131,67 @@ py::array as_contiguous(py::array& array, ov::element::Type type) {
} }
} }
py::array array_from_tensor(ov::Tensor&& t) {
switch (t.get_element_type()) {
case ov::element::Type_t::f32: {
return py::array_t<float>(t.get_shape(), t.data<float>());
break;
}
case ov::element::Type_t::f64: {
return py::array_t<double>(t.get_shape(), t.data<double>());
break;
}
case ov::element::Type_t::bf16: {
return py::array(py::dtype("float16"), t.get_shape(), t.data<ov::bfloat16>());
break;
}
case ov::element::Type_t::f16: {
return py::array(py::dtype("float16"), t.get_shape(), t.data<ov::float16>());
break;
}
case ov::element::Type_t::i8: {
return py::array_t<int8_t>(t.get_shape(), t.data<int8_t>());
break;
}
case ov::element::Type_t::i16: {
return py::array_t<int16_t>(t.get_shape(), t.data<int16_t>());
break;
}
case ov::element::Type_t::i32: {
return py::array_t<int32_t>(t.get_shape(), t.data<int32_t>());
break;
}
case ov::element::Type_t::i64: {
return py::array_t<int64_t>(t.get_shape(), t.data<int64_t>());
break;
}
case ov::element::Type_t::u8: {
return py::array_t<uint8_t>(t.get_shape(), t.data<uint8_t>());
break;
}
case ov::element::Type_t::u16: {
return py::array_t<uint16_t>(t.get_shape(), t.data<uint16_t>());
break;
}
case ov::element::Type_t::u32: {
return py::array_t<uint32_t>(t.get_shape(), t.data<uint32_t>());
break;
}
case ov::element::Type_t::u64: {
return py::array_t<uint64_t>(t.get_shape(), t.data<uint64_t>());
break;
}
case ov::element::Type_t::boolean: {
return py::array_t<bool>(t.get_shape(), t.data<bool>());
break;
}
default: {
throw ov::Exception("Numpy array cannot be created from given OV Tensor!");
break;
}
}
}
}; // namespace array_helpers }; // namespace array_helpers
template <> template <>
@ -226,38 +308,6 @@ const ov::Tensor& cast_to_tensor(const py::handle& tensor) {
return tensor.cast<const ov::Tensor&>(); return tensor.cast<const ov::Tensor&>();
} }
const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs) {
Containers::TensorNameMap result_map;
for (auto&& input : inputs) {
std::string name;
if (py::isinstance<py::str>(input.first)) {
name = input.first.cast<std::string>();
} else {
throw py::type_error("incompatible function arguments!");
}
OPENVINO_ASSERT(py::isinstance<ov::Tensor>(input.second), "Unable to cast tensor ", name, "!");
auto tensor = Common::cast_to_tensor(input.second);
result_map[name] = tensor;
}
return result_map;
}
const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) {
Containers::TensorIndexMap result_map;
for (auto&& input : inputs) {
int idx;
if (py::isinstance<py::int_>(input.first)) {
idx = input.first.cast<int>();
} else {
throw py::type_error("incompatible function arguments!");
}
OPENVINO_ASSERT(py::isinstance<ov::Tensor>(input.second), "Unable to cast tensor ", idx, "!");
auto tensor = Common::cast_to_tensor(input.second);
result_map[idx] = tensor;
}
return result_map;
}
void set_request_tensors(ov::InferRequest& request, const py::dict& inputs) { void set_request_tensors(ov::InferRequest& request, const py::dict& inputs) {
if (!inputs.empty()) { if (!inputs.empty()) {
for (auto&& input : inputs) { for (auto&& input : inputs) {
@ -293,67 +343,10 @@ uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual) {
} }
} }
py::dict outputs_to_dict(const std::vector<ov::Output<const ov::Node>>& outputs, ov::InferRequest& request) { py::dict outputs_to_dict(InferRequestWrapper& request) {
py::dict res; py::dict res;
for (const auto& out : outputs) { for (const auto& out : request.m_outputs) {
ov::Tensor t{request.get_tensor(out)}; res[py::cast(out)] = array_helpers::array_from_tensor(request.m_request.get_tensor(out));
switch (t.get_element_type()) {
case ov::element::Type_t::i8: {
res[py::cast(out)] = py::array_t<int8_t>(t.get_shape(), t.data<int8_t>());
break;
}
case ov::element::Type_t::i16: {
res[py::cast(out)] = py::array_t<int16_t>(t.get_shape(), t.data<int16_t>());
break;
}
case ov::element::Type_t::i32: {
res[py::cast(out)] = py::array_t<int32_t>(t.get_shape(), t.data<int32_t>());
break;
}
case ov::element::Type_t::i64: {
res[py::cast(out)] = py::array_t<int64_t>(t.get_shape(), t.data<int64_t>());
break;
}
case ov::element::Type_t::u8: {
res[py::cast(out)] = py::array_t<uint8_t>(t.get_shape(), t.data<uint8_t>());
break;
}
case ov::element::Type_t::u16: {
res[py::cast(out)] = py::array_t<uint16_t>(t.get_shape(), t.data<uint16_t>());
break;
}
case ov::element::Type_t::u32: {
res[py::cast(out)] = py::array_t<uint32_t>(t.get_shape(), t.data<uint32_t>());
break;
}
case ov::element::Type_t::u64: {
res[py::cast(out)] = py::array_t<uint64_t>(t.get_shape(), t.data<uint64_t>());
break;
}
case ov::element::Type_t::bf16: {
res[py::cast(out)] = py::array(py::dtype("float16"), t.get_shape(), t.data<ov::bfloat16>());
break;
}
case ov::element::Type_t::f16: {
res[py::cast(out)] = py::array(py::dtype("float16"), t.get_shape(), t.data<ov::float16>());
break;
}
case ov::element::Type_t::f32: {
res[py::cast(out)] = py::array_t<float>(t.get_shape(), t.data<float>());
break;
}
case ov::element::Type_t::f64: {
res[py::cast(out)] = py::array_t<double>(t.get_shape(), t.data<double>());
break;
}
case ov::element::Type_t::boolean: {
res[py::cast(out)] = py::array_t<bool>(t.get_shape(), t.data<bool>());
break;
}
default: {
break;
}
}
} }
return res; return res;
} }

View File

@ -20,14 +20,20 @@
#include "openvino/runtime/infer_request.hpp" #include "openvino/runtime/infer_request.hpp"
#include "openvino/runtime/tensor.hpp" #include "openvino/runtime/tensor.hpp"
#include "openvino/pass/serialize.hpp" #include "openvino/pass/serialize.hpp"
#include "pyopenvino/core/containers.hpp"
#include "pyopenvino/graph/any.hpp" #include "pyopenvino/graph/any.hpp"
#include "pyopenvino/graph/ops/constant.hpp" #include "pyopenvino/graph/ops/constant.hpp"
#include "pyopenvino/core/infer_request.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Common { namespace Common {
namespace containers {
using TensorIndexMap = std::map<size_t, ov::Tensor>;
const TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs);
}; // namespace containers
namespace values { namespace values {
// Minimum amount of bits for common numpy types. Used to perform checks against OV types. // Minimum amount of bits for common numpy types. Used to perform checks against OV types.
@ -52,6 +58,8 @@ std::vector<size_t> get_strides(const py::array& array);
py::array as_contiguous(py::array& array, ov::element::Type type); py::array as_contiguous(py::array& array, ov::element::Type type);
py::array array_from_tensor(ov::Tensor&& t);
}; // namespace array_helpers }; // namespace array_helpers
template <typename T> template <typename T>
@ -80,15 +88,11 @@ ov::PartialShape partial_shape_from_list(const py::list& shape);
const ov::Tensor& cast_to_tensor(const py::handle& tensor); const ov::Tensor& cast_to_tensor(const py::handle& tensor);
const Containers::TensorNameMap cast_to_tensor_name_map(const py::dict& inputs);
const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs);
void set_request_tensors(ov::InferRequest& request, const py::dict& inputs); void set_request_tensors(ov::InferRequest& request, const py::dict& inputs);
uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual); uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual);
py::dict outputs_to_dict(const std::vector<ov::Output<const ov::Node>>& outputs, ov::InferRequest& request); py::dict outputs_to_dict(InferRequestWrapper& request);
ov::pass::Serialize::Version convert_to_version(const std::string& version); ov::pass::Serialize::Version convert_to_version(const std::string& version);

View File

@ -9,13 +9,9 @@
#include "common.hpp" #include "common.hpp"
#include "pyopenvino/core/compiled_model.hpp" #include "pyopenvino/core/compiled_model.hpp"
#include "pyopenvino/core/containers.hpp"
#include "pyopenvino/core/infer_request.hpp" #include "pyopenvino/core/infer_request.hpp"
#include "pyopenvino/utils/utils.hpp" #include "pyopenvino/utils/utils.hpp"
PYBIND11_MAKE_OPAQUE(Containers::TensorIndexMap);
PYBIND11_MAKE_OPAQUE(Containers::TensorNameMap);
namespace py = pybind11; namespace py = pybind11;
void regclass_CompiledModel(py::module m) { void regclass_CompiledModel(py::module m) {

View File

@ -1,23 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/core/containers.hpp"
#include <pybind11/stl_bind.h>
PYBIND11_MAKE_OPAQUE(Containers::TensorIndexMap);
PYBIND11_MAKE_OPAQUE(Containers::TensorNameMap);
namespace py = pybind11;
namespace Containers {
void regclass_TensorIndexMap(py::module m) {
py::bind_map<TensorIndexMap>(m, "TensorIndexMap");
}
void regclass_TensorNameMap(py::module m) {
py::bind_map<TensorNameMap>(m, "TensorNameMap");
}
} // namespace Containers

View File

@ -1,23 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <map>
#include <vector>
#include <pybind11/pybind11.h>
#include <openvino/runtime/tensor.hpp>
namespace py = pybind11;
namespace Containers {
using TensorIndexMap = std::map<size_t, ov::Tensor>;
using TensorNameMap = std::map<std::string, ov::Tensor>;
void regclass_TensorIndexMap(py::module m);
void regclass_TensorNameMap(py::module m);
}

View File

@ -11,12 +11,8 @@
#include <string> #include <string>
#include "pyopenvino/core/common.hpp" #include "pyopenvino/core/common.hpp"
#include "pyopenvino/core/containers.hpp"
#include "pyopenvino/utils/utils.hpp" #include "pyopenvino/utils/utils.hpp"
PYBIND11_MAKE_OPAQUE(Containers::TensorIndexMap);
PYBIND11_MAKE_OPAQUE(Containers::TensorNameMap);
namespace py = pybind11; namespace py = pybind11;
inline py::dict run_sync_infer(InferRequestWrapper& self) { inline py::dict run_sync_infer(InferRequestWrapper& self) {
@ -26,7 +22,7 @@ inline py::dict run_sync_infer(InferRequestWrapper& self) {
self.m_request.infer(); self.m_request.infer();
*self.m_end_time = Time::now(); *self.m_end_time = Time::now();
} }
return Common::outputs_to_dict(self.m_outputs, self.m_request); return Common::outputs_to_dict(self);
} }
void regclass_InferRequest(py::module m) { void regclass_InferRequest(py::module m) {
@ -103,7 +99,7 @@ void regclass_InferRequest(py::module m) {
cls.def( cls.def(
"set_output_tensors", "set_output_tensors",
[](InferRequestWrapper& self, const py::dict& outputs) { [](InferRequestWrapper& self, const py::dict& outputs) {
auto outputs_map = Common::cast_to_tensor_index_map(outputs); auto outputs_map = Common::containers::cast_to_tensor_index_map(outputs);
for (auto&& output : outputs_map) { for (auto&& output : outputs_map) {
self.m_request.set_output_tensor(output.first, output.second); self.m_request.set_output_tensor(output.first, output.second);
} }
@ -120,7 +116,7 @@ void regclass_InferRequest(py::module m) {
cls.def( cls.def(
"set_input_tensors", "set_input_tensors",
[](InferRequestWrapper& self, const py::dict& inputs) { [](InferRequestWrapper& self, const py::dict& inputs) {
auto inputs_map = Common::cast_to_tensor_index_map(inputs); auto inputs_map = Common::containers::cast_to_tensor_index_map(inputs);
for (auto&& input : inputs_map) { for (auto&& input : inputs_map) {
self.m_request.set_input_tensor(input.first, input.second); self.m_request.set_input_tensor(input.first, input.second);
} }
@ -719,7 +715,7 @@ void regclass_InferRequest(py::module m) {
cls.def_property_readonly( cls.def_property_readonly(
"results", "results",
[](InferRequestWrapper& self) { [](InferRequestWrapper& self) {
return Common::outputs_to_dict(self.m_outputs, self.m_request); return Common::outputs_to_dict(self);
}, },
R"( R"(
Gets all outputs tensors of this InferRequest. Gets all outputs tensors of this InferRequest.

View File

@ -24,7 +24,6 @@
#endif #endif
#include "pyopenvino/core/async_infer_queue.hpp" #include "pyopenvino/core/async_infer_queue.hpp"
#include "pyopenvino/core/compiled_model.hpp" #include "pyopenvino/core/compiled_model.hpp"
#include "pyopenvino/core/containers.hpp"
#include "pyopenvino/core/core.hpp" #include "pyopenvino/core/core.hpp"
#include "pyopenvino/core/extension.hpp" #include "pyopenvino/core/extension.hpp"
#include "pyopenvino/core/infer_request.hpp" #include "pyopenvino/core/infer_request.hpp"
@ -210,9 +209,6 @@ PYBIND11_MODULE(_pyopenvino, m) {
regclass_Core(m); regclass_Core(m);
regclass_Tensor(m); regclass_Tensor(m);
// Registering specific types of containers
Containers::regclass_TensorIndexMap(m);
Containers::regclass_TensorNameMap(m);
regclass_CompiledModel(m); regclass_CompiledModel(m);
regclass_InferRequest(m); regclass_InferRequest(m);

View File

@ -0,0 +1,249 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping
import numpy as np
import pytest
import openvino.runtime.opset10 as ops
from openvino.runtime import Core, ConstOutput, CompiledModel, InferRequest, Model
from openvino.runtime.ie_api import OVDict
def _get_ovdict(
device,
input_shape=None,
data_type=np.float32,
input_names=None,
output_names=None,
multi_output=False,
direct_infer=False,
split_num=5,
):
# Create model
# If model is multi-output (multi_output=True), input_shape must match
# requirements of split operation.
# TODO OpenSource: refactor it to be more generic
if input_shape is None:
input_shape = [1, 20]
if input_names is None:
input_names = ["data_0"]
if output_names is None:
output_names = ["output_0"]
if multi_output:
assert isinstance(output_names, (list, tuple))
assert len(output_names) > 1
assert len(output_names) == split_num
param = ops.parameter(input_shape, data_type, name=input_names[0])
model = Model(
ops.split(param, 1, split_num) if multi_output else ops.abs(param), [param],
)
# Manually name outputs
for i in range(len(output_names)):
model.output(i).tensor.names = {output_names[i]}
# Compile model
core = Core()
compiled_model = core.compile_model(model, device)
# Create test data
input_data = np.random.random(input_shape).astype(data_type)
# Two ways of infering
if direct_infer:
result = compiled_model(input_data)
assert result is not None
return result, compiled_model
request = compiled_model.create_infer_request()
result = request.infer(input_data)
assert result is not None
return result, request
def _check_keys(keys, outs):
outs_iter = iter(outs)
for key in keys:
assert isinstance(key, ConstOutput)
assert key == next(outs_iter)
return True
def _check_values(result):
for value in result.values():
assert isinstance(value, np.ndarray)
return True
def _check_items(result, outs, output_names):
i = 0
for key, value in result.items():
assert isinstance(key, ConstOutput)
assert isinstance(value, np.ndarray)
# Check values
assert np.equal(result[outs[i]], result[key]).all()
assert np.equal(result[outs[i]], result[i]).all()
assert np.equal(result[outs[i]], result[output_names[i]]).all()
i += 1
return True
def _check_dict(result, obj, output_names=None):
if output_names is None:
output_names = ["output_0"]
outs = obj.model_outputs if isinstance(obj, InferRequest) else obj.outputs
assert len(outs) == len(result)
assert len(outs) == len(output_names)
# Check for __iter__
assert _check_keys(result, outs)
# Check for keys function
assert _check_keys(result.keys(), outs)
assert _check_values(result)
assert _check_items(result, outs, output_names)
assert result.names() == output_names
return True
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_assign(device, is_direct):
result, _ = _get_ovdict(device, multi_output=False, direct_infer=is_direct)
with pytest.raises(TypeError) as e:
result["some_name"] = 99
assert "'OVDict' object does not support item assignment" in str(e.value)
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_single_output_basic(device, is_direct):
result, obj = _get_ovdict(device, multi_output=False, direct_infer=is_direct)
assert isinstance(result, OVDict)
if isinstance(obj, (InferRequest, CompiledModel)):
assert _check_dict(result, obj)
else:
raise TypeError("Unknown `obj` type!")
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_single_output_noname(device, is_direct):
result, obj = _get_ovdict(
device,
multi_output=False,
direct_infer=is_direct,
output_names=[],
)
assert isinstance(result, OVDict)
outs = obj.model_outputs if isinstance(obj, InferRequest) else obj.outputs
assert isinstance(result[outs[0]], np.ndarray)
assert isinstance(result[0], np.ndarray)
with pytest.raises(RuntimeError) as e0:
_ = result["some_name"]
assert "Attempt to get a name for a Tensor without names" in str(e0.value)
with pytest.raises(RuntimeError) as e1:
_ = result.names()
assert "Attempt to get a name for a Tensor without names" in str(e1.value)
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_single_output_wrongname(device, is_direct):
result, obj = _get_ovdict(
device,
multi_output=False,
direct_infer=is_direct,
output_names=["output_21"],
)
assert isinstance(result, OVDict)
outs = obj.model_outputs if isinstance(obj, InferRequest) else obj.outputs
assert isinstance(result[outs[0]], np.ndarray)
assert isinstance(result[0], np.ndarray)
with pytest.raises(KeyError) as e:
_ = result["output_37"]
assert "output_37" in str(e.value)
with pytest.raises(KeyError) as e:
_ = result[6]
assert "6" in str(e.value)
@pytest.mark.parametrize("is_direct", [True, False])
@pytest.mark.parametrize("use_function", [True, False])
def test_ovdict_single_output_dict(device, is_direct, use_function):
result, obj = _get_ovdict(
device,
multi_output=False,
direct_infer=is_direct,
)
assert isinstance(result, OVDict)
outs = obj.model_outputs if isinstance(obj, InferRequest) else obj.outputs
native_dict = result.to_dict() if use_function else dict(result)
assert issubclass(type(native_dict), dict)
assert not isinstance(native_dict, OVDict)
assert isinstance(native_dict[outs[0]], np.ndarray)
with pytest.raises(KeyError) as e:
_ = native_dict["output_0"]
assert "output_0" in str(e.value)
with pytest.raises(KeyError) as e:
_ = native_dict[0]
assert "0" in str(e.value)
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_multi_output_basic(device, is_direct):
output_names = ["output_0", "output_1", "output_2", "output_3", "output_4"]
result, obj = _get_ovdict(
device,
multi_output=True,
direct_infer=is_direct,
output_names=output_names,
)
assert isinstance(result, OVDict)
if isinstance(obj, (InferRequest, CompiledModel)):
assert _check_dict(result, obj, output_names)
else:
raise TypeError("Unknown `obj` type!")
@pytest.mark.parametrize("is_direct", [True, False])
@pytest.mark.parametrize("use_function", [True, False])
def test_ovdict_multi_output_tuple0(device, is_direct, use_function):
output_names = ["output_0", "output_1"]
result, obj = _get_ovdict(
device,
input_shape=(1, 10),
multi_output=True,
direct_infer=is_direct,
split_num=2,
output_names=output_names,
)
out0, out1 = None, None
if use_function:
assert isinstance(result.to_tuple(), tuple)
out0, out1 = result.to_tuple()
else:
out0, out1 = result.values()
assert out0 is not None
assert out1 is not None
assert isinstance(out0, np.ndarray)
assert isinstance(out1, np.ndarray)
outs = obj.model_outputs if isinstance(obj, InferRequest) else obj.outputs
assert np.equal(result[outs[0]], out0).all()
assert np.equal(result[outs[1]], out1).all()