Fix queue fails on callbacks (#8809)
This commit is contained in:
parent
31ae69ac2d
commit
bcf0879785
@ -40,7 +40,8 @@ public:
|
||||
_cv.wait(lock, [this] {
|
||||
return !(_idle_handles.empty());
|
||||
});
|
||||
|
||||
if (_errors.size() > 0)
|
||||
throw _errors.front();
|
||||
return !(_idle_handles.empty());
|
||||
}
|
||||
|
||||
@ -51,7 +52,8 @@ public:
|
||||
_cv.wait(lock, [this] {
|
||||
return !(_idle_handles.empty());
|
||||
});
|
||||
|
||||
if (_errors.size() > 0)
|
||||
throw _errors.front();
|
||||
return _idle_handles.front();
|
||||
}
|
||||
|
||||
@ -63,6 +65,8 @@ public:
|
||||
_cv.wait(lock, [this] {
|
||||
return _idle_handles.size() == _requests.size();
|
||||
});
|
||||
if (_errors.size() > 0)
|
||||
throw _errors.front();
|
||||
}
|
||||
|
||||
void set_default_callbacks() {
|
||||
@ -90,7 +94,12 @@ public:
|
||||
}
|
||||
// Acquire GIL, execute Python function
|
||||
py::gil_scoped_acquire acquire;
|
||||
f_callback(_requests[handle], _user_ids[handle]);
|
||||
try {
|
||||
f_callback(_requests[handle], _user_ids[handle]);
|
||||
} catch (py::error_already_set py_error) {
|
||||
assert(PyErr_Occurred());
|
||||
_errors.push(py_error);
|
||||
}
|
||||
// Add idle handle to queue
|
||||
_idle_handles.push(handle);
|
||||
// Notify locks in getIdleRequestId() or waitAll() functions
|
||||
@ -104,6 +113,7 @@ public:
|
||||
std::vector<py::object> _user_ids; // user ID can be any Python object
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cv;
|
||||
std::queue<py::error_already_set> _errors;
|
||||
};
|
||||
|
||||
void regclass_AsyncInferQueue(py::module m) {
|
||||
|
@ -212,6 +212,52 @@ def test_infer_queue(device):
|
||||
assert all(job["latency"] > 0 for job in jobs_done)
|
||||
|
||||
|
||||
def test_infer_queue_fail_on_cpp_func(device):
|
||||
jobs = 6
|
||||
num_request = 4
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
infer_queue = AsyncInferQueue(exec_net, num_request)
|
||||
|
||||
def callback(request, _):
|
||||
request.get_tensor("Unknown")
|
||||
|
||||
img = read_image()
|
||||
infer_queue.set_callback(callback)
|
||||
assert infer_queue.is_ready
|
||||
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
for _ in range(jobs):
|
||||
infer_queue.start_async({"data": img})
|
||||
infer_queue.wait_all()
|
||||
|
||||
assert "Port for tensor name Unknown was not found" in str(e.value)
|
||||
|
||||
|
||||
def test_infer_queue_fail_on_py_func(device):
|
||||
jobs = 1
|
||||
num_request = 1
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
infer_queue = AsyncInferQueue(exec_net, num_request)
|
||||
|
||||
def callback(request, _):
|
||||
request = request + 21
|
||||
|
||||
img = read_image()
|
||||
infer_queue.set_callback(callback)
|
||||
assert infer_queue.is_ready
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
for _ in range(jobs):
|
||||
infer_queue.start_async({"data": img})
|
||||
infer_queue.wait_all()
|
||||
|
||||
assert "unsupported operand type(s) for +" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data_type",
|
||||
[np.float32,
|
||||
np.int32,
|
||||
|
Loading…
Reference in New Issue
Block a user