fix the pre-mature timeout restaring via waiting for batch1 requests completion

This commit is contained in:
myshevts 2021-10-25 17:30:50 +03:00
parent 7560a9319e
commit 74040fed95

View File

@ -193,17 +193,17 @@ void AutoBatchInferRequest::InferImpl() {
std::unique_lock<std::mutex> lock(_workerInferRequest->_mutex);
int sz = _workerInferRequest->_tasks.unsafe_size();
if (sz == _workerInferRequest->_batchSize) {
// printf("!!! BATCH : %ld \n", _workerInferRequest->_tasks.unsafe_size());
// auto start = std::chrono::high_resolution_clock::now();
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t;
for (int c = 0; c < _workerInferRequest->_batchSize; c++) {
if (_workerInferRequest->_tasks.try_pop(t)) {
_workerInferRequest->_completionTasks[c] = std::move(t.second);
t.first->_inferRequest->CopyInputsIfNeeded();
} else {
printf("!!! BUG !!! \n");
}
IE_ASSERT(_workerInferRequest->_tasks.try_pop(t));
_workerInferRequest->_completionTasks[c] = std::move(t.second);
t.first->_inferRequest->CopyInputsIfNeeded();
}
_workerInferRequest->_inferRequest->StartAsync();
// auto waitTime = std::chrono::duration_cast<std::chrono::microseconds>
// (std::chrono::high_resolution_clock::now() - start).count();
// std::cout << "START BATCH: " << waitTime << std::endl;
}
}
@ -291,7 +291,7 @@ InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateIn
workerRequestPtr->_inferRequest->SetCallback(
[workerRequestPtr, this] (std::exception_ptr exceptionPtr) mutable {
IE_ASSERT(workerRequestPtr->_completionTasks.size() == (size_t)workerRequestPtr->_batchSize);
// notify the ibvidual requests on the completion
// notify the individual requests on the completion
for (int c = 0; c < workerRequestPtr->_batchSize; c++) {
workerRequestPtr->_completionTasks[c]();
}
@ -300,21 +300,40 @@ InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateIn
});
workerRequestPtr->_thread = std::thread([workerRequestPtr, this] {
while (!_terminate) {
while (1) {
std::unique_lock<std::mutex> lock(workerRequestPtr->_mutex);
auto status = workerRequestPtr->_cond.wait_for(lock, std::chrono::milliseconds(100));
if (!_terminate && status == std::cv_status::timeout) {
auto status = workerRequestPtr->_cond.wait_for(lock, std::chrono::milliseconds(30));
if (_terminate) {
break;
} else if (status == std::cv_status::timeout) {
// timeout to collect the batch is over, have to execute the requests in the batch1 mode
auto sz = workerRequestPtr->_tasks.unsafe_size();
IE_ASSERT(sz < (size_t)_device.batchForDevice);
if (sz)
std::cout << "TIME_OUT with tasks: " << sz << std::endl;
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t;
// popping all tasks and execute with batch1
while (workerRequestPtr->_tasks.try_pop(t)) {
t.first->_inferRequestWithoutBatch->SetCallback([t](std::exception_ptr){t.second();});
t.first->_inferRequest->SetBlobsToAnotherRequest(t.first->_inferRequestWithoutBatch);
t.first->_inferRequestWithoutBatch->StartAsync();
// as we pop the tasks from the queue only here (and when the batch has been fully collected)
// both places are guarded with the same mutex
// it is ok to call unsafe_size (as the _tasks can only grow in parallel)
int sz = workerRequestPtr->_tasks.unsafe_size();
if (sz) {
auto start = std::chrono::high_resolution_clock::now();
std::pair<AutoBatchAsyncInferRequest *, InferenceEngine::Task> t;
// popping all tasks collected by the moment of the time-out and execute each with batch1
std::atomic<int> arrived = {0};
std::promise<void> all_completed;
auto all_completed_future = all_completed.get_future();
for (int n =0; n < sz; n++) {
IE_ASSERT(workerRequestPtr->_tasks.try_pop(t));
t.first->_inferRequestWithoutBatch->SetCallback(
[t, sz, &arrived, &all_completed](std::exception_ptr) {
t.second();
if (sz == ++arrived)
all_completed.set_value();
});
t.first->_inferRequest->SetBlobsToAnotherRequest(t.first->_inferRequestWithoutBatch);
t.first->_inferRequestWithoutBatch->StartAsync();
}
auto execTime = std::chrono::duration_cast<std::chrono::microseconds>
(std::chrono::high_resolution_clock::now() - start).count();
std::cout << "thread::timeout: " << execTime << " micros, tasks: " << sz << std::endl;
all_completed_future.get();
// now when all the tasks for this batch are completed, start waiting for the timeout again
}
}
}