[PyOV] Memory flow control with share_inputs and share_outputs (#18275)

* Added ReturnPolicy and updated common array helpers

* Clean up

* Remove ReturnPolicy initial

* Add share_inputs and share_outputs

* Tests and minor fixes

* Fix docstrings

* Fix whitespace

* Fix typing
This commit is contained in:
Jan Iwaszkiewicz 2023-07-11 19:05:22 +02:00 committed by GitHub
parent 56f51135d4
commit 59052d846e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 305 additions and 178 deletions

View File

@ -4,6 +4,7 @@
from typing import Any, Iterable, Union, Optional
from pathlib import Path
import warnings
import numpy as np
@ -22,6 +23,19 @@ from openvino.runtime.utils.data_helpers import (
)
def _deprecated_memory_arg(shared_memory: bool, share_inputs: bool) -> bool:
if shared_memory is not None:
warnings.warn(
"`shared_memory` is deprecated and will be removed in 2024.0. "
"Value of `shared_memory` is going to override `share_inputs` value. "
"Please use only `share_inputs` explicitly.",
FutureWarning,
stacklevel=3,
)
return shared_memory
return share_inputs
class Model(ModelBase):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if args and not kwargs:
@ -40,7 +54,14 @@ class Model(ModelBase):
class InferRequest(_InferRequestWrapper):
"""InferRequest class represents infer request which can be run in asynchronous or synchronous manners."""
def infer(self, inputs: Any = None, shared_memory: bool = False) -> OVDict:
def infer(
self,
inputs: Any = None,
share_inputs: bool = False,
share_outputs: bool = False,
*,
shared_memory: Any = None,
) -> OVDict:
"""Infers specified input(s) in synchronous mode.
Blocks all methods of InferRequest while request is running.
@ -63,7 +84,7 @@ class InferRequest(_InferRequestWrapper):
:param inputs: Data to be set on input tensors.
:type inputs: Any, optional
:param shared_memory: Enables `shared_memory` mode.
:param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs.
If set to `False` inputs the data dispatcher will safely copy data
to existing Tensors (including up- or down-casting according to data type,
@ -78,10 +99,32 @@ class InferRequest(_InferRequestWrapper):
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Note: Using `shared_memory` may result in the extra memory overhead.
Note: Using `share_inputs` may result in extra memory overhead.
Default value: False
:type share_inputs: bool, optional
:param share_outputs: Enables `share_outputs` mode. Controls memory usage on inference's outputs.
If set to `False` outputs will safely copy data to numpy arrays.
If set to `True` the data will be returned in form of views of output Tensors.
This mode still returns the data in format of numpy arrays but lifetime of the data
is connected to OpenVINO objects.
Note: Use with extra care, shared data can be modified or lost during runtime!
Default value: False
:type share_outputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
:return: Dictionary of results from output tensors with port/int/str keys.
:rtype: OVDict
@ -89,14 +132,16 @@ class InferRequest(_InferRequestWrapper):
return OVDict(super().infer(_data_dispatch(
self,
inputs,
is_shared=shared_memory,
)))
is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
), share_outputs=share_outputs))
def start_async(
self,
inputs: Any = None,
userdata: Any = None,
shared_memory: bool = False,
share_inputs: bool = False,
*,
shared_memory: Any = None,
) -> None:
"""Starts inference of specified input(s) in asynchronous mode.
@ -123,7 +168,7 @@ class InferRequest(_InferRequestWrapper):
:type inputs: Any, optional
:param userdata: Any data that will be passed inside the callback.
:type userdata: Any
:param shared_memory: Enables `shared_memory` mode.
:param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs.
If set to `False` inputs the data dispatcher will safely copy data
to existing Tensors (including up- or down-casting according to data type,
@ -138,17 +183,27 @@ class InferRequest(_InferRequestWrapper):
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Note: Using `shared_memory` may result in extra memory overhead.
Note: Using `share_inputs` may result in extra memory overhead.
Default value: False
:type share_inputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
"""
super().start_async(
_data_dispatch(
self,
inputs,
is_shared=shared_memory,
is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
),
userdata,
)
@ -229,9 +284,14 @@ class CompiledModel(CompiledModelBase):
# overloaded functions of InferRequest class
return self.create_infer_request().infer(inputs)
def __call__(self,
inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None,
shared_memory: bool = True) -> OVDict:
def __call__(
self,
inputs: Union[dict, list, tuple, Tensor, np.ndarray] = None,
share_inputs: bool = True,
share_outputs: bool = False,
*,
shared_memory: Any = None,
) -> OVDict:
"""Callable infer wrapper for CompiledModel.
Infers specified input(s) in synchronous mode.
@ -262,7 +322,7 @@ class CompiledModel(CompiledModelBase):
:param inputs: Data to be set on input tensors.
:type inputs: Union[Dict[keys, values], List[values], Tuple[values], Tensor, numpy.ndarray], optional
:param shared_memory: Enables `shared_memory` mode.
:param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs.
If set to `False` inputs the data dispatcher will safely copy data
to existing Tensors (including up- or down-casting according to data type,
@ -277,12 +337,33 @@ class CompiledModel(CompiledModelBase):
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Note: Using `shared_memory` may result in extra memory overhead.
Note: Using `share_inputs` may result in extra memory overhead.
Default value: True
:type shared_memory: bool, optional
:type share_inputs: bool, optional
:param share_outputs: Enables `share_outputs` mode. Controls memory usage on inference's outputs.
If set to `False` outputs will safely copy data to numpy arrays.
If set to `True` the data will be returned in form of views of output Tensors.
This mode still returns the data in format of numpy arrays but lifetime of the data
is connected to OpenVINO objects.
Note: Use with extra care, shared data can be modified or lost during runtime!
Default value: False
:type share_outputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
:return: Dictionary of results from output tensors with port/int/str as keys.
:rtype: OVDict
"""
@ -291,7 +372,8 @@ class CompiledModel(CompiledModelBase):
return self._infer_request.infer(
inputs,
shared_memory=shared_memory,
share_inputs=_deprecated_memory_arg(shared_memory, share_inputs),
share_outputs=share_outputs,
)
@ -333,7 +415,9 @@ class AsyncInferQueue(AsyncInferQueueBase):
self,
inputs: Any = None,
userdata: Any = None,
shared_memory: bool = False,
share_inputs: bool = False,
*,
shared_memory: Any = None,
) -> None:
"""Run asynchronous inference using the next available InferRequest from the pool.
@ -356,7 +440,7 @@ class AsyncInferQueue(AsyncInferQueueBase):
:type inputs: Any, optional
:param userdata: Any data that will be passed to a callback.
:type userdata: Any, optional
:param shared_memory: Enables `shared_memory` mode.
:param share_inputs: Enables `share_inputs` mode. Controls memory usage on inference's inputs.
If set to `False` inputs the data dispatcher will safely copy data
to existing Tensors (including up- or down-casting according to data type,
@ -371,16 +455,27 @@ class AsyncInferQueue(AsyncInferQueueBase):
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
Keeps Tensor inputs "as-is".
Note: Use with extra care, shared data can be modified during runtime!
Note: Using `shared_memory` may result in extra memory overhead.
Note: Using `share_inputs` may result in extra memory overhead.
Default value: False
:type share_inputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
"""
super().start_async(
_data_dispatch(
self[self.get_idle_request_id()],
inputs,
is_shared=shared_memory,
is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
),
userdata,
)

View File

@ -131,65 +131,22 @@ py::array as_contiguous(py::array& array, ov::element::Type type) {
}
}
py::array array_from_tensor(ov::Tensor&& t) {
switch (t.get_element_type()) {
case ov::element::Type_t::f32: {
return py::array_t<float>(t.get_shape(), t.data<float>());
break;
}
case ov::element::Type_t::f64: {
return py::array_t<double>(t.get_shape(), t.data<double>());
break;
}
case ov::element::Type_t::bf16: {
return py::array(py::dtype("float16"), t.get_shape(), t.data<ov::bfloat16>());
break;
}
case ov::element::Type_t::f16: {
return py::array(py::dtype("float16"), t.get_shape(), t.data<ov::float16>());
break;
}
case ov::element::Type_t::i8: {
return py::array_t<int8_t>(t.get_shape(), t.data<int8_t>());
break;
}
case ov::element::Type_t::i16: {
return py::array_t<int16_t>(t.get_shape(), t.data<int16_t>());
break;
}
case ov::element::Type_t::i32: {
return py::array_t<int32_t>(t.get_shape(), t.data<int32_t>());
break;
}
case ov::element::Type_t::i64: {
return py::array_t<int64_t>(t.get_shape(), t.data<int64_t>());
break;
}
case ov::element::Type_t::u8: {
return py::array_t<uint8_t>(t.get_shape(), t.data<uint8_t>());
break;
}
case ov::element::Type_t::u16: {
return py::array_t<uint16_t>(t.get_shape(), t.data<uint16_t>());
break;
}
case ov::element::Type_t::u32: {
return py::array_t<uint32_t>(t.get_shape(), t.data<uint32_t>());
break;
}
case ov::element::Type_t::u64: {
return py::array_t<uint64_t>(t.get_shape(), t.data<uint64_t>());
break;
}
case ov::element::Type_t::boolean: {
return py::array_t<bool>(t.get_shape(), t.data<bool>());
break;
}
default: {
OPENVINO_THROW("Numpy array cannot be created from given OV Tensor!");
break;
py::array array_from_tensor(ov::Tensor&& t, bool is_shared) {
auto ov_type = t.get_element_type();
auto dtype = Common::ov_type_to_dtype().at(ov_type);
// Return the array as a view:
if (is_shared) {
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
return py::array(dtype, t.get_byte_size(), t.data(), py::cast(t));
}
return py::array(dtype, t.get_shape(), t.get_strides(), t.data(), py::cast(t));
}
// Return the array as a copy:
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
return py::array(dtype, t.get_byte_size(), t.data());
}
return py::array(dtype, t.get_shape(), t.get_strides(), t.data());
}
}; // namespace array_helpers
@ -342,10 +299,10 @@ uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual) {
}
}
py::dict outputs_to_dict(InferRequestWrapper& request) {
py::dict outputs_to_dict(InferRequestWrapper& request, bool share_outputs) {
py::dict res;
for (const auto& out : request.m_outputs) {
res[py::cast(out)] = array_helpers::array_from_tensor(request.m_request.get_tensor(out));
res[py::cast(out)] = array_helpers::array_from_tensor(request.m_request.get_tensor(out), share_outputs);
}
return res;
}

View File

@ -58,7 +58,7 @@ std::vector<size_t> get_strides(const py::array& array);
py::array as_contiguous(py::array& array, ov::element::Type type);
py::array array_from_tensor(ov::Tensor&& t);
py::array array_from_tensor(ov::Tensor&& t, bool is_shared);
}; // namespace array_helpers
@ -92,7 +92,7 @@ void set_request_tensors(ov::InferRequest& request, const py::dict& inputs);
uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual);
py::dict outputs_to_dict(InferRequestWrapper& request);
py::dict outputs_to_dict(InferRequestWrapper& request, bool share_outputs);
ov::pass::Serialize::Version convert_to_version(const std::string& version);

View File

@ -15,14 +15,14 @@
namespace py = pybind11;
inline py::dict run_sync_infer(InferRequestWrapper& self) {
inline py::object run_sync_infer(InferRequestWrapper& self, bool share_outputs) {
{
py::gil_scoped_release release;
*self.m_start_time = Time::now();
self.m_request.infer();
*self.m_end_time = Time::now();
}
return Common::outputs_to_dict(self);
return Common::outputs_to_dict(self, share_outputs);
}
void regclass_InferRequest(py::module m) {
@ -168,11 +168,12 @@ void regclass_InferRequest(py::module m) {
// Overload for single input, it will throw error if a model has more than one input.
cls.def(
"infer",
[](InferRequestWrapper& self, const ov::Tensor& inputs) {
[](InferRequestWrapper& self, const ov::Tensor& inputs, bool share_outputs) {
self.m_request.set_input_tensor(inputs);
return run_sync_infer(self);
return run_sync_infer(self, share_outputs);
},
py::arg("inputs"),
py::arg("share_outputs"),
R"(
Infers specified input(s) in synchronous mode.
Blocks all methods of InferRequest while request is running.
@ -194,13 +195,14 @@ void regclass_InferRequest(py::module m) {
// and values are always of type: ov::Tensor.
cls.def(
"infer",
[](InferRequestWrapper& self, const py::dict& inputs) {
[](InferRequestWrapper& self, const py::dict& inputs, bool share_outputs) {
// Update inputs if there are any
Common::set_request_tensors(self.m_request, inputs);
// Call Infer function
return run_sync_infer(self);
return run_sync_infer(self, share_outputs);
},
py::arg("inputs"),
py::arg("share_outputs"),
R"(
Infers specified input(s) in synchronous mode.
Blocks all methods of InferRequest while request is running.
@ -727,7 +729,7 @@ void regclass_InferRequest(py::module m) {
cls.def_property_readonly(
"results",
[](InferRequestWrapper& self) {
return Common::outputs_to_dict(self);
return Common::outputs_to_dict(self, false);
},
R"(
Gets all outputs tensors of this InferRequest.

View File

@ -207,12 +207,7 @@ void regclass_Tensor(py::module m) {
cls.def_property_readonly(
"data",
[](ov::Tensor& self) {
auto ov_type = self.get_element_type();
auto dtype = Common::ov_type_to_dtype().at(ov_type);
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
return py::array(dtype, self.get_byte_size(), self.data(), py::cast(self));
}
return py::array(dtype, self.get_shape(), self.get_strides(), self.data(), py::cast(self));
return Common::array_helpers::array_from_tensor(std::forward<ov::Tensor>(self), true);
},
R"(
Access to Tensor's data.

View File

@ -84,7 +84,7 @@ def abs_model_with_data(device, ov_type, numpy_dtype):
array1 = np.array([[-1, 2, 5, -3]]).astype(numpy_dtype)
return request, tensor1, array1
return compiled_model, request, tensor1, array1
def test_get_profiling_info(device):
@ -302,8 +302,8 @@ def test_cancel(device):
assert "[ INFER_CANCELLED ]" in str(e.value)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_start_async(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_start_async(device, share_inputs):
core = Core()
model = core.read_model(test_net_xml, test_net_bin)
compiled_model = core.compile_model(model, device)
@ -321,15 +321,15 @@ def test_start_async(device, shared_flag):
callbacks_info["finished"] = 0
for request in requests:
request.set_callback(callback, callbacks_info)
request.start_async({0: img}, shared_memory=shared_flag)
request.start_async({0: img}, share_inputs=share_inputs)
for request in requests:
request.wait()
assert request.latency > 0
assert callbacks_info["finished"] == jobs
@pytest.mark.parametrize("shared_flag", [True, False])
def test_infer_list_as_inputs(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_infer_list_as_inputs(device, share_inputs):
num_inputs = 4
input_shape = [2, 1]
dtype = np.float32
@ -345,18 +345,18 @@ def test_infer_list_as_inputs(device, shared_flag):
request = compiled_model.create_infer_request()
inputs = [np.random.normal(size=input_shape).astype(dtype)]
request.infer(inputs, shared_memory=shared_flag)
request.infer(inputs, share_inputs=share_inputs)
check_fill_inputs(request, inputs)
inputs = [
np.random.normal(size=input_shape).astype(dtype) for _ in range(num_inputs)
]
request.infer(inputs, shared_memory=shared_flag)
request.infer(inputs, share_inputs=share_inputs)
check_fill_inputs(request, inputs)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_infer_mixed_keys(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_infer_mixed_keys(device, share_inputs):
core = Core()
model = get_relu_model()
compiled_model = core.compile_model(model, device)
@ -368,7 +368,7 @@ def test_infer_mixed_keys(device, shared_flag):
tensor2 = Tensor(data2)
request = compiled_model.create_infer_request()
res = request.infer({0: tensor2, "data": tensor}, shared_memory=shared_flag)
res = request.infer({0: tensor2, "data": tensor}, share_inputs=share_inputs)
assert np.argmax(res[compiled_model.output()]) == 531
@ -387,11 +387,11 @@ def test_infer_mixed_keys(device, shared_flag):
(Type.u64, np.uint64),
(Type.boolean, bool),
])
@pytest.mark.parametrize("shared_flag", [True, False])
def test_infer_mixed_values(device, ov_type, numpy_dtype, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_infer_mixed_values(device, ov_type, numpy_dtype, share_inputs):
request, tensor1, array1 = concat_model_with_data(device, ov_type, numpy_dtype)
request.infer([tensor1, array1], shared_memory=shared_flag)
request.infer([tensor1, array1], share_inputs=share_inputs)
assert np.array_equal(request.output_tensors[0].data, np.concatenate((tensor1.data, array1)))
@ -411,11 +411,11 @@ def test_infer_mixed_values(device, ov_type, numpy_dtype, shared_flag):
(Type.u64, np.uint64),
(Type.boolean, bool),
])
@pytest.mark.parametrize("shared_flag", [True, False])
def test_async_mixed_values(device, ov_type, numpy_dtype, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_async_mixed_values(device, ov_type, numpy_dtype, share_inputs):
request, tensor1, array1 = concat_model_with_data(device, ov_type, numpy_dtype)
request.start_async([tensor1, array1], shared_memory=shared_flag)
request.start_async([tensor1, array1], share_inputs=share_inputs)
request.wait()
assert np.array_equal(request.output_tensors[0].data, np.concatenate((tensor1.data, array1)))
@ -431,14 +431,14 @@ def test_async_mixed_values(device, ov_type, numpy_dtype, shared_flag):
(Type.u16, np.uint16),
(Type.i64, np.int64),
])
@pytest.mark.parametrize("shared_flag", [True, False])
def test_infer_single_input(device, ov_type, numpy_dtype, shared_flag):
request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype)
@pytest.mark.parametrize("share_inputs", [True, False])
def test_infer_single_input(device, ov_type, numpy_dtype, share_inputs):
_, request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype)
request.infer(array1, shared_memory=shared_flag)
request.infer(array1, share_inputs=share_inputs)
assert np.array_equal(request.get_output_tensor().data, np.abs(array1))
request.infer(tensor1, shared_memory=shared_flag)
request.infer(tensor1, share_inputs=share_inputs)
assert np.array_equal(request.get_output_tensor().data, np.abs(tensor1.data))
@ -453,21 +453,21 @@ def test_infer_single_input(device, ov_type, numpy_dtype, shared_flag):
(Type.u16, np.uint16),
(Type.i64, np.int64),
])
@pytest.mark.parametrize("shared_flag", [True, False])
def test_async_single_input(device, ov_type, numpy_dtype, shared_flag):
request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype)
@pytest.mark.parametrize("share_inputs", [True, False])
def test_async_single_input(device, ov_type, numpy_dtype, share_inputs):
_, request, tensor1, array1 = abs_model_with_data(device, ov_type, numpy_dtype)
request.start_async(array1, shared_memory=shared_flag)
request.start_async(array1, share_inputs=share_inputs)
request.wait()
assert np.array_equal(request.get_output_tensor().data, np.abs(array1))
request.start_async(tensor1, shared_memory=shared_flag)
request.start_async(tensor1, share_inputs=share_inputs)
request.wait()
assert np.array_equal(request.get_output_tensor().data, np.abs(tensor1.data))
@pytest.mark.parametrize("shared_flag", [True, False])
def test_infer_queue(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_infer_queue(device, share_inputs):
jobs = 8
num_request = 4
core = Core()
@ -482,15 +482,15 @@ def test_infer_queue(device, shared_flag):
img = None
if not shared_flag:
if not share_inputs:
img = generate_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready()
for i in range(jobs):
if shared_flag:
if share_inputs:
img = generate_image()
infer_queue.start_async({"data": img}, i, shared_memory=shared_flag)
infer_queue.start_async({"data": img}, i, share_inputs=share_inputs)
infer_queue.wait_all()
assert all(job["finished"] for job in jobs_done)
assert all(job["latency"] > 0 for job in jobs_done)
@ -694,21 +694,21 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
assert np.allclose(res[list(res)[0]], expected_res, atol=1e-6), f"Expected values: {expected_res} \n Actual values: {res} \n"
@pytest.mark.parametrize("shared_flag", [True, False])
def test_get_results(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_get_results(device, share_inputs):
core = Core()
data = ops.parameter([10], np.float64)
model = Model(ops.split(data, 0, 5), [data])
compiled_model = core.compile_model(model, device)
request = compiled_model.create_infer_request()
inputs = [np.random.normal(size=list(compiled_model.input().shape))]
results = request.infer(inputs, shared_memory=shared_flag)
results = request.infer(inputs, share_inputs=share_inputs)
for output in compiled_model.outputs:
assert np.array_equal(results[output], request.results[output])
@pytest.mark.parametrize("shared_flag", [True, False])
def test_results_async_infer(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_results_async_infer(device, share_inputs):
jobs = 8
num_request = 4
core = Core()
@ -724,7 +724,7 @@ def test_results_async_infer(device, shared_flag):
img = generate_image()
infer_queue.set_callback(callback)
for i in range(jobs):
infer_queue.start_async({"data": img}, i, shared_memory=shared_flag)
infer_queue.start_async({"data": img}, i, share_inputs=share_inputs)
infer_queue.wait_all()
request = compiled_model.create_infer_request()
@ -738,8 +738,8 @@ def test_results_async_infer(device, shared_flag):
os.environ.get("TEST_DEVICE") not in ["GPU"],
reason="Device dependent test",
)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_infer_float16(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_infer_float16(device, share_inputs):
model = bytes(
b"""<net name="add_model" version="10">
<layers>
@ -814,13 +814,13 @@ def test_infer_float16(device, shared_flag):
compiled_model = core.compile_model(model, device)
input_data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.float16)
request = compiled_model.create_infer_request()
outputs = request.infer({0: input_data, 1: input_data}, shared_memory=shared_flag)
outputs = request.infer({0: input_data, 1: input_data}, share_inputs=share_inputs)
assert np.allclose(list(outputs.values()), list(request.results.values()))
assert np.allclose(list(outputs.values()), input_data + input_data)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_ports_as_inputs(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_ports_as_inputs(device, share_inputs):
input_shape = [2, 2]
param_a = ops.parameter(input_shape, np.float32)
param_b = ops.parameter(input_shape, np.float32)
@ -838,64 +838,64 @@ def test_ports_as_inputs(device, shared_flag):
res = request.infer(
{compiled_model.inputs[0]: tensor1, compiled_model.inputs[1]: tensor2},
shared_memory=shared_flag,
share_inputs=share_inputs,
)
assert np.array_equal(res[compiled_model.outputs[0]], tensor1.data + tensor2.data)
res = request.infer(
{request.model_inputs[0]: tensor1, request.model_inputs[1]: tensor2},
shared_memory=shared_flag,
share_inputs=share_inputs,
)
assert np.array_equal(res[request.model_outputs[0]], tensor1.data + tensor2.data)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_inputs_dict_not_replaced(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_inputs_dict_not_replaced(device, share_inputs):
request, arr_1, arr_2 = create_simple_request_and_inputs(device)
inputs = {0: arr_1, 1: arr_2}
inputs_copy = deepcopy(inputs)
res = request.infer(inputs, shared_memory=shared_flag)
res = request.infer(inputs, share_inputs=share_inputs)
np.testing.assert_equal(inputs, inputs_copy)
assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_inputs_list_not_replaced(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_inputs_list_not_replaced(device, share_inputs):
request, arr_1, arr_2 = create_simple_request_and_inputs(device)
inputs = [arr_1, arr_2]
inputs_copy = deepcopy(inputs)
res = request.infer(inputs, shared_memory=shared_flag)
res = request.infer(inputs, share_inputs=share_inputs)
assert np.array_equal(inputs, inputs_copy)
assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_inputs_tuple_not_replaced(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_inputs_tuple_not_replaced(device, share_inputs):
request, arr_1, arr_2 = create_simple_request_and_inputs(device)
inputs = (arr_1, arr_2)
inputs_copy = deepcopy(inputs)
res = request.infer(inputs, shared_memory=shared_flag)
res = request.infer(inputs, share_inputs=share_inputs)
assert np.array_equal(inputs, inputs_copy)
assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_invalid_inputs(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_invalid_inputs(device, share_inputs):
request, _, _ = create_simple_request_and_inputs(device)
inputs = "some_input"
with pytest.raises(TypeError) as e:
request.infer(inputs, shared_memory=shared_flag)
request.infer(inputs, share_inputs=share_inputs)
assert "Incompatible inputs of type:" in str(e.value)
@ -921,8 +921,8 @@ def test_infer_dynamic_model(device):
assert request.get_input_tensor().shape == Shape(shape3)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_array_like_input_request(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_array_like_input_request(device, share_inputs):
class ArrayLikeObject:
# Array-like object accepted by np.array to test inputs similar to torch tensor and tf.Tensor
def __init__(self, array) -> None:
@ -931,13 +931,13 @@ def test_array_like_input_request(device, shared_flag):
def __array__(self):
return np.array(self.data)
request, _, input_data = abs_model_with_data(device, Type.f32, np.single)
_, request, _, input_data = abs_model_with_data(device, Type.f32, np.single)
model_input_object = ArrayLikeObject(input_data.tolist())
model_input_list = [ArrayLikeObject(input_data.tolist())]
model_input_dict = {0: ArrayLikeObject(input_data.tolist())}
# Test single array-like object in InferRequest().Infer()
res_object = request.infer(model_input_object, shared_memory=shared_flag)
res_object = request.infer(model_input_object, share_inputs=share_inputs)
assert np.array_equal(res_object[request.model_outputs[0]], np.abs(input_data))
# Test list of array-like objects to use normalize_inputs()
@ -949,8 +949,8 @@ def test_array_like_input_request(device, shared_flag):
assert np.array_equal(res_dict[request.model_outputs[0]], np.abs(input_data))
@pytest.mark.parametrize("shared_flag", [True, False])
def test_array_like_input_async(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_array_like_input_async(device, share_inputs):
class ArrayLikeObject:
# Array-like object accepted by np.array to test inputs similar to torch tensor and tf.Tensor
def __init__(self, array) -> None:
@ -959,11 +959,11 @@ def test_array_like_input_async(device, shared_flag):
def __array__(self):
return np.array(self.data)
request, _, input_data = abs_model_with_data(device, Type.f32, np.single)
_, request, _, input_data = abs_model_with_data(device, Type.f32, np.single)
model_input_object = ArrayLikeObject(input_data.tolist())
model_input_list = [ArrayLikeObject(input_data.tolist())]
# Test single array-like object in InferRequest().start_async()
request.start_async(model_input_object, shared_memory=shared_flag)
request.start_async(model_input_object, share_inputs=share_inputs)
request.wait()
assert np.array_equal(request.get_output_tensor().data, np.abs(input_data))
@ -973,8 +973,8 @@ def test_array_like_input_async(device, shared_flag):
assert np.array_equal(request.get_output_tensor().data, np.abs(input_data))
@pytest.mark.parametrize("shared_flag", [True, False])
def test_array_like_input_async_infer_queue(device, shared_flag):
@pytest.mark.parametrize("share_inputs", [True, False])
def test_array_like_input_async_infer_queue(device, share_inputs):
class ArrayLikeObject:
# Array-like object accepted by np.array to test inputs similar to torch tensor and tf.Tensor
def __init__(self, array) -> None:
@ -1008,7 +1008,7 @@ def test_array_like_input_async_infer_queue(device, shared_flag):
# Test list of array-like objects in AsyncInferQueue.start_async()
infer_queue_list = AsyncInferQueue(compiled_model, jobs)
for i in range(jobs):
infer_queue_list.start_async(model_input_list[i], shared_memory=shared_flag)
infer_queue_list.start_async(model_input_list[i], share_inputs=share_inputs)
infer_queue_list.wait_all()
for i in range(jobs):
@ -1025,7 +1025,7 @@ def test_convert_infer_request(device):
assert "cannot deepcopy 'openvino.runtime.ConstOutput' object." in str(e)
@pytest.mark.parametrize("shared_flag", [True, False])
@pytest.mark.parametrize("share_inputs", [True, False])
@pytest.mark.parametrize("input_data", [
np.array(1.0, dtype=np.float32),
np.array(1, dtype=np.int32),
@ -1034,7 +1034,7 @@ def test_convert_infer_request(device):
1.0,
1,
])
def test_only_scalar_infer(device, shared_flag, input_data):
def test_only_scalar_infer(device, share_inputs, input_data):
core = Core()
param = ops.parameter([], np.float32, name="data")
relu = ops.relu(param, name="relu")
@ -1043,18 +1043,18 @@ def test_only_scalar_infer(device, shared_flag, input_data):
compiled = core.compile_model(model=model, device_name=device)
request = compiled.create_infer_request()
res = request.infer(input_data, shared_memory=shared_flag)
res = request.infer(input_data, share_inputs=share_inputs)
assert res[request.model_outputs[0]] == np.maximum(input_data, 0)
input_tensor = request.get_input_tensor()
if shared_flag and isinstance(input_data, np.ndarray) and input_data.dtype == input_tensor.data.dtype:
if share_inputs and isinstance(input_data, np.ndarray) and input_data.dtype == input_tensor.data.dtype:
assert np.shares_memory(input_data, input_tensor.data)
else:
assert not np.shares_memory(input_data, input_tensor.data)
@pytest.mark.parametrize("shared_flag", [True, False])
@pytest.mark.parametrize("share_inputs", [True, False])
@pytest.mark.parametrize("input_data", [
{0: np.array(1.0, dtype=np.float32), 1: np.array([1.0, 2.0], dtype=np.float32)},
{0: np.array(1, dtype=np.int32), 1: np.array([1, 2], dtype=np.int32)},
@ -1063,7 +1063,7 @@ def test_only_scalar_infer(device, shared_flag, input_data):
{0: 1.0, 1: np.array([1.0, 2.0], dtype=np.float32)},
{0: 1, 1: np.array([1.0, 2.0], dtype=np.int32)},
])
def test_mixed_scalar_infer(device, shared_flag, input_data):
def test_mixed_scalar_infer(device, share_inputs, input_data):
core = Core()
param0 = ops.parameter([], np.float32, name="data0")
param1 = ops.parameter([2], np.float32, name="data1")
@ -1073,14 +1073,14 @@ def test_mixed_scalar_infer(device, shared_flag, input_data):
compiled = core.compile_model(model=model, device_name=device)
request = compiled.create_infer_request()
res = request.infer(input_data, shared_memory=shared_flag)
res = request.infer(input_data, share_inputs=share_inputs)
assert np.allclose(res[request.model_outputs[0]], np.add(input_data[0], input_data[1]))
input_tensor0 = request.get_input_tensor(0)
input_tensor1 = request.get_input_tensor(1)
if shared_flag:
if share_inputs:
if isinstance(input_data[0], np.ndarray) and input_data[0].dtype == input_tensor0.data.dtype:
assert np.shares_memory(input_data[0], input_tensor0.data)
else:
@ -1094,12 +1094,12 @@ def test_mixed_scalar_infer(device, shared_flag, input_data):
assert not np.shares_memory(input_data[1], input_tensor1.data)
@pytest.mark.parametrize("shared_flag", [True, False])
@pytest.mark.parametrize("share_inputs", [True, False])
@pytest.mark.parametrize("input_data", [
{0: np.array(1.0, dtype=np.float32), 1: np.array([3.0], dtype=np.float32)},
{0: np.array(1.0, dtype=np.float32), 1: np.array([3.0, 3.0, 3.0], dtype=np.float32)},
])
def test_mixed_dynamic_infer(device, shared_flag, input_data):
def test_mixed_dynamic_infer(device, share_inputs, input_data):
core = Core()
param0 = ops.parameter([], np.float32, name="data0")
param1 = ops.parameter(["?"], np.float32, name="data1")
@ -1109,14 +1109,14 @@ def test_mixed_dynamic_infer(device, shared_flag, input_data):
compiled = core.compile_model(model=model, device_name=device)
request = compiled.create_infer_request()
res = request.infer(input_data, shared_memory=shared_flag)
res = request.infer(input_data, share_inputs=share_inputs)
assert np.allclose(res[request.model_outputs[0]], np.add(input_data[0], input_data[1]))
input_tensor0 = request.get_input_tensor(0)
input_tensor1 = request.get_input_tensor(1)
if shared_flag:
if share_inputs:
if isinstance(input_data[0], np.ndarray) and input_data[0].dtype == input_tensor0.data.dtype:
assert np.shares_memory(input_data[0], input_tensor0.data)
else:
@ -1130,12 +1130,12 @@ def test_mixed_dynamic_infer(device, shared_flag, input_data):
assert not np.shares_memory(input_data[1], input_tensor1.data)
@pytest.mark.parametrize("shared_flag", [True, False])
@pytest.mark.parametrize("share_inputs", [True, False])
@pytest.mark.parametrize(("input_data", "change_flags"), [
({0: np.frombuffer(b"\x01\x02\x03\x04", np.uint8)}, False),
({0: np.array([1, 2, 3, 4], dtype=np.uint8)}, True),
])
def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags):
def test_not_writable_inputs_infer(device, share_inputs, input_data, change_flags):
if change_flags is True:
input_data[0].setflags(write=0)
# identity model
@ -1145,12 +1145,12 @@ def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags
model = Model(param_node, [param_node])
compiled = core.compile_model(model, "CPU")
results = compiled(input_data, shared_memory=shared_flag)
results = compiled(input_data, share_inputs=share_inputs)
assert np.array_equal(results[0], input_data[0])
request = compiled.create_infer_request()
results = request.infer(input_data, shared_memory=shared_flag)
results = request.infer(input_data, share_inputs=share_inputs)
assert np.array_equal(results[0], input_data[0])
@ -1158,3 +1158,81 @@ def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags
# Not writable inputs should always be copied.
assert not np.shares_memory(input_data[0], input_tensor.data)
@pytest.mark.parametrize("shared_flag", [True, False])
def test_shared_memory_deprecation(device, shared_flag):
compiled, request, _, input_data = abs_model_with_data(device, Type.f32, np.float32)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
_ = compiled(input_data, shared_memory=shared_flag)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
_ = request.infer(input_data, shared_memory=shared_flag)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
request.start_async(input_data, shared_memory=shared_flag)
request.wait()
queue = AsyncInferQueue(compiled, jobs=1)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
queue.start_async(input_data, shared_memory=shared_flag)
queue.wait_all()
@pytest.mark.parametrize("share_inputs", [True, False])
@pytest.mark.parametrize("share_outputs", [True, False])
@pytest.mark.parametrize("is_positional", [True, False])
def test_compiled_model_share_memory(device, share_inputs, share_outputs, is_positional):
compiled, _, _, input_data = abs_model_with_data(device, Type.f32, np.float32)
if is_positional:
results = compiled(input_data, share_inputs=share_inputs, share_outputs=share_outputs)
else:
results = compiled(input_data, share_inputs, share_outputs)
assert np.array_equal(results[0], np.abs(input_data))
in_tensor_shares = np.shares_memory(compiled._infer_request.get_input_tensor(0).data, input_data)
if share_inputs:
assert in_tensor_shares
else:
assert not in_tensor_shares
out_tensor_shares = np.shares_memory(compiled._infer_request.get_output_tensor(0).data, results[0])
if share_outputs:
assert out_tensor_shares
assert results[0].flags["OWNDATA"] is False
else:
assert not out_tensor_shares
assert results[0].flags["OWNDATA"] is True
@pytest.mark.parametrize("share_inputs", [True, False])
@pytest.mark.parametrize("share_outputs", [True, False])
@pytest.mark.parametrize("is_positional", [True, False])
def test_infer_request_share_memory(device, share_inputs, share_outputs, is_positional):
_, request, _, input_data = abs_model_with_data(device, Type.f32, np.float32)
if is_positional:
results = request.infer(input_data, share_inputs=share_inputs, share_outputs=share_outputs)
else:
results = request.infer(input_data, share_inputs, share_outputs)
assert np.array_equal(results[0], np.abs(input_data))
in_tensor_shares = np.shares_memory(request.get_input_tensor(0).data, input_data)
if share_inputs:
assert in_tensor_shares
else:
assert not in_tensor_shares
out_tensor_shares = np.shares_memory(request.get_output_tensor(0).data, results[0])
if share_outputs:
assert out_tensor_shares
assert results[0].flags["OWNDATA"] is False
else:
assert not out_tensor_shares
assert results[0].flags["OWNDATA"] is True