[PYTHON API] infer helper (#9478)

* inputs as list in infer

* fix import

* fix import 2

* refactor test
This commit is contained in:
Alexey Lebedev 2022-01-11 16:12:11 +03:00 committed by GitHub
parent 6ddc1e981b
commit 42c5be23b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 8 deletions

View File

@ -22,21 +22,23 @@ def tensor_from_file(path: str) -> Tensor:
return Tensor(np.fromfile(path, dtype=np.uint8))
def normalize_inputs(py_dict: dict, py_types: dict) -> dict:
def normalize_inputs(inputs: Union[dict, list], py_types: dict) -> dict:
"""Normalize a dictionary of inputs to Tensors."""
for k, val in py_dict.items():
if isinstance(inputs, list):
inputs = {index: input for index, input in enumerate(inputs)}
for k, val in inputs.items():
if not isinstance(k, (str, int)):
raise TypeError("Incompatible key type for tensor named: {}".format(k))
try:
ov_type = py_types[k]
except KeyError:
raise KeyError("Port for tensor named {} was not found!".format(k))
py_dict[k] = (
inputs[k] = (
val
if isinstance(val, Tensor)
else Tensor(np.array(val, get_dtype(ov_type)))
)
return py_dict
return inputs
def get_input_types(obj: Union[InferRequestBase, CompiledModelBase]) -> dict:
@ -55,14 +57,14 @@ def get_input_types(obj: Union[InferRequestBase, CompiledModelBase]) -> dict:
class InferRequest(InferRequestBase):
"""InferRequest wrapper."""
def infer(self, inputs: dict = None) -> dict:
def infer(self, inputs: Union[dict, list] = None) -> dict:
"""Infer wrapper for InferRequest."""
inputs = (
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
)
return super().infer(inputs)
def start_async(self, inputs: dict = None, userdata: Any = None) -> None:
def start_async(self, inputs: Union[dict, list] = None, userdata: Any = None) -> None:
"""Asynchronous infer wrapper for InferRequest."""
inputs = (
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
@ -77,7 +79,7 @@ class CompiledModel(CompiledModelBase):
"""Create new InferRequest object."""
return InferRequest(super().create_infer_request())
def infer_new_request(self, inputs: dict = None) -> dict:
def infer_new_request(self, inputs: Union[dict, list] = None) -> dict:
"""Infer wrapper for CompiledModel."""
inputs = (
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
@ -92,7 +94,7 @@ class AsyncInferQueue(AsyncInferQueueBase):
"""Return i-th InferRequest from AsyncInferQueue."""
return InferRequest(super().__getitem__(i))
def start_async(self, inputs: dict = None, userdata: Any = None) -> None:
def start_async(self, inputs: Union[dict, list] = None, userdata: Any = None) -> None:
"""Asynchronous infer wrapper for AsyncInferQueue."""
inputs = (
{}

View File

@ -175,6 +175,30 @@ def test_start_async(device):
assert callbacks_info["finished"] == jobs
def test_infer_list_as_inputs(device):
num_inputs = 4
input_shape = [2, 1]
dtype = np.float32
params = [ops.parameter(input_shape, dtype) for _ in range(num_inputs)]
model = Model(ops.relu(ops.concat(params, 1)), params)
core = Core()
compiled_model = core.compile_model(model, device)
def check_fill_inputs(request, inputs):
for input_idx in range(len(inputs)):
assert np.array_equal(request.get_input_tensor(input_idx).data, inputs[input_idx])
request = compiled_model.create_infer_request()
inputs = [np.random.normal(size=input_shape).astype(dtype)]
request.infer(inputs)
check_fill_inputs(request, inputs)
inputs = [np.random.normal(size=input_shape).astype(dtype) for _ in range(num_inputs)]
request.infer(inputs)
check_fill_inputs(request, inputs)
def test_infer_mixed_keys(device):
core = Core()
func = core.read_model(test_net_xml, test_net_bin)