stub pipeline task to comunicate the readiness rather than promise/future

This commit is contained in:
myshevts 2021-10-14 17:55:43 +03:00
parent c5c18f9eac
commit 3d2cd5fb71
2 changed files with 26 additions and 25 deletions

View File

@ -138,31 +138,37 @@ std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> AutoBatchInfe
}
void AutoBatchInferRequest::InferImpl() {
auto _event = _workerInferRequest->_event;
auto numReady = ++_workerInferRequest->_numRequestsReady;
// printf("!!! numReady: %d \n", numReady);
if (numReady == _workerInferRequest->_batchSize) {
_workerInferRequest->_numRequestsReady = 0;
_workerInferRequest->_inferRequest->StartAsync();
}
_event.get();
if (_needPerfCounters) {
_perfMap = _workerInferRequest->_inferRequest->GetPerformanceCounts();
}
}
AutoBatchAsyncInferRequest::AutoBatchAsyncInferRequest(
const AutoBatchInferRequest::Ptr& inferRequest,
const bool needPerfCounters,
const AutoBatchExecutableNetwork::Ptr& autoBatchExecutableNetwork,
const ITaskExecutor::Ptr& callbackExecutor) :
AsyncInferRequestThreadSafeDefault(inferRequest,
std::make_shared<CPUStreamsExecutor>(
IStreamsExecutor::Config{"AutoBatch", 1, 1,
IStreamsExecutor::ThreadBindingType::NONE, 1, 0, 1}),
callbackExecutor),
_AutoBatchExecutableNetwork{autoBatchExecutableNetwork},
AsyncInferRequestThreadSafeDefault(inferRequest, nullptr, callbackExecutor),
_inferRequest{inferRequest} {
// this executor starts the inference while the task (checking the result) is passed to the next stage
struct ThisRequestExecutor : public ITaskExecutor {
explicit ThisRequestExecutor(AutoBatchAsyncInferRequest* _this_) : _this{_this_} {}
void run(Task task) override {
auto& workerInferRequest = _this->_inferRequest->_workerInferRequest;
workerInferRequest->_tasks.push(std::move(task));
_this->_inferRequest->InferImpl();
};
AutoBatchAsyncInferRequest* _this = nullptr;
};
_pipeline = {
{ /*TaskExecutor*/ std::make_shared<ThisRequestExecutor>(this), /*task*/ [this, needPerfCounters] {
// TODO: exception checking
// if (needPerfCounters)
// _inferRequest->_perfMap = _inferRequest->_workerInferRequest->_inferRequest->GetPerformanceCounts();
}}
};
}
void AutoBatchAsyncInferRequest::Infer_ThreadUnsafe() {
@ -209,16 +215,15 @@ InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateIn
auto workerRequestPtr = _workerRequests.back();
workerRequestPtr->_inferRequest = {_network._so, _network->CreateInferRequest()};
workerRequestPtr->_batchSize = _device.batchForDevice;
workerRequestPtr->_cond = std::promise<void>();
workerRequestPtr->_event = workerRequestPtr->_cond.get_future().share();
// _idleWorkerRequests.push(workerRequestPtr);
workerRequestPtr->_inferRequest->SetCallback(
[workerRequestPtr, this] (std::exception_ptr exceptionPtr) mutable {
auto signal = std::move(workerRequestPtr->_cond);
// reset the promise/future for next use
workerRequestPtr->_cond = std::promise<void>();
workerRequestPtr->_event = workerRequestPtr->_cond.get_future().share();
signal.set_value();
Task t;
int num = 0;
while (workerRequestPtr->_tasks.try_pop(t)) {
t();
if (workerRequestPtr->_batchSize == ++num)
break;
}
});
}
return std::make_shared<AutoBatchInferRequest>(networkInputs, networkOutputs, _workerRequests.back().get(),
@ -230,7 +235,6 @@ InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateIn
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
return std::make_shared<AutoBatchAsyncInferRequest>(std::static_pointer_cast<AutoBatchInferRequest>(syncRequestImpl),
_needPerfCounters,
std::static_pointer_cast<AutoBatchExecutableNetwork>(shared_from_this()),
_callbackExecutor);
}

View File

@ -77,9 +77,8 @@ public:
InferenceEngine::SoIInferRequestInternal _inferRequest;
InferenceEngine::StatusCode _status = InferenceEngine::StatusCode::OK;
int _batchSize;
std::promise<void> _cond;
std::shared_future<void> _event;
std::atomic_int _numRequestsReady = {0};
ThreadSafeQueue<InferenceEngine::Task> _tasks;
};
using NotBusyWorkerRequests = ThreadSafeQueue<WorkerInferRequest*>;
@ -129,13 +128,11 @@ public:
explicit AutoBatchAsyncInferRequest(const AutoBatchInferRequest::Ptr& inferRequest,
const bool needPerfCounters,
const AutoBatchExecutableNetwork::Ptr& AutoBatchExecutableNetwork,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
void Infer_ThreadUnsafe() override;
virtual ~AutoBatchAsyncInferRequest();
protected:
AutoBatchExecutableNetwork::Ptr _AutoBatchExecutableNetwork;
AutoBatchInferRequest::Ptr _inferRequest;
};