Fixed canStartSeveralAsyncInsideCompletionCallbackWithSafeDtor (#2404)

This commit is contained in:
Anton Pankratv 2020-09-28 11:08:34 +03:00 committed by GitHub
parent a14e12ee63
commit c8233b7b7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
#include <future>
#include "ie_extension.h" #include "ie_extension.h"
#include <condition_variable> #include <condition_variable>
#include "functional_test_utils/layer_test_utils.hpp" #include "functional_test_utils/layer_test_utils.hpp"
@ -100,12 +101,8 @@ TEST_P(CallbackTests, canStartSeveralAsyncInsideCompletionCallbackWithSafeDtor)
const int NUM_ITER = 10; const int NUM_ITER = 10;
struct TestUserData { struct TestUserData {
int numIter = NUM_ITER; std::atomic<int> numIter = {0};
bool startAsyncOK = true; std::promise<InferenceEngine::StatusCode> promise;
std::atomic<int> numIsCalled{0};
std::mutex mutex_block_emulation;
std::condition_variable cv_block_emulation;
bool isBlocked = true;
}; };
TestUserData data; TestUserData data;
@ -117,33 +114,28 @@ TEST_P(CallbackTests, canStartSeveralAsyncInsideCompletionCallbackWithSafeDtor)
InferenceEngine::InferRequest req = execNet.CreateInferRequest(); InferenceEngine::InferRequest req = execNet.CreateInferRequest();
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>( req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
[&](InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode status) { [&](InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode status) {
// HSD_1805940120: Wait on starting callback return HDDL_ERROR_INVAL_TASK_HANDLE if (status != InferenceEngine::StatusCode::OK) {
if (targetDevice != CommonTestUtils::DEVICE_HDDL) { data.promise.set_value(status);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), status); } else {
} if (data.numIter.fetch_add(1) != NUM_ITER) {
if (--data.numIter) { InferenceEngine::StatusCode sts = request->StartAsync(nullptr);
InferenceEngine::StatusCode sts = request->StartAsync(nullptr); if (sts != InferenceEngine::StatusCode::OK) {
if (sts != InferenceEngine::StatusCode::OK) { data.promise.set_value(sts);
data.startAsyncOK = false; }
} else {
data.promise.set_value(InferenceEngine::StatusCode::OK);
} }
} }
data.numIsCalled++;
if (!data.numIter) {
data.isBlocked = false;
data.cv_block_emulation.notify_all();
}
}); });
auto future = data.promise.get_future();
req.StartAsync(); req.StartAsync();
InferenceEngine::ResponseDesc responseWait;
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY); InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
// intentionally block until notification from callback ASSERT_EQ((int) InferenceEngine::StatusCode::OK, waitStatus);
std::unique_lock<std::mutex> lock(data.mutex_block_emulation); future.wait();
data.cv_block_emulation.wait(lock, [&]() { return !data.isBlocked; }); auto callbackStatus = future.get();
ASSERT_EQ((int) InferenceEngine::StatusCode::OK, callbackStatus);
ASSERT_EQ((int) InferenceEngine::StatusCode::OK, waitStatus) << responseWait.msg; auto dataNumIter = data.numIter - 1;
ASSERT_EQ(NUM_ITER, data.numIsCalled); ASSERT_EQ(NUM_ITER, dataNumIter);
ASSERT_TRUE(data.startAsyncOK);
} }
TEST_P(CallbackTests, inferDoesNotCallCompletionCallback) { TEST_P(CallbackTests, inferDoesNotCallCompletionCallback) {