[Python] Relinquish infer request handle before handling exception (#13114)

When runtime exception was thrown during infer call, the infer
request handle was never relinquished nor condition variable notified.
That caused the get_idle_request_id function wait forever.
This commit is contained in:
Mateusz Tabaka 2022-09-22 13:00:44 +02:00 committed by GitHub
parent d5a274b0e4
commit 249df503eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 26 deletions

View File

@ -78,13 +78,6 @@ public:
for (size_t handle = 0; handle < _requests.size(); handle++) {
_requests[handle]._request.set_callback([this, handle /* ... */](std::exception_ptr exception_ptr) {
_requests[handle]._end_time = Time::now();
try {
if (exception_ptr) {
std::rethrow_exception(exception_ptr);
}
} catch (const std::exception& e) {
throw ov::Exception(e.what());
}
{
// acquire the mutex to access _idle_handles
std::lock_guard<std::mutex> lock(_mutex);
@ -93,6 +86,14 @@ public:
}
// Notify locks in getIdleRequestId()
_cv.notify_one();
try {
if (exception_ptr) {
std::rethrow_exception(exception_ptr);
}
} catch (const std::exception& e) {
throw ov::Exception(e.what());
}
});
}
}
@ -101,27 +102,23 @@ public:
for (size_t handle = 0; handle < _requests.size(); handle++) {
_requests[handle]._request.set_callback([this, f_callback, handle](std::exception_ptr exception_ptr) {
_requests[handle]._end_time = Time::now();
try {
if (exception_ptr) {
std::rethrow_exception(exception_ptr);
if (exception_ptr == nullptr) {
// Acquire GIL, execute Python function
py::gil_scoped_acquire acquire;
try {
f_callback(_requests[handle], _user_ids[handle]);
} catch (const py::error_already_set& py_error) {
// This should behave the same as assert(!PyErr_Occurred())
// since constructor for pybind11's error_already_set is
// performing PyErr_Fetch which clears error indicator and
// saves it inside itself.
assert(py_error.type());
// acquire the mutex to access _errors
std::lock_guard<std::mutex> lock(_mutex);
_errors.push(py_error);
}
} catch (const std::exception& e) {
throw ov::Exception(e.what());
}
// Acquire GIL, execute Python function
py::gil_scoped_acquire acquire;
try {
f_callback(_requests[handle], _user_ids[handle]);
} catch (const py::error_already_set& py_error) {
// This should behave the same as assert(!PyErr_Occurred())
// since constructor for pybind11's error_already_set is
// performing PyErr_Fetch which clears error indicator and
// saves it inside itself.
assert(py_error.type());
// acquire the mutex to access _errors
std::lock_guard<std::mutex> lock(_mutex);
_errors.push(py_error);
}
{
// acquire the mutex to access _idle_handles
std::lock_guard<std::mutex> lock(_mutex);
@ -130,6 +127,15 @@ public:
}
// Notify locks in getIdleRequestId()
_cv.notify_one();
try {
if (exception_ptr) {
std::rethrow_exception(exception_ptr);
}
} catch (const std::exception& e) {
// Notify locks in getIdleRequestId()
throw ov::Exception(e.what());
}
});
}
}

View File

@ -537,6 +537,35 @@ def test_infer_queue_fail_on_py_model(device):
assert "unsupported operand type(s) for +" in str(e.value)
@pytest.mark.parametrize("with_callback", [False, True])
def test_infer_queue_fail_in_inference(device, with_callback):
jobs = 6
num_request = 4
core = Core()
data = ops.parameter([5, 2], dtype=np.float32, name="data")
indexes = ops.parameter(Shape([3, 2]), dtype=np.int32, name="indexes")
emb = ops.embedding_bag_packed_sum(data, indexes)
model = Model(emb, [data, indexes])
compiled_model = core.compile_model(model, device)
infer_queue = AsyncInferQueue(compiled_model, num_request)
def callback(request, _):
pytest.fail("Callback should not be called")
if with_callback:
infer_queue.set_callback(callback)
data_tensor = Tensor(np.arange(10).reshape((5, 2)).astype(np.float32))
indexes_tensor = Tensor(np.array([[100, 101], [102, 103], [104, 105]], dtype=np.int32))
with pytest.raises(RuntimeError) as e:
for _ in range(jobs):
infer_queue.start_async({"data": data_tensor, "indexes": indexes_tensor})
infer_queue.wait_all()
assert "has invalid embedding bag index:" in str(e.value)
def test_infer_queue_get_idle_handle(device):
param = ops.parameter([10])
model = Model(ops.relu(param), [param])