Remove from unwaited requests ids if request aborted (#3160)

This commit is contained in:
Krzysztof Bruniecki 2020-11-19 16:08:04 +01:00 committed by GitHub
parent 3f2ac0ff55
commit 0342f51cd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 6 deletions

View File

@ -86,7 +86,7 @@ uint32_t GNADeviceHelper::propagate(const uint32_t requestConfigId, Gna2Accelera
const auto status2 = Gna2RequestEnqueue(requestConfigId, &reqId); const auto status2 = Gna2RequestEnqueue(requestConfigId, &reqId);
checkGna2Status(status2, "Gna2RequestEnqueue"); checkGna2Status(status2, "Gna2RequestEnqueue");
unwaitedRequestIds.push_back(reqId); unwaitedRequestIds.insert(reqId);
return reqId; return reqId;
} }
@ -329,14 +329,14 @@ const std::map <const std::pair<Gna2OperationType, int32_t>, const std::string>
GnaWaitStatus GNADeviceHelper::wait(uint32_t reqId, int64_t millisTimeout) { GnaWaitStatus GNADeviceHelper::wait(uint32_t reqId, int64_t millisTimeout) {
#if GNA_LIB_VER == 2 #if GNA_LIB_VER == 2
const auto status = Gna2RequestWait(reqId, millisTimeout); const auto status = Gna2RequestWait(reqId, millisTimeout);
if (status == Gna2StatusDriverQoSTimeoutExceeded) {
return GNA_REQUEST_ABORTED;
}
if (status == Gna2StatusWarningDeviceBusy) { if (status == Gna2StatusWarningDeviceBusy) {
return GNA_REQUEST_PENDING; return GNA_REQUEST_PENDING;
} }
unwaitedRequestIds.erase(reqId);
if (status == Gna2StatusDriverQoSTimeoutExceeded) {
return GNA_REQUEST_ABORTED;
}
checkGna2Status(status, "Gna2RequestWait"); checkGna2Status(status, "Gna2RequestWait");
unwaitedRequestIds.erase(std::remove(unwaitedRequestIds.begin(), unwaitedRequestIds.end(), reqId));
#else #else
if (isPerformanceMeasuring) { if (isPerformanceMeasuring) {
nGNAStatus = GNAWaitPerfRes(nGNAHandle, millisTimeout, reqId, &nGNAPerfResults); nGNAStatus = GNAWaitPerfRes(nGNAHandle, millisTimeout, reqId, &nGNAPerfResults);

View File

@ -9,6 +9,7 @@
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <map> #include <map>
#include <set>
#include <vector> #include <vector>
#include <thread> #include <thread>
@ -61,7 +62,7 @@ class GNADeviceHelper {
uint64_t instrumentationResults[TotalGna2InstrumentationPoints] = {}; uint64_t instrumentationResults[TotalGna2InstrumentationPoints] = {};
uint64_t instrumentationTotal[TotalGna2InstrumentationPoints] = {}; uint64_t instrumentationTotal[TotalGna2InstrumentationPoints] = {};
uint32_t instrumentationConfigId = 0; uint32_t instrumentationConfigId = 0;
std::vector<uint32_t> unwaitedRequestIds; std::set<uint32_t> unwaitedRequestIds;
#define MAX_TIMEOUT 500000 #define MAX_TIMEOUT 500000
#endif #endif
bool isPerformanceMeasuring = false; bool isPerformanceMeasuring = false;