From 59052d846e421608dca8ab0871cade46d6cce0e0 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 11 Jul 2023 19:05:22 +0200 Subject: [PATCH] [PyOV] Memory flow control with share_inputs and share_outputs (#18275) * Added ReturnPolicy and updated common array helpers * Clean up * Remove ReturnPolicy initial * Add share_inputs and share_outputs * Tests and minor fixes * Fix docstrings * Fix whitespace * Fix typing --- .../python/src/openvino/runtime/ie_api.py | 135 ++++++++-- .../python/src/pyopenvino/core/common.cpp | 75 ++---- .../python/src/pyopenvino/core/common.hpp | 4 +- .../src/pyopenvino/core/infer_request.cpp | 16 +- .../python/src/pyopenvino/core/tensor.cpp | 7 +- .../tests/test_runtime/test_infer_request.py | 246 ++++++++++++------ 6 files changed, 305 insertions(+), 178 deletions(-) diff --git a/src/bindings/python/src/openvino/runtime/ie_api.py b/src/bindings/python/src/openvino/runtime/ie_api.py index f634617c733..8ad7b9cb2c8 100644 --- a/src/bindings/python/src/openvino/runtime/ie_api.py +++ b/src/bindings/python/src/openvino/runtime/ie_api.py @@ -4,6 +4,7 @@ from typing import Any, Iterable, Union, Optional from pathlib import Path +import warnings import numpy as np @@ -22,6 +23,19 @@ from openvino.runtime.utils.data_helpers import ( ) +def _deprecated_memory_arg(shared_memory: bool, share_inputs: bool) -> bool: + if shared_memory is not None: + warnings.warn( + "`shared_memory` is deprecated and will be removed in 2024.0. " + "Value of `shared_memory` is going to override `share_inputs` value. " + "Please use only `share_inputs` explicitly.", + FutureWarning, + stacklevel=3, + ) + return shared_memory + return share_inputs + + class Model(ModelBase): def __init__(self, *args: Any, **kwargs: Any) -> None: if args and not kwargs: @@ -40,7 +54,14 @@ class Model(ModelBase): class InferRequest(_InferRequestWrapper): """InferRequest class represents infer request which can be run in asynchronous or synchronous manners.""" - def infer(self, inputs: Any = None, shared_memory: bool = False) -> OVDict: + def infer( + self, + inputs: Any = None, + share_inputs: bool = False, + share_outputs: bool = False, + *, + shared_memory: Any = None, + ) -> OVDict: """Infers specified input(s) in synchronous mode. Blocks all methods of InferRequest while request is running. @@ -63,7 +84,7 @@ class InferRequest(_InferRequestWrapper): :param inputs: Data to be set on input tensors. :type inputs: Any, optional - :param shared_memory: Enables `shared_memory` mode. + :param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs. If set to `False` inputs the data dispatcher will safely copy data to existing Tensors (including up- or down-casting according to data type, @@ -78,10 +99,32 @@ class InferRequest(_InferRequestWrapper): * inputs that should be in `BF16` data type * scalar inputs (i.e. `np.float_`/`int`/`float`) Keeps Tensor inputs "as-is". + Note: Use with extra care, shared data can be modified during runtime! - Note: Using `shared_memory` may result in the extra memory overhead. + Note: Using `share_inputs` may result in extra memory overhead. Default value: False + :type share_inputs: bool, optional + :param share_outputs: Enables `share_outputs` mode. Controls memory usage on inference's outputs. + + If set to `False` outputs will safely copy data to numpy arrays. + + If set to `True` the data will be returned in form of views of output Tensors. + This mode still returns the data in format of numpy arrays but lifetime of the data + is connected to OpenVINO objects. + + Note: Use with extra care, shared data can be modified or lost during runtime! + + Default value: False + :type share_outputs: bool, optional + :param shared_memory: Deprecated. Works like `share_inputs` mode. + + If not specified, function uses `share_inputs` value. + + Note: Will be removed in 2024.0 release! + Note: This is keyword-only argument. + + Default value: None :type shared_memory: bool, optional :return: Dictionary of results from output tensors with port/int/str keys. :rtype: OVDict @@ -89,14 +132,16 @@ class InferRequest(_InferRequestWrapper): return OVDict(super().infer(_data_dispatch( self, inputs, - is_shared=shared_memory, - ))) + is_shared=_deprecated_memory_arg(shared_memory, share_inputs), + ), share_outputs=share_outputs)) def start_async( self, inputs: Any = None, userdata: Any = None, - shared_memory: bool = False, + share_inputs: bool = False, + *, + shared_memory: Any = None, ) -> None: """Starts inference of specified input(s) in asynchronous mode. @@ -123,7 +168,7 @@ class InferRequest(_InferRequestWrapper): :type inputs: Any, optional :param userdata: Any data that will be passed inside the callback. :type userdata: Any - :param shared_memory: Enables `shared_memory` mode. + :param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs. If set to `False` inputs the data dispatcher will safely copy data to existing Tensors (including up- or down-casting according to data type, @@ -138,17 +183,27 @@ class InferRequest(_InferRequestWrapper): * inputs that should be in `BF16` data type * scalar inputs (i.e. `np.float_`/`int`/`float`) Keeps Tensor inputs "as-is". + Note: Use with extra care, shared data can be modified during runtime! - Note: Using `shared_memory` may result in extra memory overhead. + Note: Using `share_inputs` may result in extra memory overhead. Default value: False + :type share_inputs: bool, optional + :param shared_memory: Deprecated. Works like `share_inputs` mode. + + If not specified, function uses `share_inputs` value. + + Note: Will be removed in 2024.0 release! + Note: This is keyword-only argument. + + Default value: None :type shared_memory: bool, optional """ super().start_async( _data_dispatch( self, inputs, - is_shared=shared_memory, + is_shared=_deprecated_memory_arg(shared_memory, share_inputs), ), userdata, ) @@ -229,9 +284,14 @@ class CompiledModel(CompiledModelBase): # overloaded functions of InferRequest class return self.create_infer_request().infer(inputs) - def __call__(self, - inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None, - shared_memory: bool = True) -> OVDict: + def __call__( + self, + inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None, + share_inputs: bool = True, + share_outputs: bool = False, + *, + shared_memory: Any = None, + ) -> OVDict: """Callable infer wrapper for CompiledModel. Infers specified input(s) in synchronous mode. @@ -262,7 +322,7 @@ class CompiledModel(CompiledModelBase): :param inputs: Data to be set on input tensors. :type inputs: Union[Dict[keys, values], List[values], Tuple[values], Tensor, numpy.ndarray], optional - :param shared_memory: Enables `shared_memory` mode. + :param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs. If set to `False` inputs the data dispatcher will safely copy data to existing Tensors (including up- or down-casting according to data type, @@ -277,12 +337,33 @@ class CompiledModel(CompiledModelBase): * inputs that should be in `BF16` data type * scalar inputs (i.e. `np.float_`/`int`/`float`) Keeps Tensor inputs "as-is". + Note: Use with extra care, shared data can be modified during runtime! - Note: Using `shared_memory` may result in extra memory overhead. + Note: Using `share_inputs` may result in extra memory overhead. Default value: True - :type shared_memory: bool, optional + :type share_inputs: bool, optional + :param share_outputs: Enables `share_outputs` mode. Controls memory usage on inference's outputs. + If set to `False` outputs will safely copy data to numpy arrays. + + If set to `True` the data will be returned in form of views of output Tensors. + This mode still returns the data in format of numpy arrays but lifetime of the data + is connected to OpenVINO objects. + + Note: Use with extra care, shared data can be modified or lost during runtime! + + Default value: False + :type share_outputs: bool, optional + :param shared_memory: Deprecated. Works like `share_inputs` mode. + + If not specified, function uses `share_inputs` value. + + Note: Will be removed in 2024.0 release! + Note: This is keyword-only argument. + + Default value: None + :type shared_memory: bool, optional :return: Dictionary of results from output tensors with port/int/str as keys. :rtype: OVDict """ @@ -291,7 +372,8 @@ class CompiledModel(CompiledModelBase): return self._infer_request.infer( inputs, - shared_memory=shared_memory, + share_inputs=_deprecated_memory_arg(shared_memory, share_inputs), + share_outputs=share_outputs, ) @@ -333,7 +415,9 @@ class AsyncInferQueue(AsyncInferQueueBase): self, inputs: Any = None, userdata: Any = None, - shared_memory: bool = False, + share_inputs: bool = False, + *, + shared_memory: Any = None, ) -> None: """Run asynchronous inference using the next available InferRequest from the pool. @@ -356,7 +440,7 @@ class AsyncInferQueue(AsyncInferQueueBase): :type inputs: Any, optional :param userdata: Any data that will be passed to a callback. :type userdata: Any, optional - :param shared_memory: Enables `shared_memory` mode. + :param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs. If set to `False` inputs the data dispatcher will safely copy data to existing Tensors (including up- or down-casting according to data type, @@ -371,16 +455,27 @@ class AsyncInferQueue(AsyncInferQueueBase): * inputs that should be in `BF16` data type * scalar inputs (i.e. `np.float_`/`int`/`float`) Keeps Tensor inputs "as-is". + Note: Use with extra care, shared data can be modified during runtime! - Note: Using `shared_memory` may result in extra memory overhead. + Note: Using `share_inputs` may result in extra memory overhead. Default value: False + :type share_inputs: bool, optional + :param shared_memory: Deprecated. Works like `share_inputs` mode. + + If not specified, function uses `share_inputs` value. + + Note: Will be removed in 2024.0 release! + Note: This is keyword-only argument. + + Default value: None + :type shared_memory: bool, optional """ super().start_async( _data_dispatch( self[self.get_idle_request_id()], inputs, - is_shared=shared_memory, + is_shared=_deprecated_memory_arg(shared_memory, share_inputs), ), userdata, ) diff --git a/src/bindings/python/src/pyopenvino/core/common.cpp b/src/bindings/python/src/pyopenvino/core/common.cpp index acde6f3e6da..6bd794e20e7 100644 --- a/src/bindings/python/src/pyopenvino/core/common.cpp +++ b/src/bindings/python/src/pyopenvino/core/common.cpp @@ -131,65 +131,22 @@ py::array as_contiguous(py::array& array, ov::element::Type type) { } } -py::array array_from_tensor(ov::Tensor&& t) { - switch (t.get_element_type()) { - case ov::element::Type_t::f32: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::f64: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::bf16: { - return py::array(py::dtype("float16"), t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::f16: { - return py::array(py::dtype("float16"), t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::i8: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::i16: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::i32: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::i64: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::u8: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::u16: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::u32: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::u64: { - return py::array_t(t.get_shape(), t.data()); - break; - } - case ov::element::Type_t::boolean: { - return py::array_t(t.get_shape(), t.data()); - break; - } - default: { - OPENVINO_THROW("Numpy array cannot be created from given OV Tensor!"); - break; +py::array array_from_tensor(ov::Tensor&& t, bool is_shared) { + auto ov_type = t.get_element_type(); + auto dtype = Common::ov_type_to_dtype().at(ov_type); + + // Return the array as a view: + if (is_shared) { + if (ov_type.bitwidth() < Common::values::min_bitwidth) { + return py::array(dtype, t.get_byte_size(), t.data(), py::cast(t)); + } + return py::array(dtype, t.get_shape(), t.get_strides(), t.data(), py::cast(t)); } + // Return the array as a copy: + if (ov_type.bitwidth() < Common::values::min_bitwidth) { + return py::array(dtype, t.get_byte_size(), t.data()); } + return py::array(dtype, t.get_shape(), t.get_strides(), t.data()); } }; // namespace array_helpers @@ -342,10 +299,10 @@ uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual) { } } -py::dict outputs_to_dict(InferRequestWrapper& request) { +py::dict outputs_to_dict(InferRequestWrapper& request, bool share_outputs) { py::dict res; for (const auto& out : request.m_outputs) { - res[py::cast(out)] = array_helpers::array_from_tensor(request.m_request.get_tensor(out)); + res[py::cast(out)] = array_helpers::array_from_tensor(request.m_request.get_tensor(out), share_outputs); } return res; } diff --git a/src/bindings/python/src/pyopenvino/core/common.hpp b/src/bindings/python/src/pyopenvino/core/common.hpp index 3276a67a12d..ed6ad00c6fc 100644 --- a/src/bindings/python/src/pyopenvino/core/common.hpp +++ b/src/bindings/python/src/pyopenvino/core/common.hpp @@ -58,7 +58,7 @@ std::vector get_strides(const py::array& array); py::array as_contiguous(py::array& array, ov::element::Type type); -py::array array_from_tensor(ov::Tensor&& t); +py::array array_from_tensor(ov::Tensor&& t, bool is_shared); }; // namespace array_helpers @@ -92,7 +92,7 @@ void set_request_tensors(ov::InferRequest& request, const py::dict& inputs); uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual); -py::dict outputs_to_dict(InferRequestWrapper& request); +py::dict outputs_to_dict(InferRequestWrapper& request, bool share_outputs); ov::pass::Serialize::Version convert_to_version(const std::string& version); diff --git a/src/bindings/python/src/pyopenvino/core/infer_request.cpp b/src/bindings/python/src/pyopenvino/core/infer_request.cpp index 4a1d21ede20..6b087ecaa20 100644 --- a/src/bindings/python/src/pyopenvino/core/infer_request.cpp +++ b/src/bindings/python/src/pyopenvino/core/infer_request.cpp @@ -15,14 +15,14 @@ namespace py = pybind11; -inline py::dict run_sync_infer(InferRequestWrapper& self) { +inline py::object run_sync_infer(InferRequestWrapper& self, bool share_outputs) { { py::gil_scoped_release release; *self.m_start_time = Time::now(); self.m_request.infer(); *self.m_end_time = Time::now(); } - return Common::outputs_to_dict(self); + return Common::outputs_to_dict(self, share_outputs); } void regclass_InferRequest(py::module m) { @@ -168,11 +168,12 @@ void regclass_InferRequest(py::module m) { // Overload for single input, it will throw error if a model has more than one input. cls.def( "infer", - [](InferRequestWrapper& self, const ov::Tensor& inputs) { + [](InferRequestWrapper& self, const ov::Tensor& inputs, bool share_outputs) { self.m_request.set_input_tensor(inputs); - return run_sync_infer(self); + return run_sync_infer(self, share_outputs); }, py::arg("inputs"), + py::arg("share_outputs"), R"( Infers specified input(s) in synchronous mode. Blocks all methods of InferRequest while request is running. @@ -194,13 +195,14 @@ void regclass_InferRequest(py::module m) { // and values are always of type: ov::Tensor. cls.def( "infer", - [](InferRequestWrapper& self, const py::dict& inputs) { + [](InferRequestWrapper& self, const py::dict& inputs, bool share_outputs) { // Update inputs if there are any Common::set_request_tensors(self.m_request, inputs); // Call Infer function - return run_sync_infer(self); + return run_sync_infer(self, share_outputs); }, py::arg("inputs"), + py::arg("share_outputs"), R"( Infers specified input(s) in synchronous mode. Blocks all methods of InferRequest while request is running. @@ -727,7 +729,7 @@ void regclass_InferRequest(py::module m) { cls.def_property_readonly( "results", [](InferRequestWrapper& self) { - return Common::outputs_to_dict(self); + return Common::outputs_to_dict(self, false); }, R"( Gets all outputs tensors of this InferRequest. diff --git a/src/bindings/python/src/pyopenvino/core/tensor.cpp b/src/bindings/python/src/pyopenvino/core/tensor.cpp index 32143cc1d19..e7ed30ced26 100644 --- a/src/bindings/python/src/pyopenvino/core/tensor.cpp +++ b/src/bindings/python/src/pyopenvino/core/tensor.cpp @@ -207,12 +207,7 @@ void regclass_Tensor(py::module m) { cls.def_property_readonly( "data", [](ov::Tensor& self) { - auto ov_type = self.get_element_type(); - auto dtype = Common::ov_type_to_dtype().at(ov_type); - if (ov_type.bitwidth() < Common::values::min_bitwidth) { - return py::array(dtype, self.get_byte_size(), self.data(), py::cast(self)); - } - return py::array(dtype, self.get_shape(), self.get_strides(), self.data(), py::cast(self)); + return Common::array_helpers::array_from_tensor(std::forward(self), true); }, R"( Access to Tensor's data. diff --git a/src/bindings/python/tests/test_runtime/test_infer_request.py b/src/bindings/python/tests/test_runtime/test_infer_request.py index c7330baf25b..945f7a95b91 100644 --- a/src/bindings/python/tests/test_runtime/test_infer_request.py +++ b/src/bindings/python/tests/test_runtime/test_infer_request.py @@ -84,7 +84,7 @@ def abs_model_with_data(device, ov_type, numpy_dtype): array1 = np.array([[-1, 2, 5, -3]]).astype(numpy_dtype) - return request, tensor1, array1 + return compiled_model, request, tensor1, array1 def test_get_profiling_info(device): @@ -302,8 +302,8 @@ def test_cancel(device): assert "[ INFER_CANCELLED ]" in str(e.value) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_start_async(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_start_async(device, share_inputs): core = Core() model = core.read_model(test_net_xml, test_net_bin) compiled_model = core.compile_model(model, device) @@ -321,15 +321,15 @@ def test_start_async(device, shared_flag): callbacks_info["finished"] = 0 for request in requests: request.set_callback(callback, callbacks_info) - request.start_async({0: img}, shared_memory=shared_flag) + request.start_async({0: img}, share_inputs=share_inputs) for request in requests: request.wait() assert request.latency > 0 assert callbacks_info["finished"] == jobs -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_infer_list_as_inputs(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_infer_list_as_inputs(device, share_inputs): num_inputs = 4 input_shape = [2, 1] dtype = np.float32 @@ -345,18 +345,18 @@ def test_infer_list_as_inputs(device, shared_flag): request = compiled_model.create_infer_request() inputs = [np.random.normal(size=input_shape).astype(dtype)] - request.infer(inputs, shared_memory=shared_flag) + request.infer(inputs, share_inputs=share_inputs) check_fill_inputs(request, inputs) inputs = [ np.random.normal(size=input_shape).astype(dtype) for _ in range(num_inputs) ] - request.infer(inputs, shared_memory=shared_flag) + request.infer(inputs, share_inputs=share_inputs) check_fill_inputs(request, inputs) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_infer_mixed_keys(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_infer_mixed_keys(device, share_inputs): core = Core() model = get_relu_model() compiled_model = core.compile_model(model, device) @@ -368,7 +368,7 @@ def test_infer_mixed_keys(device, shared_flag): tensor2 = Tensor(data2) request = compiled_model.create_infer_request() - res = request.infer({0: tensor2, "data": tensor}, shared_memory=shared_flag) + res = request.infer({0: tensor2, "data": tensor}, share_inputs=share_inputs) assert np.argmax(res[compiled_model.output()]) == 531 @@ -387,11 +387,11 @@ def test_infer_mixed_keys(device, shared_flag): (Type.u64, np.uint64), (Type.boolean, bool), ]) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_infer_mixed_values(device, ov_type, numpy_dtype, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_infer_mixed_values(device, ov_type, numpy_dtype, share_inputs): request, tensor1, array1 = concat_model_with_data(device, ov_type, numpy_dtype) - request.infer([tensor1, array1], shared_memory=shared_flag) + request.infer([tensor1, array1], share_inputs=share_inputs) assert np.array_equal(request.output_tensors[0].data, np.concatenate((tensor1.data, array1))) @@ -411,11 +411,11 @@ def test_infer_mixed_values(device, ov_type, numpy_dtype, shared_flag): (Type.u64, np.uint64), (Type.boolean, bool), ]) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_async_mixed_values(device, ov_type, numpy_dtype, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_async_mixed_values(device, ov_type, numpy_dtype, share_inputs): request, tensor1, array1 = concat_model_with_data(device, ov_type, numpy_dtype) - request.start_async([tensor1, array1], shared_memory=shared_flag) + request.start_async([tensor1, array1], share_inputs=share_inputs) request.wait() assert np.array_equal(request.output_tensors[0].data, np.concatenate((tensor1.data, array1))) @@ -431,14 +431,14 @@ def test_async_mixed_values(device, ov_type, numpy_dtype, shared_flag): (Type.u16, np.uint16), (Type.i64, np.int64), ]) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_infer_single_input(device, ov_type, numpy_dtype, shared_flag): - request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype) +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_infer_single_input(device, ov_type, numpy_dtype, share_inputs): + _, request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype) - request.infer(array1, shared_memory=shared_flag) + request.infer(array1, share_inputs=share_inputs) assert np.array_equal(request.get_output_tensor().data, np.abs(array1)) - request.infer(tensor1, shared_memory=shared_flag) + request.infer(tensor1, share_inputs=share_inputs) assert np.array_equal(request.get_output_tensor().data, np.abs(tensor1.data)) @@ -453,21 +453,21 @@ def test_infer_single_input(device, ov_type, numpy_dtype, shared_flag): (Type.u16, np.uint16), (Type.i64, np.int64), ]) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_async_single_input(device, ov_type, numpy_dtype, shared_flag): - request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype) +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_async_single_input(device, ov_type, numpy_dtype, share_inputs): + _, request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype) - request.start_async(array1, shared_memory=shared_flag) + request.start_async(array1, share_inputs=share_inputs) request.wait() assert np.array_equal(request.get_output_tensor().data, np.abs(array1)) - request.start_async(tensor1, shared_memory=shared_flag) + request.start_async(tensor1, share_inputs=share_inputs) request.wait() assert np.array_equal(request.get_output_tensor().data, np.abs(tensor1.data)) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_infer_queue(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_infer_queue(device, share_inputs): jobs = 8 num_request = 4 core = Core() @@ -482,15 +482,15 @@ def test_infer_queue(device, shared_flag): img = None - if not shared_flag: + if not share_inputs: img = generate_image() infer_queue.set_callback(callback) assert infer_queue.is_ready() for i in range(jobs): - if shared_flag: + if share_inputs: img = generate_image() - infer_queue.start_async({"data": img}, i, shared_memory=shared_flag) + infer_queue.start_async({"data": img}, i, share_inputs=share_inputs) infer_queue.wait_all() assert all(job["finished"] for job in jobs_done) assert all(job["latency"] > 0 for job in jobs_done) @@ -694,21 +694,21 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode): assert np.allclose(res[list(res)[0]], expected_res, atol=1e-6), f"Expected values: {expected_res} \n Actual values: {res} \n" -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_get_results(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_get_results(device, share_inputs): core = Core() data = ops.parameter([10], np.float64) model = Model(ops.split(data, 0, 5), [data]) compiled_model = core.compile_model(model, device) request = compiled_model.create_infer_request() inputs = [np.random.normal(size=list(compiled_model.input().shape))] - results = request.infer(inputs, shared_memory=shared_flag) + results = request.infer(inputs, share_inputs=share_inputs) for output in compiled_model.outputs: assert np.array_equal(results[output], request.results[output]) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_results_async_infer(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_results_async_infer(device, share_inputs): jobs = 8 num_request = 4 core = Core() @@ -724,7 +724,7 @@ def test_results_async_infer(device, shared_flag): img = generate_image() infer_queue.set_callback(callback) for i in range(jobs): - infer_queue.start_async({"data": img}, i, shared_memory=shared_flag) + infer_queue.start_async({"data": img}, i, share_inputs=share_inputs) infer_queue.wait_all() request = compiled_model.create_infer_request() @@ -738,8 +738,8 @@ def test_results_async_infer(device, shared_flag): os.environ.get("TEST_DEVICE") not in ["GPU"], reason="Device dependent test", ) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_infer_float16(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_infer_float16(device, share_inputs): model = bytes( b""" @@ -814,13 +814,13 @@ def test_infer_float16(device, shared_flag): compiled_model = core.compile_model(model, device) input_data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.float16) request = compiled_model.create_infer_request() - outputs = request.infer({0: input_data, 1: input_data}, shared_memory=shared_flag) + outputs = request.infer({0: input_data, 1: input_data}, share_inputs=share_inputs) assert np.allclose(list(outputs.values()), list(request.results.values())) assert np.allclose(list(outputs.values()), input_data + input_data) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_ports_as_inputs(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_ports_as_inputs(device, share_inputs): input_shape = [2, 2] param_a = ops.parameter(input_shape, np.float32) param_b = ops.parameter(input_shape, np.float32) @@ -838,64 +838,64 @@ def test_ports_as_inputs(device, shared_flag): res = request.infer( {compiled_model.inputs[0]: tensor1, compiled_model.inputs[1]: tensor2}, - shared_memory=shared_flag, + share_inputs=share_inputs, ) assert np.array_equal(res[compiled_model.outputs[0]], tensor1.data + tensor2.data) res = request.infer( {request.model_inputs[0]: tensor1, request.model_inputs[1]: tensor2}, - shared_memory=shared_flag, + share_inputs=share_inputs, ) assert np.array_equal(res[request.model_outputs[0]], tensor1.data + tensor2.data) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_inputs_dict_not_replaced(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_inputs_dict_not_replaced(device, share_inputs): request, arr_1, arr_2 = create_simple_request_and_inputs(device) inputs = {0: arr_1, 1: arr_2} inputs_copy = deepcopy(inputs) - res = request.infer(inputs, shared_memory=shared_flag) + res = request.infer(inputs, share_inputs=share_inputs) np.testing.assert_equal(inputs, inputs_copy) assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_inputs_list_not_replaced(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_inputs_list_not_replaced(device, share_inputs): request, arr_1, arr_2 = create_simple_request_and_inputs(device) inputs = [arr_1, arr_2] inputs_copy = deepcopy(inputs) - res = request.infer(inputs, shared_memory=shared_flag) + res = request.infer(inputs, share_inputs=share_inputs) assert np.array_equal(inputs, inputs_copy) assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_inputs_tuple_not_replaced(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_inputs_tuple_not_replaced(device, share_inputs): request, arr_1, arr_2 = create_simple_request_and_inputs(device) inputs = (arr_1, arr_2) inputs_copy = deepcopy(inputs) - res = request.infer(inputs, shared_memory=shared_flag) + res = request.infer(inputs, share_inputs=share_inputs) assert np.array_equal(inputs, inputs_copy) assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_invalid_inputs(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_invalid_inputs(device, share_inputs): request, _, _ = create_simple_request_and_inputs(device) inputs = "some_input" with pytest.raises(TypeError) as e: - request.infer(inputs, shared_memory=shared_flag) + request.infer(inputs, share_inputs=share_inputs) assert "Incompatible inputs of type:" in str(e.value) @@ -921,8 +921,8 @@ def test_infer_dynamic_model(device): assert request.get_input_tensor().shape == Shape(shape3) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_array_like_input_request(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_array_like_input_request(device, share_inputs): class ArrayLikeObject: # Array-like object accepted by np.array to test inputs similar to torch tensor and tf.Tensor def __init__(self, array) -> None: @@ -931,13 +931,13 @@ def test_array_like_input_request(device, shared_flag): def __array__(self): return np.array(self.data) - request, _, input_data = abs_model_with_data(device, Type.f32, np.single) + _, request, _, input_data = abs_model_with_data(device, Type.f32, np.single) model_input_object = ArrayLikeObject(input_data.tolist()) model_input_list = [ArrayLikeObject(input_data.tolist())] model_input_dict = {0: ArrayLikeObject(input_data.tolist())} # Test single array-like object in InferRequest().Infer() - res_object = request.infer(model_input_object, shared_memory=shared_flag) + res_object = request.infer(model_input_object, share_inputs=share_inputs) assert np.array_equal(res_object[request.model_outputs[0]], np.abs(input_data)) # Test list of array-like objects to use normalize_inputs() @@ -949,8 +949,8 @@ def test_array_like_input_request(device, shared_flag): assert np.array_equal(res_dict[request.model_outputs[0]], np.abs(input_data)) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_array_like_input_async(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_array_like_input_async(device, share_inputs): class ArrayLikeObject: # Array-like object accepted by np.array to test inputs similar to torch tensor and tf.Tensor def __init__(self, array) -> None: @@ -959,11 +959,11 @@ def test_array_like_input_async(device, shared_flag): def __array__(self): return np.array(self.data) - request, _, input_data = abs_model_with_data(device, Type.f32, np.single) + _, request, _, input_data = abs_model_with_data(device, Type.f32, np.single) model_input_object = ArrayLikeObject(input_data.tolist()) model_input_list = [ArrayLikeObject(input_data.tolist())] # Test single array-like object in InferRequest().start_async() - request.start_async(model_input_object, shared_memory=shared_flag) + request.start_async(model_input_object, share_inputs=share_inputs) request.wait() assert np.array_equal(request.get_output_tensor().data, np.abs(input_data)) @@ -973,8 +973,8 @@ def test_array_like_input_async(device, shared_flag): assert np.array_equal(request.get_output_tensor().data, np.abs(input_data)) -@pytest.mark.parametrize("shared_flag", [True, False]) -def test_array_like_input_async_infer_queue(device, shared_flag): +@pytest.mark.parametrize("share_inputs", [True, False]) +def test_array_like_input_async_infer_queue(device, share_inputs): class ArrayLikeObject: # Array-like object accepted by np.array to test inputs similar to torch tensor and tf.Tensor def __init__(self, array) -> None: @@ -1008,7 +1008,7 @@ def test_array_like_input_async_infer_queue(device, shared_flag): # Test list of array-like objects in AsyncInferQueue.start_async() infer_queue_list = AsyncInferQueue(compiled_model, jobs) for i in range(jobs): - infer_queue_list.start_async(model_input_list[i], shared_memory=shared_flag) + infer_queue_list.start_async(model_input_list[i], share_inputs=share_inputs) infer_queue_list.wait_all() for i in range(jobs): @@ -1025,7 +1025,7 @@ def test_convert_infer_request(device): assert "cannot deepcopy 'openvino.runtime.ConstOutput' object." in str(e) -@pytest.mark.parametrize("shared_flag", [True, False]) +@pytest.mark.parametrize("share_inputs", [True, False]) @pytest.mark.parametrize("input_data", [ np.array(1.0, dtype=np.float32), np.array(1, dtype=np.int32), @@ -1034,7 +1034,7 @@ def test_convert_infer_request(device): 1.0, 1, ]) -def test_only_scalar_infer(device, shared_flag, input_data): +def test_only_scalar_infer(device, share_inputs, input_data): core = Core() param = ops.parameter([], np.float32, name="data") relu = ops.relu(param, name="relu") @@ -1043,18 +1043,18 @@ def test_only_scalar_infer(device, shared_flag, input_data): compiled = core.compile_model(model=model, device_name=device) request = compiled.create_infer_request() - res = request.infer(input_data, shared_memory=shared_flag) + res = request.infer(input_data, share_inputs=share_inputs) assert res[request.model_outputs[0]] == np.maximum(input_data, 0) input_tensor = request.get_input_tensor() - if shared_flag and isinstance(input_data, np.ndarray) and input_data.dtype == input_tensor.data.dtype: + if share_inputs and isinstance(input_data, np.ndarray) and input_data.dtype == input_tensor.data.dtype: assert np.shares_memory(input_data, input_tensor.data) else: assert not np.shares_memory(input_data, input_tensor.data) -@pytest.mark.parametrize("shared_flag", [True, False]) +@pytest.mark.parametrize("share_inputs", [True, False]) @pytest.mark.parametrize("input_data", [ {0: np.array(1.0, dtype=np.float32), 1: np.array([1.0, 2.0], dtype=np.float32)}, {0: np.array(1, dtype=np.int32), 1: np.array([1, 2], dtype=np.int32)}, @@ -1063,7 +1063,7 @@ def test_only_scalar_infer(device, shared_flag, input_data): {0: 1.0, 1: np.array([1.0, 2.0], dtype=np.float32)}, {0: 1, 1: np.array([1.0, 2.0], dtype=np.int32)}, ]) -def test_mixed_scalar_infer(device, shared_flag, input_data): +def test_mixed_scalar_infer(device, share_inputs, input_data): core = Core() param0 = ops.parameter([], np.float32, name="data0") param1 = ops.parameter([2], np.float32, name="data1") @@ -1073,14 +1073,14 @@ def test_mixed_scalar_infer(device, shared_flag, input_data): compiled = core.compile_model(model=model, device_name=device) request = compiled.create_infer_request() - res = request.infer(input_data, shared_memory=shared_flag) + res = request.infer(input_data, share_inputs=share_inputs) assert np.allclose(res[request.model_outputs[0]], np.add(input_data[0], input_data[1])) input_tensor0 = request.get_input_tensor(0) input_tensor1 = request.get_input_tensor(1) - if shared_flag: + if share_inputs: if isinstance(input_data[0], np.ndarray) and input_data[0].dtype == input_tensor0.data.dtype: assert np.shares_memory(input_data[0], input_tensor0.data) else: @@ -1094,12 +1094,12 @@ def test_mixed_scalar_infer(device, shared_flag, input_data): assert not np.shares_memory(input_data[1], input_tensor1.data) -@pytest.mark.parametrize("shared_flag", [True, False]) +@pytest.mark.parametrize("share_inputs", [True, False]) @pytest.mark.parametrize("input_data", [ {0: np.array(1.0, dtype=np.float32), 1: np.array([3.0], dtype=np.float32)}, {0: np.array(1.0, dtype=np.float32), 1: np.array([3.0, 3.0, 3.0], dtype=np.float32)}, ]) -def test_mixed_dynamic_infer(device, shared_flag, input_data): +def test_mixed_dynamic_infer(device, share_inputs, input_data): core = Core() param0 = ops.parameter([], np.float32, name="data0") param1 = ops.parameter(["?"], np.float32, name="data1") @@ -1109,14 +1109,14 @@ def test_mixed_dynamic_infer(device, shared_flag, input_data): compiled = core.compile_model(model=model, device_name=device) request = compiled.create_infer_request() - res = request.infer(input_data, shared_memory=shared_flag) + res = request.infer(input_data, share_inputs=share_inputs) assert np.allclose(res[request.model_outputs[0]], np.add(input_data[0], input_data[1])) input_tensor0 = request.get_input_tensor(0) input_tensor1 = request.get_input_tensor(1) - if shared_flag: + if share_inputs: if isinstance(input_data[0], np.ndarray) and input_data[0].dtype == input_tensor0.data.dtype: assert np.shares_memory(input_data[0], input_tensor0.data) else: @@ -1130,12 +1130,12 @@ def test_mixed_dynamic_infer(device, shared_flag, input_data): assert not np.shares_memory(input_data[1], input_tensor1.data) -@pytest.mark.parametrize("shared_flag", [True, False]) +@pytest.mark.parametrize("share_inputs", [True, False]) @pytest.mark.parametrize(("input_data", "change_flags"), [ ({0: np.frombuffer(b"\x01\x02\x03\x04", np.uint8)}, False), ({0: np.array([1, 2, 3, 4], dtype=np.uint8)}, True), ]) -def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags): +def test_not_writable_inputs_infer(device, share_inputs, input_data, change_flags): if change_flags is True: input_data[0].setflags(write=0) # identity model @@ -1145,12 +1145,12 @@ def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags model = Model(param_node, [param_node]) compiled = core.compile_model(model, "CPU") - results = compiled(input_data, shared_memory=shared_flag) + results = compiled(input_data, share_inputs=share_inputs) assert np.array_equal(results[0], input_data[0]) request = compiled.create_infer_request() - results = request.infer(input_data, shared_memory=shared_flag) + results = request.infer(input_data, share_inputs=share_inputs) assert np.array_equal(results[0], input_data[0]) @@ -1158,3 +1158,81 @@ def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags # Not writable inputs should always be copied. assert not np.shares_memory(input_data[0], input_tensor.data) + + +@pytest.mark.parametrize("shared_flag", [True, False]) +def test_shared_memory_deprecation(device, shared_flag): + compiled, request, _, input_data = abs_model_with_data(device, Type.f32, np.float32) + + with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"): + _ = compiled(input_data, shared_memory=shared_flag) + + with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"): + _ = request.infer(input_data, shared_memory=shared_flag) + + with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"): + request.start_async(input_data, shared_memory=shared_flag) + request.wait() + + queue = AsyncInferQueue(compiled, jobs=1) + + with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"): + queue.start_async(input_data, shared_memory=shared_flag) + queue.wait_all() + + +@pytest.mark.parametrize("share_inputs", [True, False]) +@pytest.mark.parametrize("share_outputs", [True, False]) +@pytest.mark.parametrize("is_positional", [True, False]) +def test_compiled_model_share_memory(device, share_inputs, share_outputs, is_positional): + compiled, _, _, input_data = abs_model_with_data(device, Type.f32, np.float32) + + if is_positional: + results = compiled(input_data, share_inputs=share_inputs, share_outputs=share_outputs) + else: + results = compiled(input_data, share_inputs, share_outputs) + + assert np.array_equal(results[0], np.abs(input_data)) + + in_tensor_shares = np.shares_memory(compiled._infer_request.get_input_tensor(0).data, input_data) + if share_inputs: + assert in_tensor_shares + else: + assert not in_tensor_shares + + out_tensor_shares = np.shares_memory(compiled._infer_request.get_output_tensor(0).data, results[0]) + if share_outputs: + assert out_tensor_shares + assert results[0].flags["OWNDATA"] is False + else: + assert not out_tensor_shares + assert results[0].flags["OWNDATA"] is True + + +@pytest.mark.parametrize("share_inputs", [True, False]) +@pytest.mark.parametrize("share_outputs", [True, False]) +@pytest.mark.parametrize("is_positional", [True, False]) +def test_infer_request_share_memory(device, share_inputs, share_outputs, is_positional): + _, request, _, input_data = abs_model_with_data(device, Type.f32, np.float32) + + if is_positional: + results = request.infer(input_data, share_inputs=share_inputs, share_outputs=share_outputs) + else: + results = request.infer(input_data, share_inputs, share_outputs) + + assert np.array_equal(results[0], np.abs(input_data)) + + in_tensor_shares = np.shares_memory(request.get_input_tensor(0).data, input_data) + + if share_inputs: + assert in_tensor_shares + else: + assert not in_tensor_shares + + out_tensor_shares = np.shares_memory(request.get_output_tensor(0).data, results[0]) + if share_outputs: + assert out_tensor_shares + assert results[0].flags["OWNDATA"] is False + else: + assert not out_tensor_shares + assert results[0].flags["OWNDATA"] is True