Fix queue fails on callbacks (#8809)

This commit is contained in:
Jan Iwaszkiewicz 2021-11-26 12:44:11 +01:00 committed by GitHub
parent 31ae69ac2d
commit bcf0879785
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 3 deletions

View File

@ -40,7 +40,8 @@ public:
_cv.wait(lock, [this] {
return !(_idle_handles.empty());
});
if (_errors.size() > 0)
throw _errors.front();
return !(_idle_handles.empty());
}
@ -51,7 +52,8 @@ public:
_cv.wait(lock, [this] {
return !(_idle_handles.empty());
});
if (_errors.size() > 0)
throw _errors.front();
return _idle_handles.front();
}
@ -63,6 +65,8 @@ public:
_cv.wait(lock, [this] {
return _idle_handles.size() == _requests.size();
});
if (_errors.size() > 0)
throw _errors.front();
}
void set_default_callbacks() {
@ -90,7 +94,12 @@ public:
}
// Acquire GIL, execute Python function
py::gil_scoped_acquire acquire;
f_callback(_requests[handle], _user_ids[handle]);
try {
f_callback(_requests[handle], _user_ids[handle]);
} catch (py::error_already_set py_error) {
assert(PyErr_Occurred());
_errors.push(py_error);
}
// Add idle handle to queue
_idle_handles.push(handle);
// Notify locks in getIdleRequestId() or waitAll() functions
@ -104,6 +113,7 @@ public:
std::vector<py::object> _user_ids; // user ID can be any Python object
std::mutex _mutex;
std::condition_variable _cv;
std::queue<py::error_already_set> _errors;
};
void regclass_AsyncInferQueue(py::module m) {

View File

@ -212,6 +212,52 @@ def test_infer_queue(device):
assert all(job["latency"] > 0 for job in jobs_done)
def test_infer_queue_fail_on_cpp_func(device):
jobs = 6
num_request = 4
core = Core()
func = core.read_model(test_net_xml, test_net_bin)
exec_net = core.compile_model(func, device)
infer_queue = AsyncInferQueue(exec_net, num_request)
def callback(request, _):
request.get_tensor("Unknown")
img = read_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready
with pytest.raises(RuntimeError) as e:
for _ in range(jobs):
infer_queue.start_async({"data": img})
infer_queue.wait_all()
assert "Port for tensor name Unknown was not found" in str(e.value)
def test_infer_queue_fail_on_py_func(device):
jobs = 1
num_request = 1
core = Core()
func = core.read_model(test_net_xml, test_net_bin)
exec_net = core.compile_model(func, device)
infer_queue = AsyncInferQueue(exec_net, num_request)
def callback(request, _):
request = request + 21
img = read_image()
infer_queue.set_callback(callback)
assert infer_queue.is_ready
with pytest.raises(TypeError) as e:
for _ in range(jobs):
infer_queue.start_async({"data": img})
infer_queue.wait_all()
assert "unsupported operand type(s) for +" in str(e.value)
@pytest.mark.parametrize("data_type",
[np.float32,
np.int32,