From 6eaa15745a6404c4ad4a869261bd99d8ebfc6645 Mon Sep 17 00:00:00 2001 From: Alexey Lebedev Date: Fri, 1 Apr 2022 12:04:04 +0300 Subject: [PATCH] [PYTHON API] Tensor.data property for low precisions + packing (#11131) * rebase old branch with master * Fix doc style * fix test * update tests * Add missed param * Rewrite docstring for tensor and refactor set_input_tensors test * update python exclusives * keep compatibility * remove notes about slices * fix code style * Fix code style --- docs/OV_Runtime_UG/Python_API_exclusives.md | 14 ++- docs/snippets/ov_python_exclusives.py | 28 +++-- .../python/src/openvino/helpers/__init__.py | 6 + .../python/src/openvino/helpers/packing.py | 91 ++++++++++++++ .../python/src/pyopenvino/core/common.cpp | 8 +- .../python/src/pyopenvino/core/common.hpp | 2 +- .../python/src/pyopenvino/core/tensor.cpp | 77 ++++++++---- .../pyopenvino/graph/types/element_type.cpp | 2 + .../test_infer_request.py | 32 ++--- .../test_inference_engine/test_tensor.py | 118 +++++++++++++++++- 10 files changed, 314 insertions(+), 64 deletions(-) create mode 100644 src/bindings/python/src/openvino/helpers/__init__.py create mode 100644 src/bindings/python/src/openvino/helpers/packing.py diff --git a/docs/OV_Runtime_UG/Python_API_exclusives.md b/docs/OV_Runtime_UG/Python_API_exclusives.md index 21e134ee6e4..203b2a6fe19 100644 --- a/docs/OV_Runtime_UG/Python_API_exclusives.md +++ b/docs/OV_Runtime_UG/Python_API_exclusives.md @@ -28,12 +28,6 @@ Python API allows passing data as tensors. `Tensor` object holds a copy of the d @snippet docs/snippets/ov_python_exclusives.py tensor_shared_mode -### Slices of array's memory - -One of the `Tensor` class constructors allows to share the slice of array's memory. When `shape` is specified in the constructor that has the numpy array as first argument, it triggers the special shared memory mode. - -@snippet docs/snippets/ov_python_exclusives.py tensor_slice_mode - ## Running inference Python API supports extra calling methods to synchronous and asynchronous modes for inference. @@ -76,6 +70,14 @@ The callback of `AsyncInferQueue` is uniform for every job. When executed, GIL i @snippet docs/snippets/ov_python_exclusives.py asyncinferqueue_set_callback +### Working with u1, u4 and i4 element types + +Since openvino supports low precision element types there are few ways how to handle them in python. +To create an input tensor with such element types you may need to pack your data in new numpy array which byte size matches original input size: +@snippet docs/snippets/ov_python_exclusives.py packing_data + +To extract low precision values from tensor into numpy array you can use next helper: +@snippet docs/snippets/ov_python_exclusives.py unpacking ### Releasing the GIL diff --git a/docs/snippets/ov_python_exclusives.py b/docs/snippets/ov_python_exclusives.py index 6481ba6d5e9..1af21080ce8 100644 --- a/docs/snippets/ov_python_exclusives.py +++ b/docs/snippets/ov_python_exclusives.py @@ -50,17 +50,6 @@ shared_tensor.data[0][2] = 0.6 assert data_to_share[0][2] == 0.6 #! [tensor_shared_mode] -#! [tensor_slice_mode] -data_to_share = np.ones(shape=(2,8)) - -# Specify slice of memory and the shape -shared_tensor = ov.Tensor(data_to_share[1][:] , shape=ov.Shape([8])) - -# Editing of the numpy array affects Tensor's data -data_to_share[1][:] = 2 -assert np.array_equal(shared_tensor.data, data_to_share[1][:]) -#! [tensor_slice_mode] - infer_request = compiled.create_infer_request() data = np.random.randint(-5, 3 + 1, size=(8)) @@ -132,6 +121,23 @@ infer_queue.wait_all() assert all(data_done) #! [asyncinferqueue_set_callback] +unt8_data = np.ones([100]) + +#! [packing_data] +from openvino.helpers import pack_data + +packed_buffer = pack_data(unt8_data, ov.Type.u4) +# Create tensor with shape in element types +t = ov.Tensor(packed_buffer, [1, 128], ov.Type.u4) +#! [packing_data] + +#! [unpacking] +from openvino.helpers import unpack_data + +unpacked_data = unpack_data(t.data, t.element_type, t.shape) +assert np.array_equal(unpacked_data , unt8_data) +#! [unpacking] + #! [releasing_gil] import openvino.runtime as ov import cv2 as cv diff --git a/src/bindings/python/src/openvino/helpers/__init__.py b/src/bindings/python/src/openvino/helpers/__init__.py new file mode 100644 index 00000000000..69500b97be1 --- /dev/null +++ b/src/bindings/python/src/openvino/helpers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# flake8: noqa + +from openvino.helpers.packing import pack_data, unpack_data diff --git a/src/bindings/python/src/openvino/helpers/packing.py b/src/bindings/python/src/openvino/helpers/packing.py new file mode 100644 index 00000000000..bd86e1c9d44 --- /dev/null +++ b/src/bindings/python/src/openvino/helpers/packing.py @@ -0,0 +1,91 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# flake8: noqa + +import numpy as np +from typing import Union +from openvino.runtime import Type, Shape + + +def pack_data(array: np.ndarray, type: Type) -> np.ndarray: + """Represent array values as u1,u4 or i4 openvino element type and pack them into uint8 numpy array. + + If the number of elements in array is odd we pad them with zero value to be able to fit the bit + sequence into the uint8 array. + + Example: two uint8 values - [7, 8] can be represented as uint4 values and be packed into one int8 + value - [120], because [7, 8] bit representation is [0111, 1000] will be viewed + as [01111000], which is bit representation of [120]. + + :param array: numpy array with values to pack. + :type array: numpy array + :param type: Type to interpret the array values. Type must be u1, u4 or i4. + :type type: openvino.runtime.Type + """ + assert type in [Type.u1, Type.u4, Type.i4], "Packing algorithm for the" \ + "data types stored in 1, 2 or 4 bits" + + minimum_regular_dtype = np.int8 if type == Type.i4 else np.uint8 + casted_to_regular_type = array.astype(dtype=minimum_regular_dtype, casting="unsafe") + if not np.array_equal(casted_to_regular_type, array): + raise RuntimeError(f'The conversion of array "{array}" to dtype' + f' "{casted_to_regular_type}" results in rounding') + + data_size = casted_to_regular_type.size + num_bits = type.bitwidth + + assert num_bits < 8 and 8 % num_bits == 0, "Packing algorithm for the" \ + "data types stored in 1, 2 or 4 bits" + num_values_fitting_into_uint8 = 8 // num_bits + pad = (-data_size) % num_values_fitting_into_uint8 + + flattened = casted_to_regular_type.flatten() + padded = np.concatenate((flattened, np.zeros([pad], dtype=minimum_regular_dtype))) + assert padded.size % num_values_fitting_into_uint8 == 0 + + bit_order_little = (padded[:, None] & (1 << np.arange(num_bits)) > 0).astype(minimum_regular_dtype) + bit_order_big = np.flip(bit_order_little, axis=1) + bit_order_big_flattened = bit_order_big.flatten() + + return np.packbits(bit_order_big_flattened) + + +def unpack_data(array: np.ndarray, type: Type, shape: Union[list, Shape]) -> np.ndarray: + """Extract openvino element type values from array into new uint8/int8 array given shape. + + Example: uint8 value [120] can be represented as two u4 values and be unpacked into [7, 8] + because [120] bit representation is [01111000] will be viewed as [0111, 1000], + which is bit representation of [7, 8]. + + :param array: numpy array to unpack. + :type array: numpy array + :param type: Type to extract from array values. Type must be u1, u4 or i4. + :type type: openvino.runtime.Type + :param shape: the new shape for the unpacked array. + :type shape: Union[list, openvino.runtime.Shape] + """ + assert type in [Type.u1, Type.u4, Type.i4], "Unpacking algorithm for the" \ + "data types stored in 1, 2 or 4 bits" + unpacked = np.unpackbits(array.view(np.uint8)) + shape = list(shape) + if type.bitwidth == 1: + return np.resize(unpacked, shape) + else: + unpacked = unpacked.reshape(-1, type.bitwidth) + padding_shape = (unpacked.shape[0], 8 - type.bitwidth) + padding = np.ndarray(padding_shape, np.uint8) + if type == Type.i4: + for axis, bits in enumerate(unpacked): + if bits[0] == 1: + padding[axis] = np.ones((padding_shape[1],), np.uint8) + else: + padding[axis] = np.zeros((padding_shape[1],), np.uint8) + else: + padding = np.zeros(padding_shape, np.uint8) + padded = np.concatenate((padding, unpacked), 1) + packed = np.packbits(padded, 1) + if type == Type.i4: + return np.resize(packed, shape).astype(dtype=np.int8) + else: + return np.resize(packed, shape) diff --git a/src/bindings/python/src/pyopenvino/core/common.cpp b/src/bindings/python/src/pyopenvino/core/common.cpp index 3540d9fef44..8c103ba3cee 100644 --- a/src/bindings/python/src/pyopenvino/core/common.cpp +++ b/src/bindings/python/src/pyopenvino/core/common.cpp @@ -27,6 +27,8 @@ const std::map& ov_type_to_dtype() { {ov::element::u64, py::dtype("uint64")}, {ov::element::boolean, py::dtype("bool")}, {ov::element::u1, py::dtype("uint8")}, + {ov::element::u4, py::dtype("uint8")}, + {ov::element::i4, py::dtype("int8")}, }; return ov_type_to_dtype_mapping; } @@ -49,12 +51,12 @@ const std::map& dtype_to_ov_type() { return dtype_to_ov_type_mapping; } -ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape) { +ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& type) { bool is_contiguous = C_CONTIGUOUS == (array.flags() & C_CONTIGUOUS); - auto type = Common::dtype_to_ov_type().at(py::str(array.dtype())); + auto element_type = (type == ov::element::undefined) ? Common::dtype_to_ov_type().at(py::str(array.dtype())) : type; if (is_contiguous) { - return ov::Tensor(type, shape, const_cast(array.data(0)), {}); + return ov::Tensor(element_type, shape, const_cast(array.data(0)), {}); } else { throw ov::Exception("Tensor with shared memory must be C contiguous!"); } diff --git a/src/bindings/python/src/pyopenvino/core/common.hpp b/src/bindings/python/src/pyopenvino/core/common.hpp index 36c0326fd12..3db05c5b4a8 100644 --- a/src/bindings/python/src/pyopenvino/core/common.hpp +++ b/src/bindings/python/src/pyopenvino/core/common.hpp @@ -28,7 +28,7 @@ const std::map& ov_type_to_dtype(); const std::map& dtype_to_ov_type(); -ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape); +ov::Tensor tensor_from_pointer(py::array& array, const ov::Shape& shape, const ov::element::Type& ov_type); ov::Tensor tensor_from_numpy(py::array& array, bool shared_memory); diff --git a/src/bindings/python/src/pyopenvino/core/tensor.cpp b/src/bindings/python/src/pyopenvino/core/tensor.cpp index 7d3d77098f4..ce3586c85da 100644 --- a/src/bindings/python/src/pyopenvino/core/tensor.cpp +++ b/src/bindings/python/src/pyopenvino/core/tensor.cpp @@ -35,28 +35,30 @@ void regclass_Tensor(py::module m) { :type shared_memory: bool )"); - cls.def(py::init([](py::array& array, const ov::Shape& shape) { - return Common::tensor_from_pointer(array, shape); + cls.def(py::init([](py::array& array, const ov::Shape& shape, const ov::element::Type& ov_type) { + return Common::tensor_from_pointer(array, shape, ov_type); }), py::arg("array"), py::arg("shape"), + py::arg("type") = ov::element::undefined, R"( Another Tensor's special constructor. - It takes an array or slice of it, and shape that will be - selected, starting from the first element of the given array/slice. - Please use it only in advanced cases if necessary! + Represents array in the memory with given shape and element type. + It's recommended to use this constructor only for wrapping array's + memory with the specific openvino element type parameter. - :param array: Underlaying methods will retrieve pointer on first element - from it, which is simulating `host_ptr` from C++ API. - Tensor memory is being shared with a host, - that means the responsibility of keeping host memory is + :param array: C_CONTIGUOUS numpy array which will be wrapped in + openvino.runtime.Tensor with given parameters (shape + and element_type). Array's memory is being shared with + a host, that means the responsibility of keeping host memory is on the side of a user. Any action performed on the host memory will be reflected on this Tensor's memory! - Data is required to be C_CONTIGUOUS. :type array: numpy.array :param shape: Shape of the new tensor. :type shape: openvino.runtime.Shape + :param type: Element type + :type type: openvino.runtime.Type :Example: .. code-block:: python @@ -64,15 +66,43 @@ void regclass_Tensor(py::module m) { import openvino.runtime as ov import numpy as np - arr = np.array([[1, 2, 3], [4, 5, 6]]) + arr = np.array(shape=(100), dtype=np.uint8) + t = ov.Tensor(arr, ov.Shape([100, 8]), ov.Type.u1) + )"); - t = ov.Tensor(arr[1][0:1], ov.Shape([3])) + cls.def(py::init([](py::array& array, const std::vector shape, const ov::element::Type& ov_type) { + return Common::tensor_from_pointer(array, shape, ov_type); + }), + py::arg("array"), + py::arg("shape"), + py::arg("type") = ov::element::undefined, + R"( + Another Tensor's special constructor. - t.data[0] = 9 + Represents array in the memory with given shape and element type. + It's recommended to use this constructor only for wrapping array's + memory with the specific openvino element type parameter. - print(arr) - >>> [[1 2 3] - >>> [9 5 6]] + :param array: C_CONTIGUOUS numpy array which will be wrapped in + openvino.runtime.Tensor with given parameters (shape + and element_type). Array's memory is being shared with + a host, that means the responsibility of keeping host memory is + on the side of a user. Any action performed on the host + memory will be reflected on this Tensor's memory! + :type array: numpy.array + :param shape: Shape of the new tensor. + :type shape: list or tuple + :param type: Element type. + :type type: openvino.runtime.Type + + :Example: + .. code-block:: python + + import openvino.runtime as ov + import numpy as np + + arr = np.array(shape=(100), dtype=np.uint8) + t = ov.Tensor(arr, [100, 8], ov.Type.u1) )"); cls.def(py::init(), py::arg("type"), py::arg("shape")); @@ -177,15 +207,20 @@ void regclass_Tensor(py::module m) { cls.def_property_readonly( "data", [](ov::Tensor& self) { - return py::array(Common::ov_type_to_dtype().at(self.get_element_type()), - self.get_shape(), - self.get_strides(), - self.data(), - py::cast(self)); + auto ov_type = self.get_element_type(); + auto dtype = Common::ov_type_to_dtype().at(ov_type); + if (ov_type.bitwidth() < 8) { + return py::array(dtype, self.get_byte_size(), self.data(), py::cast(self)); + } + return py::array(dtype, self.get_shape(), self.get_strides(), self.data(), py::cast(self)); }, R"( Access to Tensor's data. + Returns numpy array with corresponding shape and dtype. + For tensors with openvino specific element type, such as u1, u4 or i4 + it returns linear array, with uint8 / int8 numpy dtype. + :rtype: numpy.array )"); diff --git a/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp b/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp index b41d5c000d5..f295d865d99 100644 --- a/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp +++ b/src/bindings/python/src/pyopenvino/graph/types/element_type.cpp @@ -19,11 +19,13 @@ void regclass_graph_Type(py::module m) { type.attr("f16") = ov::element::f16; type.attr("f32") = ov::element::f32; type.attr("f64") = ov::element::f64; + type.attr("i4") = ov::element::i4; type.attr("i8") = ov::element::i8; type.attr("i16") = ov::element::i16; type.attr("i32") = ov::element::i32; type.attr("i64") = ov::element::i64; type.attr("u1") = ov::element::u1; + type.attr("u4") = ov::element::u4; type.attr("u8") = ov::element::u8; type.attr("u16") = ov::element::u16; type.attr("u32") = ov::element::u32; diff --git a/src/bindings/python/tests/test_inference_engine/test_infer_request.py b/src/bindings/python/tests/test_inference_engine/test_infer_request.py index 4119637c04e..746d1b17100 100644 --- a/src/bindings/python/tests/test_inference_engine/test_infer_request.py +++ b/src/bindings/python/tests/test_inference_engine/test_infer_request.py @@ -164,15 +164,15 @@ def test_set_tensors(device): @pytest.mark.dynamic_library @pytest.mark.template_extension def test_batched_tensors(device): - batch = 4 - one_shape = Shape([1, 2, 2, 2]) - batch_shape = Shape([batch, 2, 2, 2]) - one_shape_size = np.prod(one_shape) - core = Core() - + # TODO: remove when plugins will support set_input_tensors core.register_plugin("openvino_template_plugin", "TEMPLATE") + batch = 4 + one_shape = [1, 2, 2, 2] + one_shape_size = np.prod(one_shape) + batch_shape = [batch, 2, 2, 2] + data1 = ops.parameter(batch_shape, np.float32) data1.set_friendly_name("input0") data1.get_output_tensor(0).set_names({"tensor_input0"}) @@ -191,21 +191,21 @@ def test_batched_tensors(device): compiled = core.compile_model(model, "TEMPLATE") - buffer = np.zeros([one_shape_size * batch * 2], dtype=np.float32) - req = compiled.create_infer_request() + # Allocate 8 chunks, set 'user tensors' to 0, 2, 4, 6 chunks + buffer = np.zeros([batch * 2, *batch_shape[1:]], dtype=np.float32) + tensors = [] + for i in range(batch): + # non contiguous memory (i*2) + tensors.append(Tensor(np.expand_dims(buffer[i * 2], 0), shared_memory=True)) - for i in range(0, batch): - _start = i * one_shape_size * 2 - # Use of special constructor for Tensor. - # It creates a Tensor from pointer, thus it requires only - # one element from original buffer, and shape to "crop". - tensor = Tensor(buffer[_start:(_start + 1)], one_shape) - tensors.append(tensor) + req.set_input_tensors(tensors) - req.set_input_tensors(tensors) # using list overload! + with pytest.raises(RuntimeError) as e: + req.get_tensor("tensor_input0") + assert "get_tensor shall not be used together with batched set_tensors/set_input_tensors" in str(e.value) actual_tensor = req.get_tensor("tensor_output0") actual = actual_tensor.data diff --git a/src/bindings/python/tests/test_inference_engine/test_tensor.py b/src/bindings/python/tests/test_inference_engine/test_tensor.py index b37b7e213a2..b4deae12e6d 100644 --- a/src/bindings/python/tests/test_inference_engine/test_tensor.py +++ b/src/bindings/python/tests/test_inference_engine/test_tensor.py @@ -9,6 +9,7 @@ import numpy as np import openvino.runtime as ov from openvino.runtime import Tensor +from openvino.helpers import pack_data, unpack_data import pytest @@ -29,7 +30,9 @@ from ..conftest import read_image (ov.Type.i64, np.int64), (ov.Type.u64, np.uint64), (ov.Type.boolean, np.bool), - # (ov.Type.u1, np.uint8), + (ov.Type.u1, np.uint8), + (ov.Type.u4, np.uint8), + (ov.Type.i4, np.int8), ]) def test_init_with_ngraph(ov_type, numpy_dtype): ov_tensors = [] @@ -106,13 +109,13 @@ def test_init_with_numpy_shared_memory(ov_type, numpy_dtype): assert np.array_equal(ov_tensor.data, arr) assert ov_tensor.size == arr.size assert ov_tensor.byte_size == arr.nbytes + assert tuple(ov_tensor.strides) == arr.strides assert tuple(ov_tensor.get_shape()) == shape assert ov_tensor.get_element_type() == ov_type - assert ov_tensor.data.dtype == numpy_dtype - assert ov_tensor.data.shape == shape assert ov_tensor.get_size() == arr.size assert ov_tensor.get_byte_size() == arr.nbytes + assert tuple(ov_tensor.get_strides()) == arr.strides @pytest.mark.parametrize("ov_type, numpy_dtype", [ @@ -145,7 +148,7 @@ def test_init_with_numpy_copy_memory(ov_type, numpy_dtype): def test_init_with_numpy_fail(): - arr = read_image() + arr = np.asfortranarray(read_image()) with pytest.raises(RuntimeError) as e: _ = Tensor(array=arr, shared_memory=True) assert "Tensor with shared memory must be C contiguous" in str(e.value) @@ -175,7 +178,6 @@ def test_init_with_roi_tensor(): (ov.Type.i64, np.int64), (ov.Type.u64, np.uint64), (ov.Type.boolean, np.bool), - # (ov.Type.u1, np.uint8), ]) def test_write_to_buffer(ov_type, numpy_dtype): ov_tensor = Tensor(ov_type, ov.Shape([1, 3, 32, 32])) @@ -198,7 +200,6 @@ def test_write_to_buffer(ov_type, numpy_dtype): (ov.Type.i64, np.int64), (ov.Type.u64, np.uint64), (ov.Type.boolean, np.bool), - # (ov.Type.u1, np.uint8), ]) def test_set_shape(ov_type, numpy_dtype): shape = ov.Shape([1, 3, 32, 32]) @@ -268,3 +269,108 @@ def test_cannot_set_shape_incorrect_dims(): with pytest.raises(RuntimeError) as e: ov_tensor.shape = [3, 28, 28] assert "Dims and format are inconsistent" in str(e.value) + + +@pytest.mark.parametrize("ov_type", [ + (ov.Type.u1), + (ov.Type.u4), + (ov.Type.i4), +]) +def test_cannot_create_roi_from_packed_tensor(ov_type): + ov_tensor = Tensor(ov_type, [1, 3, 48, 48]) + with pytest.raises(RuntimeError) as e: + Tensor(ov_tensor, [0, 0, 24, 24], [1, 3, 48, 48]) + assert "ROI Tensor for types with bitwidths less then 8 bit is not implemented" in str(e.value) + + +@pytest.mark.parametrize("ov_type", [ + (ov.Type.u1), + (ov.Type.u4), + (ov.Type.i4), +]) +def test_cannot_get_strides_for_packed_tensor(ov_type): + ov_tensor = Tensor(ov_type, [1, 3, 48, 48]) + with pytest.raises(RuntimeError) as e: + ov_tensor.get_strides() + assert "Could not get strides for types with bitwidths less then 8 bit." in str(e.value) + + +@pytest.mark.parametrize("dtype", [ + (np.uint8), + (np.int8), + (np.uint16), + (np.uint32), + (np.uint64), +]) +@pytest.mark.parametrize("ov_type", [ + (ov.Type.u1), + (ov.Type.u4), + (ov.Type.i4), +]) +def test_init_with_packed_buffer(dtype, ov_type): + shape = [1, 3, 32, 32] + fit = np.dtype(dtype).itemsize * 8 / ov_type.bitwidth + assert np.prod(shape) % fit == 0 + size = int(np.prod(shape) // fit) + buffer = np.random.normal(size=size).astype(dtype) + ov_tensor = Tensor(buffer, shape, ov_type) + assert ov_tensor.data.nbytes == ov_tensor.byte_size + assert np.array_equal(ov_tensor.data.view(dtype), buffer) + + +@pytest.mark.parametrize("shape", [ + ([1, 3, 28, 28]), + ([1, 3, 27, 27]), +]) +@pytest.mark.parametrize("low, high, ov_type, dtype", [ + (0, 2, ov.Type.u1, np.uint8), + (0, 16, ov.Type.u4, np.uint8), + (-8, 7, ov.Type.i4, np.int8), +]) +def test_packing(shape, low, high, ov_type, dtype): + ov_tensor = Tensor(ov_type, shape) + data = np.random.uniform(low, high, shape).astype(dtype) + packed_data = pack_data(data, ov_tensor.element_type) + ov_tensor.data[:] = packed_data + unpacked = unpack_data(ov_tensor.data, ov_tensor.element_type, ov_tensor.shape) + assert np.array_equal(unpacked, data) + + +@pytest.mark.parametrize("dtype", [ + (np.uint8), + (np.int8), + (np.int16), + (np.uint16), + (np.int32), + (np.uint32), + (np.int64), + (np.uint64), + (np.float16), + (np.float32), + (np.float64), +]) +@pytest.mark.parametrize("element_type", [ + (ov.Type.u8), + (ov.Type.i8), + (ov.Type.i16), + (ov.Type.u16), + (ov.Type.i32), + (ov.Type.u32), + (ov.Type.i64), + (ov.Type.u64), + # (ov.Type.f16), + # (ov.Type.f32), + # (ov.Type.f64), +]) +def test_viewed_tensor(dtype, element_type): + buffer = np.random.normal(size=(2, 16)).astype(dtype) + fit = (dtype().nbytes * 8) / element_type.bitwidth + t = Tensor(buffer, (buffer.shape[0], int(buffer.shape[1] * fit)), element_type) + assert np.array_equal(t.data, buffer.view(ov.utils.types.get_dtype(element_type))) + + +def test_viewed_tensor_default_type(): + buffer = np.random.normal(size=(2, 16)) + new_shape = (4, 8) + t = Tensor(buffer, new_shape) + assert np.array_equal(t.data, buffer.reshape(new_shape))