[PYTHON API] infer helper (#9478)
* inputs as list in infer * fix import * fix import 2 * refactor test
This commit is contained in:
parent
6ddc1e981b
commit
42c5be23b1
@ -22,21 +22,23 @@ def tensor_from_file(path: str) -> Tensor:
|
||||
return Tensor(np.fromfile(path, dtype=np.uint8))
|
||||
|
||||
|
||||
def normalize_inputs(py_dict: dict, py_types: dict) -> dict:
|
||||
def normalize_inputs(inputs: Union[dict, list], py_types: dict) -> dict:
|
||||
"""Normalize a dictionary of inputs to Tensors."""
|
||||
for k, val in py_dict.items():
|
||||
if isinstance(inputs, list):
|
||||
inputs = {index: input for index, input in enumerate(inputs)}
|
||||
for k, val in inputs.items():
|
||||
if not isinstance(k, (str, int)):
|
||||
raise TypeError("Incompatible key type for tensor named: {}".format(k))
|
||||
try:
|
||||
ov_type = py_types[k]
|
||||
except KeyError:
|
||||
raise KeyError("Port for tensor named {} was not found!".format(k))
|
||||
py_dict[k] = (
|
||||
inputs[k] = (
|
||||
val
|
||||
if isinstance(val, Tensor)
|
||||
else Tensor(np.array(val, get_dtype(ov_type)))
|
||||
)
|
||||
return py_dict
|
||||
return inputs
|
||||
|
||||
|
||||
def get_input_types(obj: Union[InferRequestBase, CompiledModelBase]) -> dict:
|
||||
@ -55,14 +57,14 @@ def get_input_types(obj: Union[InferRequestBase, CompiledModelBase]) -> dict:
|
||||
class InferRequest(InferRequestBase):
|
||||
"""InferRequest wrapper."""
|
||||
|
||||
def infer(self, inputs: dict = None) -> dict:
|
||||
def infer(self, inputs: Union[dict, list] = None) -> dict:
|
||||
"""Infer wrapper for InferRequest."""
|
||||
inputs = (
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
)
|
||||
return super().infer(inputs)
|
||||
|
||||
def start_async(self, inputs: dict = None, userdata: Any = None) -> None:
|
||||
def start_async(self, inputs: Union[dict, list] = None, userdata: Any = None) -> None:
|
||||
"""Asynchronous infer wrapper for InferRequest."""
|
||||
inputs = (
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
@ -77,7 +79,7 @@ class CompiledModel(CompiledModelBase):
|
||||
"""Create new InferRequest object."""
|
||||
return InferRequest(super().create_infer_request())
|
||||
|
||||
def infer_new_request(self, inputs: dict = None) -> dict:
|
||||
def infer_new_request(self, inputs: Union[dict, list] = None) -> dict:
|
||||
"""Infer wrapper for CompiledModel."""
|
||||
inputs = (
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
@ -92,7 +94,7 @@ class AsyncInferQueue(AsyncInferQueueBase):
|
||||
"""Return i-th InferRequest from AsyncInferQueue."""
|
||||
return InferRequest(super().__getitem__(i))
|
||||
|
||||
def start_async(self, inputs: dict = None, userdata: Any = None) -> None:
|
||||
def start_async(self, inputs: Union[dict, list] = None, userdata: Any = None) -> None:
|
||||
"""Asynchronous infer wrapper for AsyncInferQueue."""
|
||||
inputs = (
|
||||
{}
|
||||
|
@ -175,6 +175,30 @@ def test_start_async(device):
|
||||
assert callbacks_info["finished"] == jobs
|
||||
|
||||
|
||||
def test_infer_list_as_inputs(device):
|
||||
num_inputs = 4
|
||||
input_shape = [2, 1]
|
||||
dtype = np.float32
|
||||
params = [ops.parameter(input_shape, dtype) for _ in range(num_inputs)]
|
||||
model = Model(ops.relu(ops.concat(params, 1)), params)
|
||||
core = Core()
|
||||
compiled_model = core.compile_model(model, device)
|
||||
|
||||
def check_fill_inputs(request, inputs):
|
||||
for input_idx in range(len(inputs)):
|
||||
assert np.array_equal(request.get_input_tensor(input_idx).data, inputs[input_idx])
|
||||
|
||||
request = compiled_model.create_infer_request()
|
||||
|
||||
inputs = [np.random.normal(size=input_shape).astype(dtype)]
|
||||
request.infer(inputs)
|
||||
check_fill_inputs(request, inputs)
|
||||
|
||||
inputs = [np.random.normal(size=input_shape).astype(dtype) for _ in range(num_inputs)]
|
||||
request.infer(inputs)
|
||||
check_fill_inputs(request, inputs)
|
||||
|
||||
|
||||
def test_infer_mixed_keys(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
|
Loading…
Reference in New Issue
Block a user