[PyOV] Force copy of not writable numpy arrays (#18194)

This commit is contained in:
Jan Iwaszkiewicz 2023-06-28 07:56:29 +02:00 committed by GitHub
parent 50897e86e6
commit 4fc0b22012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 4 deletions

View File

@ -58,7 +58,7 @@ class InferRequest(_InferRequestWrapper):
Tensors for every input in form of:
* `numpy.ndarray` and all the types that are castable to it, e.g. `torch.Tensor`
Data that is going to be copied:
* `numpy.ndarray` which are not C contiguous
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
@ -118,7 +118,7 @@ class InferRequest(_InferRequestWrapper):
Tensors for every input in form of:
* `numpy.ndarray` and all the types that are castable to it, e.g. `torch.Tensor`
Data that is going to be copied:
* `numpy.ndarray` which are not C contiguous
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
@ -246,7 +246,7 @@ class CompiledModel(CompiledModelBase):
Tensors for every input in form of:
* `numpy.ndarray` and all the types that are castable to it, e.g. `torch.Tensor`
Data that is going to be copied:
* `numpy.ndarray` which are not C contiguous
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)
@ -340,7 +340,7 @@ class AsyncInferQueue(AsyncInferQueueBase):
Tensors for every input in form of:
* `numpy.ndarray` and all the types that are castable to it, e.g. `torch.Tensor`
Data that is going to be copied:
* `numpy.ndarray` which are not C contiguous
* `numpy.ndarray` which are not C contiguous and/or not writable (WRITEABLE flag is set to False)
* inputs which data types are mismatched from Infer Request's inputs
* inputs that should be in `BF16` data type
* scalar inputs (i.e. `np.float_`/`int`/`float`)

View File

@ -70,6 +70,11 @@ def _(
tensor = Tensor(tensor_type, value.shape)
tensor.data[:] = value.view(tensor_dtype)
return tensor
# WA for "not writeable" edge-case, always copy.
if value.flags["WRITEABLE"] is False:
tensor = Tensor(tensor_type, value.shape)
tensor.data[:] = value.astype(tensor_dtype) if tensor_dtype != value.dtype else value
return tensor
# If types are mismatched, convert and always copy.
if tensor_dtype != value.dtype:
return Tensor(value.astype(tensor_dtype), shared_memory=False)

View File

@ -1112,3 +1112,33 @@ def test_mixed_dynamic_infer(device, shared_flag, input_data):
else:
assert not np.shares_memory(input_data[0], input_tensor0.data)
assert not np.shares_memory(input_data[1], input_tensor1.data)
@pytest.mark.parametrize("shared_flag", [True, False])
@pytest.mark.parametrize(("input_data", "change_flags"), [
({0: np.frombuffer(b"\x01\x02\x03\x04", np.uint8)}, False),
({0: np.array([1, 2, 3, 4], dtype=np.uint8)}, True),
])
def test_not_writable_inputs_infer(device, shared_flag, input_data, change_flags):
if change_flags is True:
input_data[0].setflags(write=0)
# identity model
input_shape = [4]
param_node = ops.parameter(input_shape, np.uint8, name="data0")
core = Core()
model = Model(param_node, [param_node])
compiled = core.compile_model(model, "CPU")
results = compiled(input_data, shared_memory=shared_flag)
assert np.array_equal(results[0], input_data[0])
request = compiled.create_infer_request()
results = request.infer(input_data, shared_memory=shared_flag)
assert np.array_equal(results[0], input_data[0])
input_tensor = request.get_input_tensor(0)
# Not writable inputs should always be copied.
assert not np.shares_memory(input_data[0], input_tensor.data)