[PYTHON API] Tensor.data property for low precisions + packing (#11131)

* rebase old branch with master

* Fix doc style

* fix test

* update tests

* Add missed param

* Rewrite docstring for tensor and refactor set_input_tensors test

* update python exclusives

* keep compatibility

* remove notes about slices

* fix code style

* Fix code style
This commit is contained in:
Alexey Lebedev 2022-04-01 12:04:04 +03:00 committed by GitHub
parent 701d75eafa
commit 6eaa15745a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 314 additions and 64 deletions

View File

@ -28,12 +28,6 @@ Python API allows passing data as tensors. `Tensor` object holds a copy of the d
@snippet docs/snippets/ov_python_exclusives.py tensor_shared_mode
### Slices of array's memory
One of the `Tensor` class constructors allows to share the slice of array's memory. When `shape` is specified in the constructor that has the numpy array as first argument, it triggers the special shared memory mode.
@snippet docs/snippets/ov_python_exclusives.py tensor_slice_mode
## Running inference
Python API supports extra calling methods to synchronous and asynchronous modes for inference.
@ -76,6 +70,14 @@ The callback of `AsyncInferQueue` is uniform for every job. When executed, GIL i
@snippet docs/snippets/ov_python_exclusives.py asyncinferqueue_set_callback
### Working with u1, u4 and i4 element types
Since openvino supports low precision element types there are few ways how to handle them in python.
To create an input tensor with such element types you may need to pack your data in new numpy array which byte size matches original input size:
@snippet docs/snippets/ov_python_exclusives.py packing_data
To extract low precision values from tensor into numpy array you can use next helper:
@snippet docs/snippets/ov_python_exclusives.py unpacking
### Releasing the GIL

View File

@ -50,17 +50,6 @@ shared_tensor.data[0][2] = 0.6
assert data_to_share[0][2] == 0.6
#! [tensor_shared_mode]
#! [tensor_slice_mode]
data_to_share = np.ones(shape=(2,8))
# Specify slice of memory and the shape
shared_tensor = ov.Tensor(data_to_share[1][:] , shape=ov.Shape([8]))
# Editing of the numpy array affects Tensor's data
data_to_share[1][:] = 2
assert np.array_equal(shared_tensor.data, data_to_share[1][:])
#! [tensor_slice_mode]
infer_request = compiled.create_infer_request()
data = np.random.randint(-5, 3 + 1, size=(8))
@ -132,6 +121,23 @@ infer_queue.wait_all()
assert all(data_done)
#! [asyncinferqueue_set_callback]
unt8_data = np.ones([100])
#! [packing_data]
from openvino.helpers import pack_data
packed_buffer = pack_data(unt8_data, ov.Type.u4)
# Create tensor with shape in element types
t = ov.Tensor(packed_buffer, [1, 128], ov.Type.u4)
#! [packing_data]
#! [unpacking]
from openvino.helpers import unpack_data
unpacked_data = unpack_data(t.data, t.element_type, t.shape)
assert np.array_equal(unpacked_data , unt8_data)
#! [unpacking]
#! [releasing_gil]
import openvino.runtime as ov
import cv2 as cv

View File

@ -0,0 +1,6 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
from openvino.helpers.packing import pack_data, unpack_data

View File

@ -0,0 +1,91 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
import numpy as np
from typing import Union
from openvino.runtime import Type, Shape
def pack_data(array: np.ndarray, type: Type) -> np.ndarray:
"""Represent array values as u1,u4 or i4 openvino element type and pack them into uint8 numpy array.
If the number of elements in array is odd we pad them with zero value to be able to fit the bit
sequence into the uint8 array.
Example: two uint8 values - [7, 8] can be represented as uint4 values and be packed into one int8
value - [120], because [7, 8] bit representation is [0111, 1000] will be viewed
as [01111000], which is bit representation of [120].
:param array: numpy array with values to pack.
:type array: numpy array
:param type: Type to interpret the array values. Type must be u1, u4 or i4.
:type type: openvino.runtime.Type
"""
assert type in [Type.u1, Type.u4, Type.i4], "Packing algorithm for the" \
"data types stored in 1, 2 or 4 bits"
minimum_regular_dtype = np.int8 if type == Type.i4 else np.uint8
casted_to_regular_type = array.astype(dtype=minimum_regular_dtype, casting="unsafe")
if not np.array_equal(casted_to_regular_type, array):
raise RuntimeError(f'The conversion of array "{array}" to dtype'
f' "{casted_to_regular_type}" results in rounding')
data_size = casted_to_regular_type.size
num_bits = type.bitwidth
assert num_bits < 8 and 8 % num_bits == 0, "Packing algorithm for the" \
"data types stored in 1, 2 or 4 bits"
num_values_fitting_into_uint8 = 8 // num_bits
pad = (-data_size) % num_values_fitting_into_uint8
flattened = casted_to_regular_type.flatten()
padded = np.concatenate((flattened, np.zeros([pad], dtype=minimum_regular_dtype)))
assert padded.size % num_values_fitting_into_uint8 == 0
bit_order_little = (padded[:, None] & (1 << np.arange(num_bits)) > 0).astype(minimum_regular_dtype)
bit_order_big = np.flip(bit_order_little, axis=1)
bit_order_big_flattened = bit_order_big.flatten()
return np.packbits(bit_order_big_flattened)
def unpack_data(array: np.ndarray, type: Type, shape: Union[list, Shape]) -> np.ndarray:
"""Extract openvino element type values from array into new uint8/int8 array given shape.
Example: uint8 value [120] can be represented as two u4 values and be unpacked into [7, 8]
because [120] bit representation is [01111000] will be viewed as [0111, 1000],
which is bit representation of [7, 8].
:param array: numpy array to unpack.
:type array: numpy array
:param type: Type to extract from array values. Type must be u1, u4 or i4.
:type type: openvino.runtime.Type
:param shape: the new shape for the unpacked array.
:type shape: Union[list, openvino.runtime.Shape]
"""
assert type in [Type.u1, Type.u4, Type.i4], "Unpacking algorithm for the" \
"data types stored in 1, 2 or 4 bits"
unpacked = np.unpackbits(array.view(np.uint8))
shape = list(shape)
if type.bitwidth == 1:
return np.resize(unpacked, shape)
else:
unpacked = unpacked.reshape(-1, type.bitwidth)
padding_shape = (unpacked.shape[0], 8 - type.bitwidth)
padding = np.ndarray(padding_shape, np.uint8)
if type == Type.i4:
for axis, bits in enumerate(unpacked):
if bits[0] == 1:
padding[axis] = np.ones((padding_shape[1],), np.uint8)
else:
padding[axis] = np.zeros((padding_shape[1],), np.uint8)
else:
padding = np.zeros(padding_shape, np.uint8)
padded = np.concatenate((padding, unpacked), 1)
packed = np.packbits(padded, 1)
if type == Type.i4:
return np.resize(packed, shape).astype(dtype=np.int8)
else:
return np.resize(packed, shape)

View File

@ -27,6 +27,8 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
{ov::element::u64, py::dtype("uint64")},
{ov::element::boolean, py::dtype("bool")},
{ov::element::u1, py::dtype("uint8")},
{ov::element::u4, py::dtype("uint8")},
{ov::element::i4, py::dtype("int8")},
};
return ov_type_to_dtype_mapping;
}
@ -49,12 +51,12 @@ const std::map<std::string, ov::element::Type>& dtype_to_ov_type() {
return dtype_to_ov_type_mapping;
}
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape) {
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& type) {
bool is_contiguous = C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS);
auto type = Common::dtype_to_ov_type().at(py::str(array.dtype()));
auto element_type = (type == ov::element::undefined) ? Common::dtype_to_ov_type().at(py::str(array.dtype())) : type;
if (is_contiguous) {
return ov::Tensor(type, shape, const_cast<void*>(array.data(0)), {});
return ov::Tensor(element_type, shape, const_cast<void*>(array.data(0)), {});
} else {
throw ov::Exception("Tensor with shared memory must be C contiguous!");
}

View File

@ -28,7 +28,7 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype();
const std::map<std::string, ov::element::Type>& dtype_to_ov_type();
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape);
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& ov_type);
ov::Tensor tensor_from_numpy(py::array& array, bool shared_memory);

View File

@ -35,28 +35,30 @@ void regclass_Tensor(py::module m) {
:type shared_memory: bool
)");
cls.def(py::init([](py::array& array, const ov::Shape& shape) {
return Common::tensor_from_pointer(array, shape);
cls.def(py::init([](py::array& array, const ov::Shape& shape, const ov::element::Type& ov_type) {
return Common::tensor_from_pointer(array, shape, ov_type);
}),
py::arg("array"),
py::arg("shape"),
py::arg("type") = ov::element::undefined,
R"(
Another Tensor's special constructor.
It takes an array or slice of it, and shape that will be
selected, starting from the first element of the given array/slice.
Please use it only in advanced cases if necessary!
Represents array in the memory with given shape and element type.
It's recommended to use this constructor only for wrapping array's
memory with the specific openvino element type parameter.
:param array: Underlaying methods will retrieve pointer on first element
from it, which is simulating `host_ptr` from C++ API.
Tensor memory is being shared with a host,
that means the responsibility of keeping host memory is
:param array: C_CONTIGUOUS numpy array which will be wrapped in
openvino.runtime.Tensor with given parameters (shape
and element_type). Array's memory is being shared with
a host, that means the responsibility of keeping host memory is
on the side of a user. Any action performed on the host
memory will be reflected on this Tensor's memory!
Data is required to be C_CONTIGUOUS.
:type array: numpy.array
:param shape: Shape of the new tensor.
:type shape: openvino.runtime.Shape
:param type: Element type
:type type: openvino.runtime.Type
:Example:
.. code-block:: python
@ -64,15 +66,43 @@ void regclass_Tensor(py::module m) {
import openvino.runtime as ov
import numpy as np
arr = np.array([[1, 2, 3], [4, 5, 6]])
arr = np.array(shape=(100), dtype=np.uint8)
t = ov.Tensor(arr, ov.Shape([100, 8]), ov.Type.u1)
)");
t = ov.Tensor(arr[1][0:1], ov.Shape([3]))
cls.def(py::init([](py::array& array, const std::vector<size_t> shape, const ov::element::Type& ov_type) {
return Common::tensor_from_pointer(array, shape, ov_type);
}),
py::arg("array"),
py::arg("shape"),
py::arg("type") = ov::element::undefined,
R"(
Another Tensor's special constructor.
t.data[0] = 9
Represents array in the memory with given shape and element type.
It's recommended to use this constructor only for wrapping array's
memory with the specific openvino element type parameter.
print(arr)
>>> [[1 2 3]
>>> [9 5 6]]
:param array: C_CONTIGUOUS numpy array which will be wrapped in
openvino.runtime.Tensor with given parameters (shape
and element_type). Array's memory is being shared with
a host, that means the responsibility of keeping host memory is
on the side of a user. Any action performed on the host
memory will be reflected on this Tensor's memory!
:type array: numpy.array
:param shape: Shape of the new tensor.
:type shape: list or tuple
:param type: Element type.
:type type: openvino.runtime.Type
:Example:
.. code-block:: python
import openvino.runtime as ov
import numpy as np
arr = np.array(shape=(100), dtype=np.uint8)
t = ov.Tensor(arr, [100, 8], ov.Type.u1)
)");
cls.def(py::init<const ov::element::Type, const ov::Shape>(), py::arg("type"), py::arg("shape"));
@ -177,15 +207,20 @@ void regclass_Tensor(py::module m) {
cls.def_property_readonly(
"data",
[](ov::Tensor& self) {
return py::array(Common::ov_type_to_dtype().at(self.get_element_type()),
self.get_shape(),
self.get_strides(),
self.data(),
py::cast(self));
auto ov_type = self.get_element_type();
auto dtype = Common::ov_type_to_dtype().at(ov_type);
if (ov_type.bitwidth() < 8) {
return py::array(dtype, self.get_byte_size(), self.data(), py::cast(self));
}
return py::array(dtype, self.get_shape(), self.get_strides(), self.data(), py::cast(self));
},
R"(
Access to Tensor's data.
Returns numpy array with corresponding shape and dtype.
For tensors with openvino specific element type, such as u1, u4 or i4
it returns linear array, with uint8 / int8 numpy dtype.
:rtype: numpy.array
)");

View File

@ -19,11 +19,13 @@ void regclass_graph_Type(py::module m) {
type.attr("f16") = ov::element::f16;
type.attr("f32") = ov::element::f32;
type.attr("f64") = ov::element::f64;
type.attr("i4") = ov::element::i4;
type.attr("i8") = ov::element::i8;
type.attr("i16") = ov::element::i16;
type.attr("i32") = ov::element::i32;
type.attr("i64") = ov::element::i64;
type.attr("u1") = ov::element::u1;
type.attr("u4") = ov::element::u4;
type.attr("u8") = ov::element::u8;
type.attr("u16") = ov::element::u16;
type.attr("u32") = ov::element::u32;

View File

@ -164,15 +164,15 @@ def test_set_tensors(device):
@pytest.mark.dynamic_library
@pytest.mark.template_extension
def test_batched_tensors(device):
batch = 4
one_shape = Shape([1, 2, 2, 2])
batch_shape = Shape([batch, 2, 2, 2])
one_shape_size = np.prod(one_shape)
core = Core()
# TODO: remove when plugins will support set_input_tensors
core.register_plugin("openvino_template_plugin", "TEMPLATE")
batch = 4
one_shape = [1, 2, 2, 2]
one_shape_size = np.prod(one_shape)
batch_shape = [batch, 2, 2, 2]
data1 = ops.parameter(batch_shape, np.float32)
data1.set_friendly_name("input0")
data1.get_output_tensor(0).set_names({"tensor_input0"})
@ -191,21 +191,21 @@ def test_batched_tensors(device):
compiled = core.compile_model(model, "TEMPLATE")
buffer = np.zeros([one_shape_size * batch * 2], dtype=np.float32)
req = compiled.create_infer_request()
# Allocate 8 chunks, set 'user tensors' to 0, 2, 4, 6 chunks
buffer = np.zeros([batch * 2, *batch_shape[1:]], dtype=np.float32)
tensors = []
for i in range(batch):
# non contiguous memory (i*2)
tensors.append(Tensor(np.expand_dims(buffer[i * 2], 0), shared_memory=True))
for i in range(0, batch):
_start = i * one_shape_size * 2
# Use of special constructor for Tensor.
# It creates a Tensor from pointer, thus it requires only
# one element from original buffer, and shape to "crop".
tensor = Tensor(buffer[_start:(_start + 1)], one_shape)
tensors.append(tensor)
req.set_input_tensors(tensors)
req.set_input_tensors(tensors) # using list overload!
with pytest.raises(RuntimeError) as e:
req.get_tensor("tensor_input0")
assert "get_tensor shall not be used together with batched set_tensors/set_input_tensors" in str(e.value)
actual_tensor = req.get_tensor("tensor_output0")
actual = actual_tensor.data

View File

@ -9,6 +9,7 @@ import numpy as np
import openvino.runtime as ov
from openvino.runtime import Tensor
from openvino.helpers import pack_data, unpack_data
import pytest
@ -29,7 +30,9 @@ from ..conftest import read_image
(ov.Type.i64, np.int64),
(ov.Type.u64, np.uint64),
(ov.Type.boolean, np.bool),
# (ov.Type.u1, np.uint8),
(ov.Type.u1, np.uint8),
(ov.Type.u4, np.uint8),
(ov.Type.i4, np.int8),
])
def test_init_with_ngraph(ov_type, numpy_dtype):
ov_tensors = []
@ -106,13 +109,13 @@ def test_init_with_numpy_shared_memory(ov_type, numpy_dtype):
assert np.array_equal(ov_tensor.data, arr)
assert ov_tensor.size == arr.size
assert ov_tensor.byte_size == arr.nbytes
assert tuple(ov_tensor.strides) == arr.strides
assert tuple(ov_tensor.get_shape()) == shape
assert ov_tensor.get_element_type() == ov_type
assert ov_tensor.data.dtype == numpy_dtype
assert ov_tensor.data.shape == shape
assert ov_tensor.get_size() == arr.size
assert ov_tensor.get_byte_size() == arr.nbytes
assert tuple(ov_tensor.get_strides()) == arr.strides
@pytest.mark.parametrize("ov_type, numpy_dtype", [
@ -145,7 +148,7 @@ def test_init_with_numpy_copy_memory(ov_type, numpy_dtype):
def test_init_with_numpy_fail():
arr = read_image()
arr = np.asfortranarray(read_image())
with pytest.raises(RuntimeError) as e:
_ = Tensor(array=arr, shared_memory=True)
assert "Tensor with shared memory must be C contiguous" in str(e.value)
@ -175,7 +178,6 @@ def test_init_with_roi_tensor():
(ov.Type.i64, np.int64),
(ov.Type.u64, np.uint64),
(ov.Type.boolean, np.bool),
# (ov.Type.u1, np.uint8),
])
def test_write_to_buffer(ov_type, numpy_dtype):
ov_tensor = Tensor(ov_type, ov.Shape([1, 3, 32, 32]))
@ -198,7 +200,6 @@ def test_write_to_buffer(ov_type, numpy_dtype):
(ov.Type.i64, np.int64),
(ov.Type.u64, np.uint64),
(ov.Type.boolean, np.bool),
# (ov.Type.u1, np.uint8),
])
def test_set_shape(ov_type, numpy_dtype):
shape = ov.Shape([1, 3, 32, 32])
@ -268,3 +269,108 @@ def test_cannot_set_shape_incorrect_dims():
with pytest.raises(RuntimeError) as e:
ov_tensor.shape = [3, 28, 28]
assert "Dims and format are inconsistent" in str(e.value)
@pytest.mark.parametrize("ov_type", [
(ov.Type.u1),
(ov.Type.u4),
(ov.Type.i4),
])
def test_cannot_create_roi_from_packed_tensor(ov_type):
ov_tensor = Tensor(ov_type, [1, 3, 48, 48])
with pytest.raises(RuntimeError) as e:
Tensor(ov_tensor, [0, 0, 24, 24], [1, 3, 48, 48])
assert "ROI Tensor for types with bitwidths less then 8 bit is not implemented" in str(e.value)
@pytest.mark.parametrize("ov_type", [
(ov.Type.u1),
(ov.Type.u4),
(ov.Type.i4),
])
def test_cannot_get_strides_for_packed_tensor(ov_type):
ov_tensor = Tensor(ov_type, [1, 3, 48, 48])
with pytest.raises(RuntimeError) as e:
ov_tensor.get_strides()
assert "Could not get strides for types with bitwidths less then 8 bit." in str(e.value)
@pytest.mark.parametrize("dtype", [
(np.uint8),
(np.int8),
(np.uint16),
(np.uint32),
(np.uint64),
])
@pytest.mark.parametrize("ov_type", [
(ov.Type.u1),
(ov.Type.u4),
(ov.Type.i4),
])
def test_init_with_packed_buffer(dtype, ov_type):
shape = [1, 3, 32, 32]
fit = np.dtype(dtype).itemsize * 8 / ov_type.bitwidth
assert np.prod(shape) % fit == 0
size = int(np.prod(shape) // fit)
buffer = np.random.normal(size=size).astype(dtype)
ov_tensor = Tensor(buffer, shape, ov_type)
assert ov_tensor.data.nbytes == ov_tensor.byte_size
assert np.array_equal(ov_tensor.data.view(dtype), buffer)
@pytest.mark.parametrize("shape", [
([1, 3, 28, 28]),
([1, 3, 27, 27]),
])
@pytest.mark.parametrize("low, high, ov_type, dtype", [
(0, 2, ov.Type.u1, np.uint8),
(0, 16, ov.Type.u4, np.uint8),
(-8, 7, ov.Type.i4, np.int8),
])
def test_packing(shape, low, high, ov_type, dtype):
ov_tensor = Tensor(ov_type, shape)
data = np.random.uniform(low, high, shape).astype(dtype)
packed_data = pack_data(data, ov_tensor.element_type)
ov_tensor.data[:] = packed_data
unpacked = unpack_data(ov_tensor.data, ov_tensor.element_type, ov_tensor.shape)
assert np.array_equal(unpacked, data)
@pytest.mark.parametrize("dtype", [
(np.uint8),
(np.int8),
(np.int16),
(np.uint16),
(np.int32),
(np.uint32),
(np.int64),
(np.uint64),
(np.float16),
(np.float32),
(np.float64),
])
@pytest.mark.parametrize("element_type", [
(ov.Type.u8),
(ov.Type.i8),
(ov.Type.i16),
(ov.Type.u16),
(ov.Type.i32),
(ov.Type.u32),
(ov.Type.i64),
(ov.Type.u64),
# (ov.Type.f16),
# (ov.Type.f32),
# (ov.Type.f64),
])
def test_viewed_tensor(dtype, element_type):
buffer = np.random.normal(size=(2, 16)).astype(dtype)
fit = (dtype().nbytes * 8) / element_type.bitwidth
t = Tensor(buffer, (buffer.shape[0], int(buffer.shape[1] * fit)), element_type)
assert np.array_equal(t.data, buffer.view(ov.utils.types.get_dtype(element_type)))
def test_viewed_tensor_default_type():
buffer = np.random.normal(size=(2, 16))
new_shape = (4, 8)
t = Tensor(buffer, new_shape)
assert np.array_equal(t.data, buffer.reshape(new_shape))