[PYTHON API] Tensor.data property for low precisions + packing (#11131)
* rebase old branch with master * Fix doc style * fix test * update tests * Add missed param * Rewrite docstring for tensor and refactor set_input_tensors test * update python exclusives * keep compatibility * remove notes about slices * fix code style * Fix code style
This commit is contained in:
parent
701d75eafa
commit
6eaa15745a
@ -28,12 +28,6 @@ Python API allows passing data as tensors. `Tensor` object holds a copy of the d
|
||||
|
||||
@snippet docs/snippets/ov_python_exclusives.py tensor_shared_mode
|
||||
|
||||
### Slices of array's memory
|
||||
|
||||
One of the `Tensor` class constructors allows to share the slice of array's memory. When `shape` is specified in the constructor that has the numpy array as first argument, it triggers the special shared memory mode.
|
||||
|
||||
@snippet docs/snippets/ov_python_exclusives.py tensor_slice_mode
|
||||
|
||||
## Running inference
|
||||
|
||||
Python API supports extra calling methods to synchronous and asynchronous modes for inference.
|
||||
@ -76,6 +70,14 @@ The callback of `AsyncInferQueue` is uniform for every job. When executed, GIL i
|
||||
|
||||
@snippet docs/snippets/ov_python_exclusives.py asyncinferqueue_set_callback
|
||||
|
||||
### Working with u1, u4 and i4 element types
|
||||
|
||||
Since openvino supports low precision element types there are few ways how to handle them in python.
|
||||
To create an input tensor with such element types you may need to pack your data in new numpy array which byte size matches original input size:
|
||||
@snippet docs/snippets/ov_python_exclusives.py packing_data
|
||||
|
||||
To extract low precision values from tensor into numpy array you can use next helper:
|
||||
@snippet docs/snippets/ov_python_exclusives.py unpacking
|
||||
|
||||
### Releasing the GIL
|
||||
|
||||
|
@ -50,17 +50,6 @@ shared_tensor.data[0][2] = 0.6
|
||||
assert data_to_share[0][2] == 0.6
|
||||
#! [tensor_shared_mode]
|
||||
|
||||
#! [tensor_slice_mode]
|
||||
data_to_share = np.ones(shape=(2,8))
|
||||
|
||||
# Specify slice of memory and the shape
|
||||
shared_tensor = ov.Tensor(data_to_share[1][:] , shape=ov.Shape([8]))
|
||||
|
||||
# Editing of the numpy array affects Tensor's data
|
||||
data_to_share[1][:] = 2
|
||||
assert np.array_equal(shared_tensor.data, data_to_share[1][:])
|
||||
#! [tensor_slice_mode]
|
||||
|
||||
infer_request = compiled.create_infer_request()
|
||||
data = np.random.randint(-5, 3 + 1, size=(8))
|
||||
|
||||
@ -132,6 +121,23 @@ infer_queue.wait_all()
|
||||
assert all(data_done)
|
||||
#! [asyncinferqueue_set_callback]
|
||||
|
||||
unt8_data = np.ones([100])
|
||||
|
||||
#! [packing_data]
|
||||
from openvino.helpers import pack_data
|
||||
|
||||
packed_buffer = pack_data(unt8_data, ov.Type.u4)
|
||||
# Create tensor with shape in element types
|
||||
t = ov.Tensor(packed_buffer, [1, 128], ov.Type.u4)
|
||||
#! [packing_data]
|
||||
|
||||
#! [unpacking]
|
||||
from openvino.helpers import unpack_data
|
||||
|
||||
unpacked_data = unpack_data(t.data, t.element_type, t.shape)
|
||||
assert np.array_equal(unpacked_data , unt8_data)
|
||||
#! [unpacking]
|
||||
|
||||
#! [releasing_gil]
|
||||
import openvino.runtime as ov
|
||||
import cv2 as cv
|
||||
|
6
src/bindings/python/src/openvino/helpers/__init__.py
Normal file
6
src/bindings/python/src/openvino/helpers/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
from openvino.helpers.packing import pack_data, unpack_data
|
91
src/bindings/python/src/openvino/helpers/packing.py
Normal file
91
src/bindings/python/src/openvino/helpers/packing.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from openvino.runtime import Type, Shape
|
||||
|
||||
|
||||
def pack_data(array: np.ndarray, type: Type) -> np.ndarray:
|
||||
"""Represent array values as u1,u4 or i4 openvino element type and pack them into uint8 numpy array.
|
||||
|
||||
If the number of elements in array is odd we pad them with zero value to be able to fit the bit
|
||||
sequence into the uint8 array.
|
||||
|
||||
Example: two uint8 values - [7, 8] can be represented as uint4 values and be packed into one int8
|
||||
value - [120], because [7, 8] bit representation is [0111, 1000] will be viewed
|
||||
as [01111000], which is bit representation of [120].
|
||||
|
||||
:param array: numpy array with values to pack.
|
||||
:type array: numpy array
|
||||
:param type: Type to interpret the array values. Type must be u1, u4 or i4.
|
||||
:type type: openvino.runtime.Type
|
||||
"""
|
||||
assert type in [Type.u1, Type.u4, Type.i4], "Packing algorithm for the" \
|
||||
"data types stored in 1, 2 or 4 bits"
|
||||
|
||||
minimum_regular_dtype = np.int8 if type == Type.i4 else np.uint8
|
||||
casted_to_regular_type = array.astype(dtype=minimum_regular_dtype, casting="unsafe")
|
||||
if not np.array_equal(casted_to_regular_type, array):
|
||||
raise RuntimeError(f'The conversion of array "{array}" to dtype'
|
||||
f' "{casted_to_regular_type}" results in rounding')
|
||||
|
||||
data_size = casted_to_regular_type.size
|
||||
num_bits = type.bitwidth
|
||||
|
||||
assert num_bits < 8 and 8 % num_bits == 0, "Packing algorithm for the" \
|
||||
"data types stored in 1, 2 or 4 bits"
|
||||
num_values_fitting_into_uint8 = 8 // num_bits
|
||||
pad = (-data_size) % num_values_fitting_into_uint8
|
||||
|
||||
flattened = casted_to_regular_type.flatten()
|
||||
padded = np.concatenate((flattened, np.zeros([pad], dtype=minimum_regular_dtype)))
|
||||
assert padded.size % num_values_fitting_into_uint8 == 0
|
||||
|
||||
bit_order_little = (padded[:, None] & (1 << np.arange(num_bits)) > 0).astype(minimum_regular_dtype)
|
||||
bit_order_big = np.flip(bit_order_little, axis=1)
|
||||
bit_order_big_flattened = bit_order_big.flatten()
|
||||
|
||||
return np.packbits(bit_order_big_flattened)
|
||||
|
||||
|
||||
def unpack_data(array: np.ndarray, type: Type, shape: Union[list, Shape]) -> np.ndarray:
|
||||
"""Extract openvino element type values from array into new uint8/int8 array given shape.
|
||||
|
||||
Example: uint8 value [120] can be represented as two u4 values and be unpacked into [7, 8]
|
||||
because [120] bit representation is [01111000] will be viewed as [0111, 1000],
|
||||
which is bit representation of [7, 8].
|
||||
|
||||
:param array: numpy array to unpack.
|
||||
:type array: numpy array
|
||||
:param type: Type to extract from array values. Type must be u1, u4 or i4.
|
||||
:type type: openvino.runtime.Type
|
||||
:param shape: the new shape for the unpacked array.
|
||||
:type shape: Union[list, openvino.runtime.Shape]
|
||||
"""
|
||||
assert type in [Type.u1, Type.u4, Type.i4], "Unpacking algorithm for the" \
|
||||
"data types stored in 1, 2 or 4 bits"
|
||||
unpacked = np.unpackbits(array.view(np.uint8))
|
||||
shape = list(shape)
|
||||
if type.bitwidth == 1:
|
||||
return np.resize(unpacked, shape)
|
||||
else:
|
||||
unpacked = unpacked.reshape(-1, type.bitwidth)
|
||||
padding_shape = (unpacked.shape[0], 8 - type.bitwidth)
|
||||
padding = np.ndarray(padding_shape, np.uint8)
|
||||
if type == Type.i4:
|
||||
for axis, bits in enumerate(unpacked):
|
||||
if bits[0] == 1:
|
||||
padding[axis] = np.ones((padding_shape[1],), np.uint8)
|
||||
else:
|
||||
padding[axis] = np.zeros((padding_shape[1],), np.uint8)
|
||||
else:
|
||||
padding = np.zeros(padding_shape, np.uint8)
|
||||
padded = np.concatenate((padding, unpacked), 1)
|
||||
packed = np.packbits(padded, 1)
|
||||
if type == Type.i4:
|
||||
return np.resize(packed, shape).astype(dtype=np.int8)
|
||||
else:
|
||||
return np.resize(packed, shape)
|
@ -27,6 +27,8 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype() {
|
||||
{ov::element::u64, py::dtype("uint64")},
|
||||
{ov::element::boolean, py::dtype("bool")},
|
||||
{ov::element::u1, py::dtype("uint8")},
|
||||
{ov::element::u4, py::dtype("uint8")},
|
||||
{ov::element::i4, py::dtype("int8")},
|
||||
};
|
||||
return ov_type_to_dtype_mapping;
|
||||
}
|
||||
@ -49,12 +51,12 @@ const std::map<std::string, ov::element::Type>& dtype_to_ov_type() {
|
||||
return dtype_to_ov_type_mapping;
|
||||
}
|
||||
|
||||
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape) {
|
||||
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& type) {
|
||||
bool is_contiguous = C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS);
|
||||
auto type = Common::dtype_to_ov_type().at(py::str(array.dtype()));
|
||||
auto element_type = (type == ov::element::undefined) ? Common::dtype_to_ov_type().at(py::str(array.dtype())) : type;
|
||||
|
||||
if (is_contiguous) {
|
||||
return ov::Tensor(type, shape, const_cast<void*>(array.data(0)), {});
|
||||
return ov::Tensor(element_type, shape, const_cast<void*>(array.data(0)), {});
|
||||
} else {
|
||||
throw ov::Exception("Tensor with shared memory must be C contiguous!");
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ const std::map<ov::element::Type, py::dtype>& ov_type_to_dtype();
|
||||
|
||||
const std::map<std::string, ov::element::Type>& dtype_to_ov_type();
|
||||
|
||||
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape);
|
||||
ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& ov_type);
|
||||
|
||||
ov::Tensor tensor_from_numpy(py::array& array, bool shared_memory);
|
||||
|
||||
|
@ -35,28 +35,30 @@ void regclass_Tensor(py::module m) {
|
||||
:type shared_memory: bool
|
||||
)");
|
||||
|
||||
cls.def(py::init([](py::array& array, const ov::Shape& shape) {
|
||||
return Common::tensor_from_pointer(array, shape);
|
||||
cls.def(py::init([](py::array& array, const ov::Shape& shape, const ov::element::Type& ov_type) {
|
||||
return Common::tensor_from_pointer(array, shape, ov_type);
|
||||
}),
|
||||
py::arg("array"),
|
||||
py::arg("shape"),
|
||||
py::arg("type") = ov::element::undefined,
|
||||
R"(
|
||||
Another Tensor's special constructor.
|
||||
|
||||
It takes an array or slice of it, and shape that will be
|
||||
selected, starting from the first element of the given array/slice.
|
||||
Please use it only in advanced cases if necessary!
|
||||
Represents array in the memory with given shape and element type.
|
||||
It's recommended to use this constructor only for wrapping array's
|
||||
memory with the specific openvino element type parameter.
|
||||
|
||||
:param array: Underlaying methods will retrieve pointer on first element
|
||||
from it, which is simulating `host_ptr` from C++ API.
|
||||
Tensor memory is being shared with a host,
|
||||
that means the responsibility of keeping host memory is
|
||||
:param array: C_CONTIGUOUS numpy array which will be wrapped in
|
||||
openvino.runtime.Tensor with given parameters (shape
|
||||
and element_type). Array's memory is being shared with
|
||||
a host, that means the responsibility of keeping host memory is
|
||||
on the side of a user. Any action performed on the host
|
||||
memory will be reflected on this Tensor's memory!
|
||||
Data is required to be C_CONTIGUOUS.
|
||||
:type array: numpy.array
|
||||
:param shape: Shape of the new tensor.
|
||||
:type shape: openvino.runtime.Shape
|
||||
:param type: Element type
|
||||
:type type: openvino.runtime.Type
|
||||
|
||||
:Example:
|
||||
.. code-block:: python
|
||||
@ -64,15 +66,43 @@ void regclass_Tensor(py::module m) {
|
||||
import openvino.runtime as ov
|
||||
import numpy as np
|
||||
|
||||
arr = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
arr = np.array(shape=(100), dtype=np.uint8)
|
||||
t = ov.Tensor(arr, ov.Shape([100, 8]), ov.Type.u1)
|
||||
)");
|
||||
|
||||
t = ov.Tensor(arr[1][0:1], ov.Shape([3]))
|
||||
cls.def(py::init([](py::array& array, const std::vector<size_t> shape, const ov::element::Type& ov_type) {
|
||||
return Common::tensor_from_pointer(array, shape, ov_type);
|
||||
}),
|
||||
py::arg("array"),
|
||||
py::arg("shape"),
|
||||
py::arg("type") = ov::element::undefined,
|
||||
R"(
|
||||
Another Tensor's special constructor.
|
||||
|
||||
t.data[0] = 9
|
||||
Represents array in the memory with given shape and element type.
|
||||
It's recommended to use this constructor only for wrapping array's
|
||||
memory with the specific openvino element type parameter.
|
||||
|
||||
print(arr)
|
||||
>>> [[1 2 3]
|
||||
>>> [9 5 6]]
|
||||
:param array: C_CONTIGUOUS numpy array which will be wrapped in
|
||||
openvino.runtime.Tensor with given parameters (shape
|
||||
and element_type). Array's memory is being shared with
|
||||
a host, that means the responsibility of keeping host memory is
|
||||
on the side of a user. Any action performed on the host
|
||||
memory will be reflected on this Tensor's memory!
|
||||
:type array: numpy.array
|
||||
:param shape: Shape of the new tensor.
|
||||
:type shape: list or tuple
|
||||
:param type: Element type.
|
||||
:type type: openvino.runtime.Type
|
||||
|
||||
:Example:
|
||||
.. code-block:: python
|
||||
|
||||
import openvino.runtime as ov
|
||||
import numpy as np
|
||||
|
||||
arr = np.array(shape=(100), dtype=np.uint8)
|
||||
t = ov.Tensor(arr, [100, 8], ov.Type.u1)
|
||||
)");
|
||||
|
||||
cls.def(py::init<const ov::element::Type, const ov::Shape>(), py::arg("type"), py::arg("shape"));
|
||||
@ -177,15 +207,20 @@ void regclass_Tensor(py::module m) {
|
||||
cls.def_property_readonly(
|
||||
"data",
|
||||
[](ov::Tensor& self) {
|
||||
return py::array(Common::ov_type_to_dtype().at(self.get_element_type()),
|
||||
self.get_shape(),
|
||||
self.get_strides(),
|
||||
self.data(),
|
||||
py::cast(self));
|
||||
auto ov_type = self.get_element_type();
|
||||
auto dtype = Common::ov_type_to_dtype().at(ov_type);
|
||||
if (ov_type.bitwidth() < 8) {
|
||||
return py::array(dtype, self.get_byte_size(), self.data(), py::cast(self));
|
||||
}
|
||||
return py::array(dtype, self.get_shape(), self.get_strides(), self.data(), py::cast(self));
|
||||
},
|
||||
R"(
|
||||
Access to Tensor's data.
|
||||
|
||||
Returns numpy array with corresponding shape and dtype.
|
||||
For tensors with openvino specific element type, such as u1, u4 or i4
|
||||
it returns linear array, with uint8 / int8 numpy dtype.
|
||||
|
||||
:rtype: numpy.array
|
||||
)");
|
||||
|
||||
|
@ -19,11 +19,13 @@ void regclass_graph_Type(py::module m) {
|
||||
type.attr("f16") = ov::element::f16;
|
||||
type.attr("f32") = ov::element::f32;
|
||||
type.attr("f64") = ov::element::f64;
|
||||
type.attr("i4") = ov::element::i4;
|
||||
type.attr("i8") = ov::element::i8;
|
||||
type.attr("i16") = ov::element::i16;
|
||||
type.attr("i32") = ov::element::i32;
|
||||
type.attr("i64") = ov::element::i64;
|
||||
type.attr("u1") = ov::element::u1;
|
||||
type.attr("u4") = ov::element::u4;
|
||||
type.attr("u8") = ov::element::u8;
|
||||
type.attr("u16") = ov::element::u16;
|
||||
type.attr("u32") = ov::element::u32;
|
||||
|
@ -164,15 +164,15 @@ def test_set_tensors(device):
|
||||
@pytest.mark.dynamic_library
|
||||
@pytest.mark.template_extension
|
||||
def test_batched_tensors(device):
|
||||
batch = 4
|
||||
one_shape = Shape([1, 2, 2, 2])
|
||||
batch_shape = Shape([batch, 2, 2, 2])
|
||||
one_shape_size = np.prod(one_shape)
|
||||
|
||||
core = Core()
|
||||
|
||||
# TODO: remove when plugins will support set_input_tensors
|
||||
core.register_plugin("openvino_template_plugin", "TEMPLATE")
|
||||
|
||||
batch = 4
|
||||
one_shape = [1, 2, 2, 2]
|
||||
one_shape_size = np.prod(one_shape)
|
||||
batch_shape = [batch, 2, 2, 2]
|
||||
|
||||
data1 = ops.parameter(batch_shape, np.float32)
|
||||
data1.set_friendly_name("input0")
|
||||
data1.get_output_tensor(0).set_names({"tensor_input0"})
|
||||
@ -191,21 +191,21 @@ def test_batched_tensors(device):
|
||||
|
||||
compiled = core.compile_model(model, "TEMPLATE")
|
||||
|
||||
buffer = np.zeros([one_shape_size * batch * 2], dtype=np.float32)
|
||||
|
||||
req = compiled.create_infer_request()
|
||||
|
||||
# Allocate 8 chunks, set 'user tensors' to 0, 2, 4, 6 chunks
|
||||
buffer = np.zeros([batch * 2, *batch_shape[1:]], dtype=np.float32)
|
||||
|
||||
tensors = []
|
||||
for i in range(batch):
|
||||
# non contiguous memory (i*2)
|
||||
tensors.append(Tensor(np.expand_dims(buffer[i * 2], 0), shared_memory=True))
|
||||
|
||||
for i in range(0, batch):
|
||||
_start = i * one_shape_size * 2
|
||||
# Use of special constructor for Tensor.
|
||||
# It creates a Tensor from pointer, thus it requires only
|
||||
# one element from original buffer, and shape to "crop".
|
||||
tensor = Tensor(buffer[_start:(_start + 1)], one_shape)
|
||||
tensors.append(tensor)
|
||||
req.set_input_tensors(tensors)
|
||||
|
||||
req.set_input_tensors(tensors) # using list overload!
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
req.get_tensor("tensor_input0")
|
||||
assert "get_tensor shall not be used together with batched set_tensors/set_input_tensors" in str(e.value)
|
||||
|
||||
actual_tensor = req.get_tensor("tensor_output0")
|
||||
actual = actual_tensor.data
|
||||
|
@ -9,6 +9,7 @@ import numpy as np
|
||||
|
||||
import openvino.runtime as ov
|
||||
from openvino.runtime import Tensor
|
||||
from openvino.helpers import pack_data, unpack_data
|
||||
|
||||
import pytest
|
||||
|
||||
@ -29,7 +30,9 @@ from ..conftest import read_image
|
||||
(ov.Type.i64, np.int64),
|
||||
(ov.Type.u64, np.uint64),
|
||||
(ov.Type.boolean, np.bool),
|
||||
# (ov.Type.u1, np.uint8),
|
||||
(ov.Type.u1, np.uint8),
|
||||
(ov.Type.u4, np.uint8),
|
||||
(ov.Type.i4, np.int8),
|
||||
])
|
||||
def test_init_with_ngraph(ov_type, numpy_dtype):
|
||||
ov_tensors = []
|
||||
@ -106,13 +109,13 @@ def test_init_with_numpy_shared_memory(ov_type, numpy_dtype):
|
||||
assert np.array_equal(ov_tensor.data, arr)
|
||||
assert ov_tensor.size == arr.size
|
||||
assert ov_tensor.byte_size == arr.nbytes
|
||||
assert tuple(ov_tensor.strides) == arr.strides
|
||||
|
||||
assert tuple(ov_tensor.get_shape()) == shape
|
||||
assert ov_tensor.get_element_type() == ov_type
|
||||
assert ov_tensor.data.dtype == numpy_dtype
|
||||
assert ov_tensor.data.shape == shape
|
||||
assert ov_tensor.get_size() == arr.size
|
||||
assert ov_tensor.get_byte_size() == arr.nbytes
|
||||
assert tuple(ov_tensor.get_strides()) == arr.strides
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ov_type, numpy_dtype", [
|
||||
@ -145,7 +148,7 @@ def test_init_with_numpy_copy_memory(ov_type, numpy_dtype):
|
||||
|
||||
|
||||
def test_init_with_numpy_fail():
|
||||
arr = read_image()
|
||||
arr = np.asfortranarray(read_image())
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
_ = Tensor(array=arr, shared_memory=True)
|
||||
assert "Tensor with shared memory must be C contiguous" in str(e.value)
|
||||
@ -175,7 +178,6 @@ def test_init_with_roi_tensor():
|
||||
(ov.Type.i64, np.int64),
|
||||
(ov.Type.u64, np.uint64),
|
||||
(ov.Type.boolean, np.bool),
|
||||
# (ov.Type.u1, np.uint8),
|
||||
])
|
||||
def test_write_to_buffer(ov_type, numpy_dtype):
|
||||
ov_tensor = Tensor(ov_type, ov.Shape([1, 3, 32, 32]))
|
||||
@ -198,7 +200,6 @@ def test_write_to_buffer(ov_type, numpy_dtype):
|
||||
(ov.Type.i64, np.int64),
|
||||
(ov.Type.u64, np.uint64),
|
||||
(ov.Type.boolean, np.bool),
|
||||
# (ov.Type.u1, np.uint8),
|
||||
])
|
||||
def test_set_shape(ov_type, numpy_dtype):
|
||||
shape = ov.Shape([1, 3, 32, 32])
|
||||
@ -268,3 +269,108 @@ def test_cannot_set_shape_incorrect_dims():
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
ov_tensor.shape = [3, 28, 28]
|
||||
assert "Dims and format are inconsistent" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ov_type", [
|
||||
(ov.Type.u1),
|
||||
(ov.Type.u4),
|
||||
(ov.Type.i4),
|
||||
])
|
||||
def test_cannot_create_roi_from_packed_tensor(ov_type):
|
||||
ov_tensor = Tensor(ov_type, [1, 3, 48, 48])
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
Tensor(ov_tensor, [0, 0, 24, 24], [1, 3, 48, 48])
|
||||
assert "ROI Tensor for types with bitwidths less then 8 bit is not implemented" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ov_type", [
|
||||
(ov.Type.u1),
|
||||
(ov.Type.u4),
|
||||
(ov.Type.i4),
|
||||
])
|
||||
def test_cannot_get_strides_for_packed_tensor(ov_type):
|
||||
ov_tensor = Tensor(ov_type, [1, 3, 48, 48])
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
ov_tensor.get_strides()
|
||||
assert "Could not get strides for types with bitwidths less then 8 bit." in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [
|
||||
(np.uint8),
|
||||
(np.int8),
|
||||
(np.uint16),
|
||||
(np.uint32),
|
||||
(np.uint64),
|
||||
])
|
||||
@pytest.mark.parametrize("ov_type", [
|
||||
(ov.Type.u1),
|
||||
(ov.Type.u4),
|
||||
(ov.Type.i4),
|
||||
])
|
||||
def test_init_with_packed_buffer(dtype, ov_type):
|
||||
shape = [1, 3, 32, 32]
|
||||
fit = np.dtype(dtype).itemsize * 8 / ov_type.bitwidth
|
||||
assert np.prod(shape) % fit == 0
|
||||
size = int(np.prod(shape) // fit)
|
||||
buffer = np.random.normal(size=size).astype(dtype)
|
||||
ov_tensor = Tensor(buffer, shape, ov_type)
|
||||
assert ov_tensor.data.nbytes == ov_tensor.byte_size
|
||||
assert np.array_equal(ov_tensor.data.view(dtype), buffer)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [
|
||||
([1, 3, 28, 28]),
|
||||
([1, 3, 27, 27]),
|
||||
])
|
||||
@pytest.mark.parametrize("low, high, ov_type, dtype", [
|
||||
(0, 2, ov.Type.u1, np.uint8),
|
||||
(0, 16, ov.Type.u4, np.uint8),
|
||||
(-8, 7, ov.Type.i4, np.int8),
|
||||
])
|
||||
def test_packing(shape, low, high, ov_type, dtype):
|
||||
ov_tensor = Tensor(ov_type, shape)
|
||||
data = np.random.uniform(low, high, shape).astype(dtype)
|
||||
packed_data = pack_data(data, ov_tensor.element_type)
|
||||
ov_tensor.data[:] = packed_data
|
||||
unpacked = unpack_data(ov_tensor.data, ov_tensor.element_type, ov_tensor.shape)
|
||||
assert np.array_equal(unpacked, data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [
|
||||
(np.uint8),
|
||||
(np.int8),
|
||||
(np.int16),
|
||||
(np.uint16),
|
||||
(np.int32),
|
||||
(np.uint32),
|
||||
(np.int64),
|
||||
(np.uint64),
|
||||
(np.float16),
|
||||
(np.float32),
|
||||
(np.float64),
|
||||
])
|
||||
@pytest.mark.parametrize("element_type", [
|
||||
(ov.Type.u8),
|
||||
(ov.Type.i8),
|
||||
(ov.Type.i16),
|
||||
(ov.Type.u16),
|
||||
(ov.Type.i32),
|
||||
(ov.Type.u32),
|
||||
(ov.Type.i64),
|
||||
(ov.Type.u64),
|
||||
# (ov.Type.f16),
|
||||
# (ov.Type.f32),
|
||||
# (ov.Type.f64),
|
||||
])
|
||||
def test_viewed_tensor(dtype, element_type):
|
||||
buffer = np.random.normal(size=(2, 16)).astype(dtype)
|
||||
fit = (dtype().nbytes * 8) / element_type.bitwidth
|
||||
t = Tensor(buffer, (buffer.shape[0], int(buffer.shape[1] * fit)), element_type)
|
||||
assert np.array_equal(t.data, buffer.view(ov.utils.types.get_dtype(element_type)))
|
||||
|
||||
|
||||
def test_viewed_tensor_default_type():
|
||||
buffer = np.random.normal(size=(2, 16))
|
||||
new_shape = (4, 8)
|
||||
t = Tensor(buffer, new_shape)
|
||||
assert np.array_equal(t.data, buffer.reshape(new_shape))
|
||||
|
Loading…
Reference in New Issue
Block a user