[PyOV] String type support (#21532)

* WIP working Tensor with automatic casting

* Update infer functions and minor fixes

* Impl for multi-dim arrays with fixes for strides

* Fix strides for S kind and refactor tests

* Added str_data and bytes_data to Tensor, cleaning up the solution

* Add test of warning while using data property

* Allow lists as inputs for single Tensors, refactor infer code to return OVDict with decoded strings, some refactoring, tests

* Replace string with another invalid input

* Added bytes_str and string_str properties, clean up common part of bindings, added test cases

* Improve string element type to be created from numpy/python counterparts, refactor of common code, small improvements

* Add tests for types

* Remove print

* Small fix for tensors from pointers

* Small fix for asserts

* Add tests for data dispatcher

* Fix comments

* fix tests

* Fix edge-case for scalar-like values and unlock tests for data dispatcher
This commit is contained in:
Jan Iwaszkiewicz 2023-12-15 21:14:55 +01:00 committed by GitHub
parent 7a311635bd
commit eff9ba76ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 832 additions and 57 deletions

View File

@ -262,7 +262,7 @@ class CompiledModel(CompiledModelBase):
"""
return InferRequest(super().create_infer_request())
def infer_new_request(self, inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None) -> OVDict:
def infer_new_request(self, inputs: Any = None) -> OVDict:
"""Infers specified input(s) in synchronous mode.
Blocks all methods of CompiledModel while request is running.
@ -287,7 +287,7 @@ class CompiledModel(CompiledModelBase):
function throws error.
:param inputs: Data to be set on input tensors.
:type inputs: Union[Dict[keys, values], List[values], Tuple[values], Tensor, numpy.ndarray], optional
:type inputs: Any, optional
:return: Dictionary of results from output tensors with port/int/str keys.
:rtype: OVDict
"""
@ -297,7 +297,7 @@ class CompiledModel(CompiledModelBase):
def __call__(
self,
inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None,
inputs: Any = None,
share_inputs: bool = True,
share_outputs: bool = False,
*,
@ -332,7 +332,7 @@ class CompiledModel(CompiledModelBase):
function throws error.
:param inputs: Data to be set on input tensors.
:type inputs: Union[Dict[keys, values], List[values], Tuple[values], Tensor, numpy.ndarray], optional
:type inputs: Any, optional
:param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs.
If set to `False` inputs the data dispatcher will safely copy data

View File

@ -10,7 +10,7 @@ import numpy as np
from openvino._pyopenvino import ConstOutput, Tensor, Type
from openvino.runtime.utils.data_helpers.wrappers import _InferRequestWrapper, OVDict
ContainerTypes = Union[dict, list, tuple]
ContainerTypes = Union[dict, list, tuple, OVDict]
ScalarTypes = Union[np.number, int, float]
ValidKeys = Union[str, int, ConstOutput]
@ -31,7 +31,7 @@ def get_request_tensor(
@singledispatch
def value_to_tensor(
value: Union[Tensor, np.ndarray, ScalarTypes],
value: Union[Tensor, np.ndarray, ScalarTypes, str],
request: Optional[_InferRequestWrapper] = None,
is_shared: bool = False,
key: Optional[ValidKeys] = None,
@ -59,6 +59,11 @@ def _(
tensor = get_request_tensor(request, key)
tensor_type = tensor.get_element_type()
tensor_dtype = tensor_type.to_dtype()
# String edge-case, always copy.
# Scalars are also handled by C++.
if tensor_type == Type.string:
return Tensor(value, shared_memory=False)
# Scalars edge-case:
if value.ndim == 0:
tensor_shape = tuple(tensor.shape)
if tensor_dtype == value.dtype and tensor_shape == value.shape:
@ -82,21 +87,34 @@ def _(
return Tensor(value, shared_memory=is_shared)
@value_to_tensor.register(np.number)
@value_to_tensor.register(int)
@value_to_tensor.register(float)
@value_to_tensor.register(list)
def _(
value: ScalarTypes,
value: list,
request: _InferRequestWrapper,
is_shared: bool = False,
key: Optional[ValidKeys] = None,
) -> Tensor:
# np.number/int/float edge-case, copy will occur in both scenarios.
return Tensor(value)
@value_to_tensor.register(np.number)
@value_to_tensor.register(int)
@value_to_tensor.register(float)
@value_to_tensor.register(str)
@value_to_tensor.register(bytes)
def _(
value: Union[ScalarTypes, str, bytes],
request: _InferRequestWrapper,
is_shared: bool = False,
key: Optional[ValidKeys] = None,
) -> Tensor:
# np.number/int/float/str/bytes edge-case, copy will occur in both scenarios.
tensor_type = get_request_tensor(request, key).get_element_type()
tensor_dtype = tensor_type.to_dtype()
tmp = np.array(value)
# String edge-case -- it converts the data inside of Tensor class.
# If types are mismatched, convert.
if tensor_dtype != tmp.dtype:
if tensor_type != Type.string and tensor_dtype != tmp.dtype:
return Tensor(tmp.astype(tensor_dtype), shared_memory=False)
return Tensor(tmp, shared_memory=False)
@ -204,8 +222,10 @@ def _(
@create_shared.register(np.number)
@create_shared.register(int)
@create_shared.register(float)
@create_shared.register(str)
@create_shared.register(bytes)
def _(
inputs: Union[Tensor, ScalarTypes],
inputs: Union[Tensor, ScalarTypes, str, bytes],
request: _InferRequestWrapper,
) -> Tensor:
return value_to_tensor(inputs, request=request, is_shared=True)
@ -256,7 +276,10 @@ def _(
if tuple(tensor.shape) != inputs.shape:
tensor.shape = inputs.shape
# When copying, type should be up/down-casted automatically.
tensor.data[:] = inputs[:]
if tensor.element_type == Type.string:
tensor.bytes_data = inputs
else:
tensor.data[:] = inputs[:]
else:
# If shape is "empty", assume this is a scalar value
set_request_tensor(
@ -269,8 +292,9 @@ def _(
@update_tensor.register(np.number) # type: ignore
@update_tensor.register(float)
@update_tensor.register(int)
@update_tensor.register(str)
def _(
inputs: Union[np.number, float, int],
inputs: Union[ScalarTypes, str],
request: _InferRequestWrapper,
key: Optional[ValidKeys] = None,
) -> None:
@ -286,6 +310,7 @@ def update_inputs(inputs: dict, request: _InferRequestWrapper) -> dict:
It creates copy of Tensors or copy data to already allocated Tensors on device
if the item is of type `np.ndarray`, `np.number`, `int`, `float` or has numpy __array__ attribute.
If value is of type `list`, create a Tensor based on it, copy will occur in the Tensor constructor.
"""
# Create new temporary dictionary.
# new_inputs will be used to transfer data to inference calls,
@ -296,8 +321,10 @@ def update_inputs(inputs: dict, request: _InferRequestWrapper) -> dict:
raise TypeError(f"Incompatible key type for input: {key}")
# Copy numpy arrays to already allocated Tensors.
# If value object has __array__ attribute, load it to Tensor using np.array
if isinstance(value, (np.ndarray, np.number, int, float)) or hasattr(value, "__array__"):
if isinstance(value, (np.ndarray, np.number, int, float, str)) or hasattr(value, "__array__"):
update_tensor(value, request, key)
elif isinstance(value, list):
new_inputs[key] = Tensor(value)
# If value is of Tensor type, put it into temporary dictionary.
elif isinstance(value, Tensor):
new_inputs[key] = value
@ -309,7 +336,7 @@ def update_inputs(inputs: dict, request: _InferRequestWrapper) -> dict:
@singledispatch
def create_copied(
inputs: Union[ContainerTypes, OVDict, np.ndarray, ScalarTypes],
inputs: Union[ContainerTypes, np.ndarray, ScalarTypes, str, bytes],
request: _InferRequestWrapper,
) -> Union[dict, None]:
# Check the special case of the array-interface
@ -325,7 +352,7 @@ def create_copied(
@create_copied.register(tuple)
@create_copied.register(OVDict)
def _(
inputs: Union[ContainerTypes, OVDict],
inputs: ContainerTypes,
request: _InferRequestWrapper,
) -> dict:
return update_inputs(normalize_arrays(inputs, is_shared=False), request)
@ -344,8 +371,10 @@ def _(
@create_copied.register(np.number)
@create_copied.register(int)
@create_copied.register(float)
@create_copied.register(str)
@create_copied.register(bytes)
def _(
inputs: Union[Tensor, ScalarTypes],
inputs: Union[Tensor, ScalarTypes, str, bytes],
request: _InferRequestWrapper,
) -> Tensor:
return value_to_tensor(inputs, request=request, is_shared=False)
@ -356,7 +385,7 @@ def _(
def _data_dispatch(
request: _InferRequestWrapper,
inputs: Union[ContainerTypes, OVDict, Tensor, np.ndarray, ScalarTypes] = None,
inputs: Union[ContainerTypes, Tensor, np.ndarray, ScalarTypes, str] = None,
is_shared: bool = False,
) -> Union[dict, Tensor]:
if inputs is None:

View File

@ -36,6 +36,10 @@ openvino_to_numpy_types_map = [
(Type.u32, np.uint32),
(Type.u64, np.uint64),
(Type.bf16, np.uint16),
(Type.string, str),
(Type.string, np.str_),
(Type.string, bytes),
(Type.string, np.bytes_),
]
openvino_to_numpy_types_str_map = [
@ -52,6 +56,10 @@ openvino_to_numpy_types_str_map = [
("u16", np.uint16),
("u32", np.uint32),
("u64", np.uint64),
("string", str),
("string", np.str_),
("string", bytes),
("string", np.bytes_),
]

View File

@ -14,6 +14,9 @@
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
namespace Common {
namespace type_helpers {
const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
static const std::map<ov::element::Type, py::dtype> ov_type_to_dtype_mapping = {
{ov::element::f16, py::dtype("float16")},
@ -33,10 +36,15 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
{ov::element::u4, py::dtype("uint8")},
{ov::element::nf4, py::dtype("uint8")},
{ov::element::i4, py::dtype("int8")},
{ov::element::string, py::dtype("bytes_")},
};
return ov_type_to_dtype_mapping;
}
py::dtype get_dtype(const ov::element::Type& ov_type) {
return ov_type_to_dtype().at(ov_type);
}
const std::map<std::string, ov::element::Type>& dtype_to_ov_type() {
static const std::map<std::string, ov::element::Type> dtype_to_ov_type_mapping = {
{"float16", ov::element::f16},
@ -51,10 +59,36 @@ const std::map<std::string, ov::element::Type>& dtype_to_ov_type() {
{"uint32", ov::element::u32},
{"uint64", ov::element::u64},
{"bool", ov::element::boolean},
{"bytes_", ov::element::string},
{"str_", ov::element::string},
{"bytes", ov::element::string},
{"str", ov::element::string},
};
return dtype_to_ov_type_mapping;
}
ov::element::Type get_ov_type(const py::array& array) {
// More about character codes:
// https://numpy.org/doc/stable/reference/arrays.scalars.html
char ctype = array.dtype().kind();
if (ctype == 'U' || ctype == 'S') {
return ov::element::string;
}
return dtype_to_ov_type().at(py::str(array.dtype()));
}
ov::element::Type get_ov_type(py::dtype& dtype) {
// More about character codes:
// https://numpy.org/doc/stable/reference/arrays.scalars.html
char ctype = dtype.kind();
if (ctype == 'U' || ctype == 'S') {
return ov::element::string;
}
return dtype_to_ov_type().at(py::str(dtype));
}
}; // namespace type_helpers
namespace containers {
const TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) {
TensorIndexMap result_map;
@ -76,16 +110,110 @@ const TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs) {
}
}; // namespace containers
namespace string_helpers {
py::array bytes_array_from_tensor(ov::Tensor&& t) {
if (t.get_element_type() != ov::element::string) {
OPENVINO_THROW("Tensor's type must be a string!");
}
auto data = t.data<std::string>();
auto max_element = std::max_element(data, data + t.get_size(), [](const std::string& x, const std::string& y) {
return x.length() < y.length();
});
auto max_stride = max_element->length();
auto dtype = py::dtype("|S" + std::to_string(max_stride));
// Adjusting strides to follow the numpy convention:
py::array array;
auto new_strides = t.get_strides();
if (new_strides.size() == 0) {
array = py::array(dtype, t.get_shape(), {});
} else {
auto element_stride = new_strides[new_strides.size() - 1];
for (size_t i = 0; i < new_strides.size(); ++i) {
new_strides[i] = (new_strides[i] / element_stride) * max_stride;
}
array = py::array(dtype, t.get_shape(), new_strides);
}
// Create an empty array and populate it with utf-8 encoded strings:
auto ptr = array.data();
for (size_t i = 0; i < t.get_size(); ++i) {
auto start = &data[i][0];
auto length = data[i].length();
auto end = std::copy(start, start + length, (char*)ptr + i * max_stride);
std::fill_n(end, max_stride - length, 0);
}
return array;
}
py::array string_array_from_tensor(ov::Tensor&& t) {
if (t.get_element_type() != ov::element::string) {
OPENVINO_THROW("Tensor's type must be a string!");
}
// Approach that is compact and faster than np.char.decode(tensor.data):
auto data = t.data<std::string>();
py::list _list;
for (size_t i = 0; i < t.get_size(); ++i) {
PyObject* _unicode_obj = PyUnicode_DecodeUTF8(&data[i][0], data[i].length(), "strict");
_list.append(_unicode_obj);
Py_XDECREF(_unicode_obj);
}
// Adjusting shape to follow the numpy convention:
py::array array(_list);
array.resize(t.get_shape());
return array;
}
void fill_tensor_from_bytes(ov::Tensor& tensor, py::array& array) {
if (tensor.get_size() != static_cast<size_t>(array.size())) {
OPENVINO_THROW("Passed array must have the same size (number of elements) as the Tensor!");
}
py::buffer_info buf = array.request();
auto data = tensor.data<std::string>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
const char* ptr = reinterpret_cast<const char*>(buf.ptr) + (i * buf.itemsize);
data[i] = std::string(ptr, buf.ndim == 0 ? buf.itemsize : buf.strides[0]);
}
}
void fill_tensor_from_strings(ov::Tensor& tensor, py::array& array) {
if (tensor.get_size() != static_cast<size_t>(array.size())) {
OPENVINO_THROW("Passed array must have the same size (number of elements) as the Tensor!");
}
py::buffer_info buf = array.request();
auto data = tensor.data<std::string>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
char* ptr = reinterpret_cast<char*>(buf.ptr) + (i * buf.itemsize);
// TODO: check other unicode kinds? 2BYTE and 1BYTE?
PyObject* _unicode_obj =
PyUnicode_FromKindAndData(PyUnicode_4BYTE_KIND, reinterpret_cast<void*>(ptr), buf.itemsize / 4);
PyObject* _utf8_obj = PyUnicode_AsUTF8String(_unicode_obj);
const char* _tmp_str = PyBytes_AsString(_utf8_obj);
data[i] = std::string(_tmp_str);
Py_XDECREF(_unicode_obj);
Py_XDECREF(_utf8_obj);
}
}
void fill_string_tensor_data(ov::Tensor& tensor, py::array& array) {
// More about character codes:
// https://numpy.org/doc/stable/reference/arrays.scalars.html
if (array.dtype().kind() == 'S') {
fill_tensor_from_bytes(tensor, array);
} else if (array.dtype().kind() == 'U') {
fill_tensor_from_strings(tensor, array);
} else {
OPENVINO_THROW("Unknown string kind passed to fill the Tensor's data!");
}
}
}; // namespace string_helpers
namespace array_helpers {
bool is_contiguous(const py::array& array) {
return C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS);
}
ov::element::Type get_ov_type(const py::array& array) {
return Common::dtype_to_ov_type().at(py::str(array.dtype()));
}
std::vector<size_t> get_shape(const py::array& array) {
return std::vector<size_t>(array.shape(), array.shape() + array.ndim());
}
@ -134,9 +262,18 @@ py::array as_contiguous(py::array& array, ov::element::Type type) {
}
py::array array_from_tensor(ov::Tensor&& t, bool is_shared) {
// Special case for string data type.
if (t.get_element_type() == ov::element::string) {
PyErr_WarnEx(PyExc_RuntimeWarning,
"Data of string type will be copied! Please use dedicated properties "
"`str_data` and `bytes_data` to avoid confusion while accessing "
"Tensor's contents.",
1);
return string_helpers::bytes_array_from_tensor(std::move(t));
}
// Get actual dtype from OpenVINO type:
auto ov_type = t.get_element_type();
auto dtype = Common::ov_type_to_dtype().at(ov_type);
auto dtype = Common::type_helpers::get_dtype(ov_type);
// Return the array as a view:
if (is_shared) {
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
@ -157,16 +294,16 @@ template <>
ov::op::v0::Constant create_copied(py::array& array) {
// Do not copy data from the array, only return empty tensor based on type.
if (array.size() == 0) {
return ov::op::v0::Constant(array_helpers::get_ov_type(array), array_helpers::get_shape(array));
return ov::op::v0::Constant(type_helpers::get_ov_type(array), array_helpers::get_shape(array));
}
// Convert to contiguous array if not already in C-style.
if (!array_helpers::is_contiguous(array)) {
array = array_helpers::as_contiguous(array, array_helpers::get_ov_type(array));
array = array_helpers::as_contiguous(array, type_helpers::get_ov_type(array));
}
// Create actual Constant and a constructor is copying data.
// If ndim is equal to 0, creates scalar Constant.
// If size is equal to 0, creates empty Constant.
return ov::op::v0::Constant(array_helpers::get_ov_type(array),
return ov::op::v0::Constant(type_helpers::get_ov_type(array),
array_helpers::get_shape(array),
array.ndim() == 0 ? array.data() : array.data(0));
}
@ -188,7 +325,7 @@ ov::op::v0::Constant create_shared(py::array& array) {
static_cast<char*>((array.ndim() == 0 || array.size() == 0) ? array.mutable_data() : array.mutable_data(0)),
array.ndim() == 0 ? array.itemsize() : array.nbytes(),
array);
return ov::op::v0::Constant(array_helpers::get_ov_type(array), array_helpers::get_shape(array), memory);
return ov::op::v0::Constant(type_helpers::get_ov_type(array), array_helpers::get_shape(array), memory);
}
// If passed array is not C-style, throw an error.
OPENVINO_THROW("SHARED MEMORY MODE FOR THIS CONSTANT IS NOT APPLICABLE! Passed numpy array must be C contiguous.");
@ -202,7 +339,7 @@ ov::op::v0::Constant create_shared(ov::Tensor& tensor) {
template <>
ov::Tensor create_copied(py::array& array) {
// Create actual Tensor.
auto tensor = ov::Tensor(array_helpers::get_ov_type(array), array_helpers::get_shape(array));
auto tensor = ov::Tensor(type_helpers::get_ov_type(array), array_helpers::get_shape(array));
// If size of an array is equal to 0, the array is empty.
// Alternative could be `array.nbytes()`.
// Do not copy data from it, only return empty tensor based on type.
@ -211,7 +348,13 @@ ov::Tensor create_copied(py::array& array) {
}
// Convert to contiguous array if not already in C-style.
if (!array_helpers::is_contiguous(array)) {
array = array_helpers::as_contiguous(array, array_helpers::get_ov_type(array));
array = array_helpers::as_contiguous(array, type_helpers::get_ov_type(array));
}
// Special case with string data type of kind "S"(bytes_) or "U"(str_).
// pass through check for other string types like bytes and str.
if (type_helpers::get_ov_type(array) == ov::element::string) {
string_helpers::fill_string_tensor_data(tensor, array);
return tensor;
}
// If ndim of py::array is 0, array is a numpy scalar. That results in size to be equal to 0.
std::memcpy(tensor.data(),
@ -222,12 +365,15 @@ ov::Tensor create_copied(py::array& array) {
template <>
ov::Tensor create_shared(py::array& array) {
if (type_helpers::get_ov_type(array) == ov::element::string) {
OPENVINO_THROW("SHARED MEMORY MODE FOR THIS TENSOR IS NOT APPLICABLE! String types can be only copied.");
}
// Check if passed array has C-style contiguous memory layout.
// If memory is going to be shared it needs to be contiguous before passing to the constructor.
if (array_helpers::is_contiguous(array)) {
// If ndim of py::array is 0, array is a numpy scalar.
// If size of an array is equal to 0, the array is empty.
return ov::Tensor(array_helpers::get_ov_type(array),
return ov::Tensor(type_helpers::get_ov_type(array),
array_helpers::get_shape(array),
(array.ndim() == 0 || array.size() == 0) ? array.mutable_data() : array.mutable_data(0));
}
@ -236,7 +382,11 @@ ov::Tensor create_shared(py::array& array) {
}
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& type) {
auto element_type = (type == ov::element::undefined) ? Common::dtype_to_ov_type().at(py::str(array.dtype())) : type;
if (type_helpers::get_ov_type(array) == ov::element::string) {
OPENVINO_THROW("SHARED MEMORY MODE FOR THIS TENSOR IS NOT APPLICABLE! String types can be only copied.");
}
auto element_type = (type == ov::element::undefined) ? Common::type_helpers::get_ov_type(array) : type;
if (array_helpers::is_contiguous(array)) {
return ov::Tensor(element_type, shape, const_cast<void*>(array.data(0)), {});
@ -245,7 +395,11 @@ ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const o
}
ov::Tensor tensor_from_pointer(py::array& array, const ov::Output<const ov::Node>& port) {
auto array_type = array_helpers::get_ov_type(array);
if (type_helpers::get_ov_type(array) == ov::element::string) {
OPENVINO_THROW("SHARED MEMORY MODE FOR THIS TENSOR IS NOT APPLICABLE! String types can be only copied.");
}
auto array_type = type_helpers::get_ov_type(array);
auto array_shape_size = ov::shape_size(array_helpers::get_shape(array));
auto port_element_type = port.get_element_type();
auto port_shape_size = ov::shape_size(port.get_partial_shape().is_dynamic() ? ov::Shape{0} : port.get_shape());
@ -348,7 +502,15 @@ uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual) {
py::dict outputs_to_dict(InferRequestWrapper& request, bool share_outputs) {
py::dict res;
for (const auto& out : request.m_outputs) {
res[py::cast(out)] = array_helpers::array_from_tensor(request.m_request.get_tensor(out), share_outputs);
auto t = request.m_request.get_tensor(out);
if (t.get_element_type() == ov::element::string) {
if (share_outputs) {
PyErr_WarnEx(PyExc_RuntimeWarning, "Result of a string type will be copied to OVDict!", 1);
}
res[py::cast(out)] = string_helpers::string_array_from_tensor(std::move(t));
} else {
res[py::cast(out)] = array_helpers::array_from_tensor(std::move(t), share_outputs);
}
}
return res;
}

View File

@ -40,17 +40,40 @@ constexpr size_t min_bitwidth = sizeof(int8_t) * CHAR_BIT;
}; // namespace values
// Helpers for dtypes and OpenVINO types
namespace type_helpers {
const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype();
py::dtype get_dtype(const ov::element::Type& ov_type);
const std::map<std::string, ov::element::Type>& dtype_to_ov_type();
ov::element::Type get_ov_type(const py::array& array);
ov::element::Type get_ov_type(py::dtype& dtype);
}
// Helpers for string types and numpy arrays of strings
namespace string_helpers {
py::array bytes_array_from_tensor(ov::Tensor&& t);
py::array string_array_from_tensor(ov::Tensor&& t);
void fill_tensor_from_bytes(ov::Tensor& tensor, py::array& array);
void fill_tensor_from_strings(ov::Tensor& tensor, py::array& array);
void fill_string_tensor_data(ov::Tensor& tensor, py::array& array);
}; // namespace string_helpers
// Helpers for numpy arrays
namespace array_helpers {
bool is_contiguous(const py::array& array);
ov::element::Type get_ov_type(const py::array& array);
std::vector<size_t> get_shape(const py::array& array);
std::vector<size_t> get_strides(const py::array& array);

View File

@ -15,11 +15,6 @@
namespace py = pybind11;
inline std::string to_string(py::handle handle) {
auto encodedString = PyUnicode_AsUTF8String(handle.ptr());
return PyBytes_AsString(encodedString);
}
void regclass_Core(py::module m) {
py::class_<ov::Core, std::shared_ptr<ov::Core>> cls(m, "Core");
cls.doc() =

View File

@ -24,7 +24,7 @@ void regclass_Tensor(py::module m) {
R"(
Tensor's special constructor.
:param array: Array to create tensor from.
:param array: Array to create the tensor from.
:type array: numpy.array
:param shared_memory: If `True`, this Tensor memory is being shared with a host,
that means the responsibility of keeping host memory is
@ -105,30 +105,48 @@ void regclass_Tensor(py::module m) {
t = ov.Tensor(arr, [100, 8], ov.Type.u1)
)");
// It may clash in future with overloads like <ov::Coordinate, ov::Coordinate>
cls.def(py::init([](py::list& list) {
auto array = py::array(list);
return Common::object_from_data<ov::Tensor>(array, false);
}),
py::arg("list"),
R"(
Tensor's special constructor.
Creates a Tensor from a given Python list.
Warning: It is always a copy of list's data!
:param array: List to create the tensor from.
:type array: List[int, float, str]
)");
cls.def(py::init<const ov::element::Type, const ov::Shape>(), py::arg("type"), py::arg("shape"));
cls.def(py::init<const ov::element::Type, const std::vector<size_t>>(), py::arg("type"), py::arg("shape"));
cls.def(py::init([](py::dtype& np_dtype, std::vector<size_t>& shape) {
return ov::Tensor(Common::dtype_to_ov_type().at(py::str(np_dtype)), shape);
return ov::Tensor(Common::type_helpers::get_ov_type(np_dtype), shape);
}),
py::arg("type"),
py::arg("shape"));
cls.def(py::init([](py::object& np_literal, std::vector<size_t>& shape) {
return ov::Tensor(Common::dtype_to_ov_type().at(py::str(py::dtype::from_args(np_literal))), shape);
auto dtype = py::dtype::from_args(np_literal);
return ov::Tensor(Common::type_helpers::get_ov_type(dtype), shape);
}),
py::arg("type"),
py::arg("shape"));
cls.def(py::init([](py::dtype& np_dtype, const ov::Shape& shape) {
return ov::Tensor(Common::dtype_to_ov_type().at(py::str(np_dtype)), shape);
return ov::Tensor(Common::type_helpers::get_ov_type(np_dtype), shape);
}),
py::arg("type"),
py::arg("shape"));
cls.def(py::init([](py::object& np_literal, const ov::Shape& shape) {
return ov::Tensor(Common::dtype_to_ov_type().at(py::str(py::dtype::from_args(np_literal))), shape);
auto dtype = py::dtype::from_args(np_literal);
return ov::Tensor(Common::type_helpers::get_ov_type(dtype), shape);
}),
py::arg("type"),
py::arg("shape"));
@ -271,12 +289,75 @@ void regclass_Tensor(py::module m) {
Access to Tensor's data.
Returns numpy array with corresponding shape and dtype.
For tensors with openvino specific element type, such as u1, u4 or i4
For tensors with OpenVINO specific element type, such as u1, u4 or i4
it returns linear array, with uint8 / int8 numpy dtype.
For tensors with string element type, returns a numpy array of bytes
without any decoding.
To change the underlaying data use `str_data`/`bytes_data` properties
or the `copy_from` function.
Warning: Data of string type is always a copy of underlaying memory!
:rtype: numpy.array
)");
cls.def_property(
"bytes_data",
[](ov::Tensor& self) {
return Common::string_helpers::bytes_array_from_tensor(std::forward<ov::Tensor>(self));
},
[](ov::Tensor& self, py::object& other) {
if (py::isinstance<py::array>(other)) {
auto array = other.cast<py::array>();
Common::string_helpers::fill_string_tensor_data(self, array);
} else if (py::isinstance<py::list>(other)) {
auto array = py::array(other.cast<py::list>());
Common::string_helpers::fill_string_tensor_data(self, array);
} else {
OPENVINO_THROW("Invalid data to fill String Tensor!");
}
return;
},
R"(
Access to Tensor's data with string Type in `np.bytes_` dtype.
Getter returns a numpy array with corresponding shape and dtype.
Warning: Data of string type is always a copy of underlaying memory!
Setter fills underlaying Tensor's memory by copying strings from `other`.
`other` must have the same size (number of elements) as the Tensor.
Tensor's shape is not changed by performing this operation!
)");
cls.def_property(
"str_data",
[](ov::Tensor& self) {
return Common::string_helpers::string_array_from_tensor(std::forward<ov::Tensor>(self));
},
[](ov::Tensor& self, py::object& other) {
if (py::isinstance<py::array>(other)) {
auto array = other.cast<py::array>();
Common::string_helpers::fill_string_tensor_data(self, array);
} else if (py::isinstance<py::list>(other)) {
auto array = py::array(other.cast<py::list>());
Common::string_helpers::fill_string_tensor_data(self, array);
} else {
OPENVINO_THROW("Invalid data to fill String Tensor!");
}
return;
},
R"(
Access to Tensor's data with string Type in `np.str_` dtype.
Getter returns a numpy array with corresponding shape and dtype.
Warning: Data of string type is always a copy of underlaying memory!
Setter fills underlaying Tensor's memory by copying strings from `other`.
`other` must have the same size (number of elements) as the Tensor.
Tensor's shape is not changed by performing this operation!
)");
cls.def("get_shape",
&ov::Tensor::get_shape,
R"(
@ -310,6 +391,47 @@ void regclass_Tensor(py::module m) {
Copy tensor's data to a destination tensor. The destination tensor should have the same element type and shape.
)");
cls.def(
"copy_from",
[](ov::Tensor& self, ov::Tensor& source) {
return source.copy_to(self);
},
py::arg("source"),
R"(
Copy source tensor's data to this tensor. Tensors should have the same element type and shape.
)");
cls.def(
"copy_from",
[](ov::Tensor& self, py::array& source) {
auto _source = Common::object_from_data<ov::Tensor>(source, false);
if (self.get_shape() != _source.get_shape()) {
self.set_shape(_source.get_shape());
}
return _source.copy_to(self);
},
py::arg("source"),
R"(
Copy the source to this tensor. This tensor and the source should have the same element type.
Shape will be adjusted if there is a mismatch.
)");
cls.def(
"copy_from",
[](ov::Tensor& self, py::list& source) {
auto array = py::array(source);
auto _source = Common::object_from_data<ov::Tensor>(array, false);
if (self.get_shape() != _source.get_shape()) {
self.set_shape(_source.get_shape());
}
return _source.copy_to(self);
},
py::arg("source"),
R"(
Copy the source to this tensor. This tensor and the source should have the same element type.
Shape will be adjusted if there is a mismatch.
)");
cls.def("is_continuous",
&ov::Tensor::is_continuous,
R"(

View File

@ -203,7 +203,7 @@ void regclass_graph_op_Constant(py::module m) {
"get_data",
[](ov::op::v0::Constant& self) {
auto ov_type = self.get_element_type();
auto dtype = Common::ov_type_to_dtype().at(ov_type);
auto dtype = Common::type_helpers::get_dtype(ov_type);
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
return py::array(dtype, self.get_byte_size(), self.get_data_ptr());
}
@ -223,7 +223,7 @@ void regclass_graph_op_Constant(py::module m) {
"data",
[](ov::op::v0::Constant& self) {
auto ov_type = self.get_element_type();
auto dtype = Common::ov_type_to_dtype().at(ov_type);
auto dtype = Common::type_helpers::get_dtype(ov_type);
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
return py::array(dtype, self.get_byte_size(), self.get_data_ptr(), py::cast(self));
}

View File

@ -18,7 +18,8 @@ void regclass_graph_Type(py::module m) {
type.doc() = "openvino.runtime.Type wraps ov::element::Type";
type.def(py::init([](py::object& np_literal) {
return Common::dtype_to_ov_type().at(py::str(py::dtype::from_args(np_literal)));
auto dtype = py::dtype::from_args(np_literal);
return Common::type_helpers::get_ov_type(dtype);
}),
py::arg("dtype"),
R"(
@ -49,6 +50,7 @@ void regclass_graph_Type(py::module m) {
type.attr("u64") = ov::element::u64;
type.attr("bf16") = ov::element::bf16;
type.attr("nf4") = ov::element::nf4;
type.attr("string") = ov::element::string;
type.def("__hash__", &ov::element::Type::hash);
type.def("__repr__", [](const ov::element::Type& self) {
@ -120,7 +122,7 @@ void regclass_graph_Type(py::module m) {
type.def(
"to_dtype",
[](ov::element::Type& self) {
return Common::ov_type_to_dtype().at(self);
return Common::type_helpers::get_dtype(self);
},
R"(
Convert Type to numpy dtype.

View File

@ -612,7 +612,10 @@ def test_inputs_tuple_not_replaced(device, share_inputs):
def test_invalid_inputs(device, share_inputs):
request, _, _ = create_simple_request_and_inputs(device)
inputs = "some_input"
class InvalidInput():
pass
inputs = InvalidInput()
with pytest.raises(TypeError) as e:
request.infer(inputs, share_inputs=share_inputs)

View File

@ -560,3 +560,23 @@ def test_init_from_empty_array(shared_flag, init_value):
assert tensor.element_type.to_dtype() == init_value.dtype
assert tensor.byte_size == init_value.nbytes
assert np.array_equal(tensor.data, init_value)
@pytest.mark.parametrize(
"init_value",
[
([1.0, 2.0, 3.0]),
([21, 37, 42]),
([[10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]]),
([[2.2, 6.5], [0.2, 6.7]]),
],
)
def test_init_from_list(init_value):
tensor = ov.Tensor(init_value)
assert np.array_equal(tensor.data, init_value)
# Convert to numpy to perform all checks. Memory is not shared,
# so it does not matter if data is stored in numpy format.
_init_value = np.array(init_value)
assert tuple(tensor.shape) == _init_value.shape
assert tensor.element_type.to_dtype() == _init_value.dtype
assert tensor.byte_size == _init_value.nbytes

View File

@ -0,0 +1,256 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import openvino as ov
import pytest
from enum import Enum
class DataGetter(Enum):
BYTES = 1
STRINGS = 2
def _check_tensor_string(tensor_data, test_data):
assert tensor_data.shape == test_data.shape
assert tensor_data.strides == test_data.strides
assert np.array_equal(tensor_data, test_data)
assert not (np.shares_memory(tensor_data, test_data))
def check_bytes_based(tensor, string_data, to_flat=False):
tensor_data = tensor.bytes_data
encoded_data = string_data if string_data.dtype.kind == "S" else np.char.encode(string_data)
assert tensor_data.dtype.kind == "S"
_check_tensor_string(tensor_data.flatten() if to_flat else tensor_data, encoded_data.flatten() if to_flat else encoded_data)
def check_string_based(tensor, string_data, to_flat=False):
tensor_data = tensor.str_data
decoded_data = string_data if string_data.dtype.kind == "U" else np.char.decode(string_data)
assert tensor_data.dtype.kind == "U"
_check_tensor_string(tensor_data.flatten() if to_flat else tensor_data, decoded_data.flatten() if to_flat else decoded_data)
def test_string_tensor_shared_memory_fails():
data = np.array(["You", "shall", "not", "pass!"])
with pytest.raises(RuntimeError) as e:
_ = ov.Tensor(data, shared_memory=True)
assert "SHARED MEMORY MODE FOR THIS TENSOR IS NOT APPLICABLE! String types can be only copied." in str(e.value)
def test_string_tensor_data_warning():
data = np.array(["You", "shall", "not", "pass!"])
tensor = ov.Tensor(data, shared_memory=False)
with pytest.warns(RuntimeWarning) as w:
_ = tensor.data
assert "Data of string type will be copied! Please use dedicated properties" in str(w[0].message)
@pytest.mark.parametrize(
("init_type"),
[
(ov.Type.string),
(str),
(bytes),
(np.str_),
(np.bytes_),
],
)
def test_empty_string_tensor(init_type):
tensor = ov.Tensor(type=init_type, shape=ov.Shape([2, 2]))
assert tensor.element_type == ov.Type.string
@pytest.mark.parametrize(
("string_data"),
[
([bytes("text", encoding="utf-8"), bytes("openvino", encoding="utf-8")]),
([[b"xyz"], [b"abc"], [b"this is my last"]]),
(["text", "abc", "openvino"]),
(["text", "больше текста", "jeszcze więcej słów", "효과가 있었어"]),
([["text"], ["abc"], ["openvino"]]),
([["jeszcze więcej słów", "효과가 있었어"]]),
],
)
def test_init_with_list(string_data):
tensor = ov.Tensor(string_data)
assert tensor.element_type == ov.Type.string
# Convert to numpy to perform all checks. Memory is not shared,
# so it does not matter if data is stored in numpy format.
_string_data = np.array(string_data)
# Encoded:
check_bytes_based(tensor, _string_data)
# Decoded:
check_string_based(tensor, _string_data)
@pytest.mark.parametrize(
("string_data"),
[
(np.array(["text", "abc", "openvino"]).astype("S")), # "|S"
(np.array([["xyz"], ["abc"]]).astype(np.bytes_)), # "|S"
(np.array(["text", "abc", "openvino"])), # "<U"
(np.array(["text", "больше текста", "jeszcze więcej słów", "효과가 있었어"])), # "<U"
(np.array([["text"], ["abc"], ["openvino"]])), # "<U"
(np.array([["jeszcze więcej słów", "효과가 있었어"]])), # "<U"
],
)
def test_init_with_numpy(string_data):
tensor = ov.Tensor(string_data, shared_memory=False)
assert tensor.element_type == ov.Type.string
# Encoded:
check_bytes_based(tensor, string_data)
# Decoded:
check_string_based(tensor, string_data)
@pytest.mark.parametrize(
("init_type"),
[
(ov.Type.string),
(str),
(bytes),
(np.str_),
(np.bytes_),
],
)
@pytest.mark.parametrize(
("init_shape"),
[
(ov.Shape()),
(ov.Shape([])),
(ov.Shape([5])),
(ov.Shape([1, 1])),
(ov.Shape([2, 4, 5])),
],
)
@pytest.mark.parametrize(
("string_data"),
[
(np.array(["text", "abc", "openvino"]).astype(np.bytes_)), # "|S8"
(np.array([["text!"], ["abc?"]]).astype("S")), # "|S8"
(np.array(["text", "abc", "openvino"])), # "<U", depending on platform
(np.array([["text"], ["abc"], ["openvino"]])), # "<U", depending on platform
(np.array([["text", "больше текста"], ["jeszcze więcej słów", "효과가 있었어"]])), # "<U"
(np.array([["#text@", "больше текста"]])), # "<U"
],
)
def test_empty_tensor_copy_from(init_type, init_shape, string_data):
tensor = ov.Tensor(init_type, init_shape)
assert tensor.element_type == ov.Type.string
tensor.copy_from(string_data)
# Encoded:
check_bytes_based(tensor, string_data)
# Decoded:
check_string_based(tensor, string_data)
@pytest.mark.parametrize(
("init_shape"),
[
(ov.Shape()),
(ov.Shape([])),
(ov.Shape([1])),
(ov.Shape([8])),
(ov.Shape([4, 4])),
],
)
@pytest.mark.parametrize(
("string_data"),
[
(np.array(["text", "abc", "openvino"]).astype(np.bytes_)), # "|S"
(np.array([["text!"], ["abc?"]]).astype("S")), # "|S"
(np.array(["text", "abc", "openvino"])), # "<U"
(np.array([["text"], ["abc"], ["openvino"]])), # "<U"
(np.array([["text", "больше текста"], ["jeszcze więcej słów", "효과가 있었어"]])), # "<U"
(np.array([["#text@", "больше текста"]])), # "<U"
([bytes("text", encoding="utf-8"), bytes("openvino", encoding="utf-8")]),
([[b"xyz"], [b"abc"], [b"this is my last"]]),
(["text", "abc", "openvino"]),
(["text", "больше текста", "jeszcze więcej słów", "효과가 있었어"]),
([["text"], ["abc"], ["openvino"]]),
([["jeszcze więcej słów", "효과가 있었어"]]),
],
)
def test_populate_fails_size_check(init_shape, string_data):
tensor = ov.Tensor(ov.Type.string, init_shape)
assert tensor.element_type == ov.Type.string
with pytest.raises(RuntimeError) as e:
tensor.bytes_data = string_data
assert "Passed array must have the same size (number of elements) as the Tensor!" in str(e.value)
with pytest.raises(RuntimeError) as e:
tensor.str_data = string_data
assert "Passed array must have the same size (number of elements) as the Tensor!" in str(e.value)
@pytest.mark.parametrize(
("string_data"),
[
(np.array([0.6, 2.1, 3.7, 7.8])),
([1, 2, 3, 4, 5, 6, 7, 8, 9]),
],
)
def test_populate_fails_type_check(string_data):
tensor = ov.Tensor(ov.Type.string, ov.Shape([1]))
assert tensor.element_type == ov.Type.string
with pytest.raises(RuntimeError) as e:
tensor.bytes_data = string_data
assert "Unknown string kind passed to fill the Tensor's data!" in str(e.value)
with pytest.raises(RuntimeError) as e:
tensor.str_data = string_data
assert "Unknown string kind passed to fill the Tensor's data!" in str(e.value)
@pytest.mark.parametrize(
("init_type"),
[
(ov.Type.string),
(str),
(bytes),
(np.str_),
(np.bytes_),
],
)
@pytest.mark.parametrize(
("init_shape", "string_data"),
[
(ov.Shape([3]), np.array(["text", "abc", "openvino"]).astype(np.bytes_)),
(ov.Shape([3]), np.array(["text", "больше текста", "jeszcze więcej słów"])),
(ov.Shape([3]), [b"xyz", b"abc", b"this is my last"]),
(ov.Shape([3]), ["text", "abc", "openvino"]),
(ov.Shape([3]), ["text", "больше текста", "jeszcze więcej słów"]),
(ov.Shape([2, 2]), np.array(["text", "abc", "openvino", "different"]).astype(np.bytes_)),
(ov.Shape([2, 2]), np.array(["text", "больше текста", "jeszcze więcej słów", "abcdefg"])),
(ov.Shape([2, 2]), [b"xyz", b"abc", b"this is my last", b"this is my final"]),
(ov.Shape([2, 2]), [["text", "abc"], ["openvino", "abcdefg"]]),
(ov.Shape([2, 2]), ["text", "больше текста", "jeszcze więcej słów", "śćżó"]),
],
)
@pytest.mark.parametrize(
("data_getter"),
[
(DataGetter.BYTES),
(DataGetter.STRINGS),
],
)
def test_empty_tensor_populate(init_type, init_shape, string_data, data_getter):
tensor = ov.Tensor(init_type, init_shape)
assert tensor.element_type == ov.Type.string
if data_getter == DataGetter.BYTES:
tensor.bytes_data = string_data
elif data_getter == DataGetter.STRINGS:
tensor.str_data = string_data
else:
raise AttributeError("Unknown DataGetter passed!")
_string_data = np.array(string_data) if isinstance(string_data, list) else string_data
# Need to flatten the numpy array as Tensor can have different shape.
# It only checks if strings are filling the data correctly.
# Encoded:
check_bytes_based(tensor, _string_data, to_flat=True)
# Decoded:
check_string_based(tensor, _string_data, to_flat=True)

View File

@ -21,9 +21,20 @@ from openvino import Type
("uint32", np.uint32, Type.u32),
("uint64", np.uint64, Type.u64),
("bool", bool, Type.boolean),
("bytes_", np.bytes_, Type.string),
("str_", np.str_, Type.string),
("bytes", bytes, Type.string),
("str", str, Type.string),
("|S", np.dtype("|S"), Type.string),
("|U", np.dtype("|U"), Type.string),
])
def test_dtype_ovtype_conversion(dtype_string, dtype, ovtype):
assert ovtype.to_dtype() == dtype
if hasattr(dtype, "kind"):
assert ovtype.to_dtype() == np.bytes_
elif issubclass(dtype, (str, np.str_)):
assert ovtype.to_dtype() == np.bytes_
else:
assert ovtype.to_dtype() == dtype
assert Type(dtype_string) == ovtype
assert Type(dtype) == ovtype

View File

@ -4,11 +4,13 @@
import os
import pytest
from copy import deepcopy
import numpy as np
from tests.utils.helpers import generate_relu_compiled_model
from openvino import Type, Shape, Tensor
from openvino import Core, Model, Type, Shape, Tensor
import openvino.runtime.opset13 as ops
from openvino.runtime.utils.data_helpers import _data_dispatch
is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
@ -125,7 +127,7 @@ def test_ndarray_shared_dispatcher_casting(device, input_shape):
assert isinstance(result, Tensor)
assert result.get_shape() == Shape(test_data.shape)
assert result.get_element_type() == infer_request.inputs[0].get_element_type()
assert result.get_element_type() == infer_request.input_tensors[0].get_element_type()
assert np.array_equal(result.data, test_data)
test_data[0] = 2.0
@ -161,3 +163,145 @@ def test_ndarray_copied_dispatcher(device, input_shape):
test_data[0] = 2.0
assert not np.array_equal(infer_request.input_tensors[0].data, test_data)
@pytest.mark.parametrize(
("input_data"),
[
np.array(["śćżóąę", "data_dispatcher_test"]),
np.array(["abcdef", "data_dispatcher_test"]).astype("S"),
],
)
@pytest.mark.parametrize("data_type", [Type.string, str, bytes, np.str_, np.bytes_])
@pytest.mark.parametrize("input_shape", [[2], [2, 1]])
@pytest.mark.parametrize("is_shared", [True, False])
def test_string_array_dispatcher(device, input_data, data_type, input_shape, is_shared):
# Copy data so it won't be overriden by next testcase:
test_data = np.copy(input_data).reshape(input_shape)
param = ops.parameter(input_shape, data_type, name="data")
res = ops.result(param)
model = Model([res], [param], "test_model")
core = Core()
compiled_model = core.compile_model(model, device)
infer_request = compiled_model.create_infer_request()
result = _data_dispatch(infer_request, test_data, is_shared)
if is_shared:
assert isinstance(result, Tensor)
assert result.element_type == Type.string
assert result.shape == Shape(input_shape)
if test_data.dtype.kind == "U":
assert np.array_equal(result.bytes_data, np.char.encode(test_data))
assert np.array_equal(result.str_data, test_data)
else:
assert np.array_equal(result.bytes_data, test_data)
assert np.array_equal(result.str_data, np.char.decode(test_data))
assert not np.shares_memory(result.bytes_data, test_data)
assert not np.shares_memory(result.str_data, test_data)
else:
assert result == {}
if test_data.dtype.kind == "U":
assert np.array_equal(infer_request.input_tensors[0].bytes_data, np.char.encode(test_data))
assert np.array_equal(infer_request.input_tensors[0].str_data, test_data)
else:
assert np.array_equal(infer_request.input_tensors[0].bytes_data, test_data)
assert np.array_equal(infer_request.input_tensors[0].str_data, np.char.decode(test_data))
assert not np.shares_memory(infer_request.input_tensors[0].bytes_data, test_data)
assert not np.shares_memory(infer_request.input_tensors[0].str_data, test_data)
# Override value to confirm:
test_data[0] = "different string"
if test_data.dtype.kind == "U":
assert not np.array_equal(infer_request.input_tensors[0].bytes_data, np.char.encode(test_data))
assert not np.array_equal(infer_request.input_tensors[0].str_data, test_data)
else:
assert not np.array_equal(infer_request.input_tensors[0].bytes_data, test_data)
assert not np.array_equal(infer_request.input_tensors[0].str_data, np.char.decode(test_data))
@pytest.mark.parametrize(
("input_data", "input_shape"),
[
([["śćżóąę", "data_dispatcher_test"]], [2]),
([[b"abcdef", b"data_dispatcher_test"]], [2]),
([[bytes("abc", encoding="utf-8"), bytes("zzzz", encoding="utf-8")]], [2]),
([[["śćżóąę", "data_dispatcher_test"]]], [1, 2]),
([[["śćżóąę"], ["data_dispatcher_test"]]], [2, 1]),
],
)
@pytest.mark.parametrize("data_type", [Type.string, str, bytes, np.str_, np.bytes_])
@pytest.mark.parametrize("is_shared", [True, False])
def test_string_list_dispatcher(device, input_data, input_shape, data_type, is_shared):
# Copy data so it won't be overriden by next testcase:
test_data = deepcopy(input_data)
param = ops.parameter(input_shape, data_type, name="data")
res = ops.result(param)
model = Model([res], [param], "test_model")
core = Core()
compiled_model = core.compile_model(model, device)
infer_request = compiled_model.create_infer_request()
result_dict = _data_dispatch(infer_request, test_data, is_shared)
# Dispatcher will always return new Tensors from any lists.
# For copied approach it will be based of the list and ov.Tensor class
# is responsible for copying list over to C++ memory.
result = result_dict[0]
assert isinstance(result, Tensor)
assert result.element_type == Type.string
assert result.shape == Shape(input_shape)
# Convert input_data into numpy array to test properties
test_data_np = np.array(input_data).reshape(input_shape)
if test_data_np.dtype.kind == "U":
assert np.array_equal(result.bytes_data, np.char.encode(test_data_np))
assert np.array_equal(result.str_data, test_data_np)
else:
assert np.array_equal(result.bytes_data, test_data_np)
assert np.array_equal(result.str_data, np.char.decode(test_data_np))
@pytest.mark.parametrize(
("input_data"),
[
"śćżóąę",
"test dispatcher",
bytes("zzzz", encoding="utf-8"),
b"aaaaaaa",
"😁😁",
],
)
@pytest.mark.parametrize("data_type", [Type.string, str, bytes, np.str_, np.bytes_])
@pytest.mark.parametrize("is_shared", [True, False])
def test_string_scalar_dispatcher(device, input_data, data_type, is_shared):
test_data = input_data
param = ops.parameter([1], data_type, name="data")
res = ops.result(param)
model = Model([res], [param], "test_model")
core = Core()
compiled_model = core.compile_model(model, device)
infer_request = compiled_model.create_infer_request()
result = _data_dispatch(infer_request, test_data, is_shared)
# Result will always be a Tensor:
assert isinstance(result, Tensor)
assert result.element_type == Type.string
assert result.shape == Shape([])
if isinstance(test_data, str):
assert np.array_equal(result.bytes_data, np.char.encode(test_data))
assert np.array_equal(result.str_data, test_data)
else:
assert np.array_equal(result.bytes_data, test_data)
assert np.array_equal(result.str_data, np.char.decode(test_data))
assert not np.shares_memory(result.bytes_data, test_data)
assert not np.shares_memory(result.str_data, test_data)