[PYTHON API] fix request busy using AsyncInferQueue (#10020)

* Add request.wait() call to get_idle_request_id and wait_all()

* rethrow exception on default callback

* update comment

* fix code style
This commit is contained in:
Alexey Lebedev 2022-01-31 15:30:04 +03:00 committed by GitHub
parent b56fd07169
commit 495931673d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 9 deletions

View File

@ -52,19 +52,19 @@ public:
_cv.wait(lock, [this] {
return !(_idle_handles.empty());
});
size_t idle_handle = _idle_handles.front();
_requests[idle_handle]._request.wait();
if (_errors.size() > 0)
throw _errors.front();
return _idle_handles.front();
return idle_handle;
}
void wait_all() {
// Wait for all requests to return with callback thus updating
// _idle_handles so it matches the size of requests
// Wait for all requests to complete
py::gil_scoped_release release;
std::unique_lock<std::mutex> lock(_mutex);
_cv.wait(lock, [this] {
return _idle_handles.size() == _requests.size();
});
for (auto&& request : _requests) {
request._request.wait();
}
if (_errors.size() > 0)
throw _errors.front();
}
@ -73,9 +73,16 @@ public:
for (size_t handle = 0; handle < _requests.size(); handle++) {
_requests[handle]._request.set_callback([this, handle /* ... */](std::exception_ptr exception_ptr) {
_requests[handle]._end_time = Time::now();
try {
if (exception_ptr) {
std::rethrow_exception(exception_ptr);
}
} catch (const std::exception& e) {
throw ov::Exception(e.what());
}
// Add idle handle to queue
_idle_handles.push(handle);
// Notify locks in getIdleRequestId() or waitAll() functions
// Notify locks in getIdleRequestId()
_cv.notify_one();
});
}
@ -102,7 +109,7 @@ public:
}
// Add idle handle to queue
_idle_handles.push(handle);
// Notify locks in getIdleRequestId() or waitAll() functions
// Notify locks in getIdleRequestId()
_cv.notify_one();
});
}

View File

@ -28,6 +28,13 @@ public:
{
_request.set_callback([this](std::exception_ptr exception_ptr) {
_end_time = Time::now();
try {
if (exception_ptr) {
std::rethrow_exception(exception_ptr);
}
} catch (const std::exception& e) {
throw ov::Exception("Caught exception: " + std::string(e.what()));
}
});
}
// ~InferRequestWrapper() = default;

View File

@ -331,6 +331,27 @@ def test_infer_queue_fail_on_py_model(device):
assert "unsupported operand type(s) for +" in str(e.value)
def test_infer_queue_get_idle_handle(device):
param = ops.parameter([10])
model = Model(ops.relu(param), [param])
core = Core()
compiled = core.compile_model(model, device)
queue = AsyncInferQueue(compiled, 2)
niter = 10
for _ in range(len(queue)):
queue.start_async()
queue.wait_all()
for request in queue:
assert request.wait_for(0)
for _ in range(niter):
idle_id = queue.get_idle_request_id()
assert queue[idle_id].wait_for(0)
queue.start_async()
queue.wait_all()
@pytest.mark.parametrize("data_type",
[np.float32,
np.int32,