[PYTHON API] Fix InferQueue.is_ready() call (#10096)

* Fix is_ready and add tests

* remove wrong comment

* refactor test

* Fix code style
This commit is contained in:
Alexey Lebedev 2022-02-04 11:57:56 +03:00 committed by GitHub
parent da02951d67
commit 7478915ef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 14 deletions

View File

@ -35,24 +35,26 @@ public:
}
bool _is_ready() {
// Check if any request has finished already
py::gil_scoped_release release;
std::unique_lock<std::mutex> lock(_mutex);
_cv.wait(lock, [this] {
return !(_idle_handles.empty());
});
// acquire the mutex to access _errors and _idle_handles
std::lock_guard<std::mutex> lock(_mutex);
if (_errors.size() > 0)
throw _errors.front();
return !(_idle_handles.empty());
}
size_t get_idle_request_id() {
// Wait for any of _idle_handles
// Wait for any request to complete and return its id
// release GIL to avoid deadlock on python callback
py::gil_scoped_release release;
// acquire the mutex to access _errors and _idle_handles
std::unique_lock<std::mutex> lock(_mutex);
_cv.wait(lock, [this] {
return !(_idle_handles.empty());
});
size_t idle_handle = _idle_handles.front();
// wait for request to make sure it returned from callback
_requests[idle_handle]._request.wait();
if (_errors.size() > 0)
throw _errors.front();
@ -60,11 +62,14 @@ public:
}
void wait_all() {
// Wait for all requests to complete
// Wait for all request to complete
// release GIL to avoid deadlock on python callback
py::gil_scoped_release release;
for (auto&& request : _requests) {
request._request.wait();
}
// acquire the mutex to access _errors
std::lock_guard<std::mutex> lock(_mutex);
if (_errors.size() > 0)
throw _errors.front();
}
@ -80,8 +85,12 @@ public:
} catch (const std::exception& e) {
throw ov::Exception(e.what());
}
// Add idle handle to queue
_idle_handles.push(handle);
{
// acquire the mutex to access _idle_handles
std::lock_guard<std::mutex> lock(_mutex);
// Add idle handle to queue
_idle_handles.push(handle);
}
// Notify locks in getIdleRequestId()
_cv.notify_one();
});
@ -105,10 +114,16 @@ public:
f_callback(_requests[handle], _user_ids[handle]);
} catch (py::error_already_set py_error) {
assert(PyErr_Occurred());
// acquire the mutex to access _errors
std::lock_guard<std::mutex> lock(_mutex);
_errors.push(py_error);
}
// Add idle handle to queue
_idle_handles.push(handle);
{
// acquire the mutex to access _idle_handles
std::lock_guard<std::mutex> lock(_mutex);
// Add idle handle to queue
_idle_handles.push(handle);
}
// Notify locks in getIdleRequestId()
_cv.notify_one();
});

View File

@ -277,7 +277,6 @@ def test_infer_queue(device):
img = read_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready
for i in range(jobs):
infer_queue.start_async({"data": img}, i)
infer_queue.wait_all()
@ -285,6 +284,22 @@ def test_infer_queue(device):
assert all(job["latency"] > 0 for job in jobs_done)
def test_infer_queue_is_ready(device):
core = Core()
param = ops.parameter([10])
model = Model(ops.relu(param), [param])
compiled = core.compile_model(model, device)
infer_queue = AsyncInferQueue(compiled, 1)
def callback(request, _):
time.sleep(0.001)
infer_queue.set_callback(callback)
assert infer_queue.is_ready()
infer_queue.start_async()
assert not infer_queue.is_ready()
infer_queue.wait_all()
def test_infer_queue_fail_on_cpp_model(device):
jobs = 6
num_request = 4
@ -298,7 +313,6 @@ def test_infer_queue_fail_on_cpp_model(device):
img = read_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready
with pytest.raises(RuntimeError) as e:
for _ in range(jobs):
@ -321,7 +335,6 @@ def test_infer_queue_fail_on_py_model(device):
img = read_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready
with pytest.raises(TypeError) as e:
for _ in range(jobs):
@ -434,7 +447,6 @@ def test_results_async_infer(device):
img = read_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready
for i in range(jobs):
infer_queue.start_async({"data": img}, i)
infer_queue.wait_all()