[PYTHON API] fix request busy using AsyncInferQueue (#10020)
* Add request.wait() call to get_idle_request_id and wait_all() * rethrow exception on default callback * update comment * fix code style
This commit is contained in:
parent
b56fd07169
commit
495931673d
@ -52,19 +52,19 @@ public:
|
||||
_cv.wait(lock, [this] {
|
||||
return !(_idle_handles.empty());
|
||||
});
|
||||
size_t idle_handle = _idle_handles.front();
|
||||
_requests[idle_handle]._request.wait();
|
||||
if (_errors.size() > 0)
|
||||
throw _errors.front();
|
||||
return _idle_handles.front();
|
||||
return idle_handle;
|
||||
}
|
||||
|
||||
void wait_all() {
|
||||
// Wait for all requests to return with callback thus updating
|
||||
// _idle_handles so it matches the size of requests
|
||||
// Wait for all requests to complete
|
||||
py::gil_scoped_release release;
|
||||
std::unique_lock<std::mutex> lock(_mutex);
|
||||
_cv.wait(lock, [this] {
|
||||
return _idle_handles.size() == _requests.size();
|
||||
});
|
||||
for (auto&& request : _requests) {
|
||||
request._request.wait();
|
||||
}
|
||||
if (_errors.size() > 0)
|
||||
throw _errors.front();
|
||||
}
|
||||
@ -73,9 +73,16 @@ public:
|
||||
for (size_t handle = 0; handle < _requests.size(); handle++) {
|
||||
_requests[handle]._request.set_callback([this, handle /* ... */](std::exception_ptr exception_ptr) {
|
||||
_requests[handle]._end_time = Time::now();
|
||||
try {
|
||||
if (exception_ptr) {
|
||||
std::rethrow_exception(exception_ptr);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
throw ov::Exception(e.what());
|
||||
}
|
||||
// Add idle handle to queue
|
||||
_idle_handles.push(handle);
|
||||
// Notify locks in getIdleRequestId() or waitAll() functions
|
||||
// Notify locks in getIdleRequestId()
|
||||
_cv.notify_one();
|
||||
});
|
||||
}
|
||||
@ -102,7 +109,7 @@ public:
|
||||
}
|
||||
// Add idle handle to queue
|
||||
_idle_handles.push(handle);
|
||||
// Notify locks in getIdleRequestId() or waitAll() functions
|
||||
// Notify locks in getIdleRequestId()
|
||||
_cv.notify_one();
|
||||
});
|
||||
}
|
||||
|
@ -28,6 +28,13 @@ public:
|
||||
{
|
||||
_request.set_callback([this](std::exception_ptr exception_ptr) {
|
||||
_end_time = Time::now();
|
||||
try {
|
||||
if (exception_ptr) {
|
||||
std::rethrow_exception(exception_ptr);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
throw ov::Exception("Caught exception: " + std::string(e.what()));
|
||||
}
|
||||
});
|
||||
}
|
||||
// ~InferRequestWrapper() = default;
|
||||
|
@ -331,6 +331,27 @@ def test_infer_queue_fail_on_py_model(device):
|
||||
assert "unsupported operand type(s) for +" in str(e.value)
|
||||
|
||||
|
||||
def test_infer_queue_get_idle_handle(device):
|
||||
param = ops.parameter([10])
|
||||
model = Model(ops.relu(param), [param])
|
||||
core = Core()
|
||||
compiled = core.compile_model(model, device)
|
||||
queue = AsyncInferQueue(compiled, 2)
|
||||
niter = 10
|
||||
|
||||
for _ in range(len(queue)):
|
||||
queue.start_async()
|
||||
queue.wait_all()
|
||||
for request in queue:
|
||||
assert request.wait_for(0)
|
||||
|
||||
for _ in range(niter):
|
||||
idle_id = queue.get_idle_request_id()
|
||||
assert queue[idle_id].wait_for(0)
|
||||
queue.start_async()
|
||||
queue.wait_all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data_type",
|
||||
[np.float32,
|
||||
np.int32,
|
||||
|
Loading…
Reference in New Issue
Block a user