[PyOV] Fix iteration over AsyncInferQueue (#13496)
* [PyOV] Fix iteration over AsyncInferQueue * apply comments
This commit is contained in:
committed by
GitHub
parent
152511daa8
commit
2455bb67d4
@@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import singledispatch
|
||||
from typing import Any, Union, Dict
|
||||
from typing import Any, Iterable, Union, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -280,6 +280,14 @@ class AsyncInferQueue(AsyncInferQueueBase):
|
||||
InferRequests and provides synchronization functions to control flow of
|
||||
a simple pipeline.
|
||||
"""
|
||||
def __iter__(self) -> Iterable[InferRequest]:
|
||||
"""Allows to iterate over AsyncInferQueue.
|
||||
|
||||
:return: a map object (which is an iterator) that yields InferRequests.
|
||||
:rtype: Iterable[openvino.runtime.InferRequest]
|
||||
"""
|
||||
return map(lambda x: InferRequest(x), super().__iter__())
|
||||
|
||||
def __getitem__(self, i: int) -> InferRequest:
|
||||
"""Gets InferRequest from the pool with given i id.
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
import os
|
||||
@@ -10,7 +11,7 @@ import datetime
|
||||
import time
|
||||
|
||||
import openvino.runtime.opset8 as ops
|
||||
from openvino.runtime import Core, AsyncInferQueue, Tensor, ProfilingInfo, Model
|
||||
from openvino.runtime import Core, AsyncInferQueue, Tensor, ProfilingInfo, Model, InferRequest
|
||||
from openvino.runtime import Type, PartialShape, Shape, Layout
|
||||
from openvino.preprocess import PrePostProcessor
|
||||
|
||||
@@ -493,6 +494,23 @@ def test_infer_queue_is_ready(device):
|
||||
infer_queue.wait_all()
|
||||
|
||||
|
||||
def test_infer_queue_iteration(device):
|
||||
core = Core()
|
||||
param = ops.parameter([10])
|
||||
model = Model(ops.relu(param), [param])
|
||||
compiled_model = core.compile_model(model, device)
|
||||
infer_queue = AsyncInferQueue(compiled_model, 1)
|
||||
assert isinstance(infer_queue, Iterable)
|
||||
for infer_req in infer_queue:
|
||||
assert isinstance(infer_req, InferRequest)
|
||||
|
||||
it = iter(infer_queue)
|
||||
infer_request = next(it)
|
||||
assert isinstance(infer_request, InferRequest)
|
||||
with pytest.raises(StopIteration):
|
||||
next(it)
|
||||
|
||||
|
||||
def test_infer_queue_userdata_is_empty(device):
|
||||
core = Core()
|
||||
param = ops.parameter([10])
|
||||
|
||||
Reference in New Issue
Block a user