Optimized Infer Request Scheduling (#3300)

* Optimized Infer Request Scheduling

* Fixed misprint

* Brushing the code and comments a bit

* further brushing of the ScheduleToWorkerRequest: moving the task execution directly into the loop over devices (avoids pointers and 'else' clause)

Co-authored-by: Maxim Shevtsov <maxim.y.shevtsov@intel.com>
This commit is contained in:
Anton Pankratv 2020-11-27 16:37:57 +03:00 committed by GitHub
parent 6aa7c51de9
commit 4b44608b3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 26 deletions

View File

@ -36,7 +36,7 @@ struct IdleGuard {
}
~IdleGuard() {
if (nullptr != _notBusyWorkerRequests) {
_notBusyWorkerRequests->push(_workerInferRequestPtr);
_notBusyWorkerRequests->try_push(_workerInferRequestPtr);
}
}
MultiDeviceExecutableNetwork::NotBusyWorkerRequests* Release() {
@ -79,10 +79,11 @@ MultiDeviceExecutableNetwork::MultiDeviceExecutableNetwork(const DeviceMap<Infer
auto& idleWorkerRequests = _idleWorkerRequests[device];
workerRequests.resize(numRequests);
auto* idleWorkerRequestsPtr = &(idleWorkerRequests);
idleWorkerRequests.set_capacity(numRequests);
for (auto&& workerRequest : workerRequests) {
workerRequest._inferRequest = network.CreateInferRequest();
auto* workerRequestPtr = &workerRequest;
idleWorkerRequests.push(workerRequestPtr);
IE_ASSERT(idleWorkerRequests.try_push(workerRequestPtr) == true);
workerRequest._inferRequest.SetCompletionCallback<std::function<void(InferRequest, StatusCode)>>(
[workerRequestPtr, this, device, idleWorkerRequestsPtr] (InferRequest , StatusCode status) mutable {
IdleGuard idleGuard{workerRequestPtr, *idleWorkerRequestsPtr};
@ -91,41 +92,44 @@ MultiDeviceExecutableNetwork::MultiDeviceExecutableNetwork(const DeviceMap<Infer
auto capturedTask = std::move(workerRequestPtr->_task);
capturedTask();
}
if (!_terminate) {
idleGuard.Release()->push(workerRequestPtr);
ScheduleToWorkerInferRequest();
// try to return the request to the idle list (fails if the overall object destruction has began)
if (idleGuard.Release()->try_push(workerRequestPtr)) {
// try pop the task, as we know there is at least one idle request
if (_inferPipelineTasks.try_pop(workerRequestPtr->_task)) {
// if succeeded, let's schedule that
ScheduleToWorkerInferRequest(std::move(workerRequestPtr->_task));
}
}
});
}
}
}
void MultiDeviceExecutableNetwork::ScheduleToWorkerInferRequest() {
void MultiDeviceExecutableNetwork::ScheduleToWorkerInferRequest(Task inferPipelineTask) {
auto devices = [&] {
std::lock_guard<std::mutex> lock(_mutex);
return _devicePriorities;
}();
for (auto&& device : devices) {
auto& idleWorkerRequests = _idleWorkerRequests[device.deviceName];
WorkerInferRequest* workerRequestPtr = nullptr;
NotBusyWorkerRequests& idleWorkerRequests = _idleWorkerRequests[device.deviceName];
if (idleWorkerRequests.try_pop(workerRequestPtr)) {
IdleGuard idleGuard{workerRequestPtr, idleWorkerRequests};
Task inferPipelineTask;
if (_inferPipelineTasks.try_pop(inferPipelineTask)) {
_thisWorkerInferRequest = workerRequestPtr;
inferPipelineTask();
idleGuard.Release();
break;
_thisWorkerInferRequest = workerRequestPtr;
{
auto capturedTask = std::move(inferPipelineTask);
capturedTask();
}
idleGuard.Release();
return;
}
}
// no vacant requests this time, storing the task to the queue
_inferPipelineTasks.push(std::move(inferPipelineTask));
}
void MultiDeviceExecutableNetwork::run(Task inferPipelineTask) {
if (!_terminate) {
_inferPipelineTasks.push(std::move(inferPipelineTask));
ScheduleToWorkerInferRequest();
}
ScheduleToWorkerInferRequest(std::move(inferPipelineTask));
}
MultiDeviceExecutableNetwork::~MultiDeviceExecutableNetwork() {
@ -133,10 +137,13 @@ MultiDeviceExecutableNetwork::~MultiDeviceExecutableNetwork() {
std::lock_guard<std::mutex> lock(_mutex);
_devicePriorities.clear();
}
_terminate = true;
/* NOTE: The only threads that use `MultiDeviceExecutableNetwork` Context are those that are used by Worker infer requests.
* But AsyncInferRequest destructor should waits for all asynchronous tasks that are used by the request
/* NOTE: The only threads that use `MultiDeviceExecutableNetwork` worker infer requests' threads.
* But AsyncInferRequest destructor should wait for all asynchronous tasks by the request
*/
for (auto&& networkValue : _networksPerDevice) {
// stop accepting any idle requests back (for re-scheduling)
_idleWorkerRequests.at(networkValue.first).set_capacity(0);
}
_workerRequests.clear();
}

View File

@ -37,6 +37,8 @@ using DeviceMap = std::unordered_map<DeviceName, T>;
#if ((IE_THREAD == IE_THREAD_TBB) || (IE_THREAD == IE_THREAD_TBB_AUTO))
template <typename T>
using ThreadSafeQueue = tbb::concurrent_queue<T>;
template <typename T>
using ThreadSafeBoundedQueue = tbb::concurrent_bounded_queue<T>;
#else
template <typename T>
class ThreadSafeQueue {
@ -45,7 +47,6 @@ public:
std::lock_guard<std::mutex> lock(_mutex);
_queue.push(std::move(value));
}
bool try_pop(T& value) {
std::lock_guard<std::mutex> lock(_mutex);
if (!_queue.empty()) {
@ -56,15 +57,40 @@ public:
return false;
}
}
bool empty() {
protected:
std::queue<T> _queue;
std::mutex _mutex;
};
template <typename T>
class ThreadSafeBoundedQueue {
public:
ThreadSafeBoundedQueue() = default;
bool try_push(T value) {
std::lock_guard<std::mutex> lock(_mutex);
return _queue.empty();
if (_capacity) {
_queue.push(std::move(value));
}
return _capacity;
}
bool try_pop(T& value) {
std::lock_guard<std::mutex> lock(_mutex);
if (_capacity && !_queue.empty()) {
value = std::move(_queue.front());
_queue.pop();
return true;
} else {
return false;
}
}
void set_capacity(std::size_t newCapacity) {
std::lock_guard<std::mutex> lock(_mutex);
_capacity = newCapacity;
}
protected:
std::queue<T> _queue;
std::mutex _mutex;
bool _capacity = false;
};
#endif
@ -77,7 +103,7 @@ public:
InferenceEngine::Task _task;
InferenceEngine::StatusCode _status = InferenceEngine::StatusCode::OK;
};
using NotBusyWorkerRequests = ThreadSafeQueue<WorkerInferRequest*>;
using NotBusyWorkerRequests = ThreadSafeBoundedQueue<WorkerInferRequest*>;
explicit MultiDeviceExecutableNetwork(const DeviceMap<InferenceEngine::ExecutableNetwork>& networksPerDevice,
const std::vector<DeviceInformation>& networkDevices,
@ -93,7 +119,7 @@ public:
InferenceEngine::OutputsDataMap networkOutputs) override;
~MultiDeviceExecutableNetwork() override;
void ScheduleToWorkerInferRequest();
void ScheduleToWorkerInferRequest(InferenceEngine::Task);
static thread_local WorkerInferRequest* _thisWorkerInferRequest;
std::atomic_bool _terminate = {false};