Merged base implementation files (#3699)
This commit is contained in:
parent
bfd8f1372c
commit
943e511c58
@ -87,7 +87,8 @@ void MultiDeviceAsyncInferRequest::Infer_ThreadUnsafe() {
|
||||
InferUsingAsync();
|
||||
}
|
||||
|
||||
void MultiDeviceAsyncInferRequest::GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
|
||||
void MultiDeviceAsyncInferRequest::GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
|
||||
CheckBusy();
|
||||
perfMap = std::move(_perfMap);
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,7 @@ public:
|
||||
const MultiDeviceExecutableNetwork::Ptr& multiDeviceExecutableNetwork,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
||||
void Infer_ThreadUnsafe() override;
|
||||
void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &_perfMap) const override;
|
||||
void GetPerformanceCounts(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &_perfMap) const override;
|
||||
~MultiDeviceAsyncInferRequest() override;
|
||||
|
||||
protected:
|
||||
|
@ -8,8 +8,8 @@
|
||||
#include <threading/ie_itask_executor.hpp>
|
||||
#include <threading/ie_istreams_executor.hpp>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
|
||||
#include <cpp_interfaces/exception2status.hpp>
|
||||
#include <ie_system_conf.h>
|
||||
|
||||
@ -40,7 +40,7 @@ namespace InferenceEngine {
|
||||
*
|
||||
* @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
|
||||
*/
|
||||
class AsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeInternal {
|
||||
class AsyncInferRequestThreadSafeDefault : public IAsyncInferRequestInternal {
|
||||
using AtomicCallback = std::atomic<IInferRequest::CompletionCallback>;
|
||||
using Futures = std::vector<std::shared_future<void>>;
|
||||
using Promise = std::shared_ptr<std::promise<void>>;
|
||||
@ -143,6 +143,72 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void StartAsync() override {
|
||||
if (setIsRequestBusy(true)) ThrowBusy();
|
||||
try {
|
||||
StartAsync_ThreadUnsafe();
|
||||
} catch (...) {
|
||||
setIsRequestBusy(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void Infer() override {
|
||||
if (setIsRequestBusy(true)) ThrowBusy();
|
||||
try {
|
||||
Infer_ThreadUnsafe();
|
||||
} catch (...) {
|
||||
setIsRequestBusy(false);
|
||||
throw;
|
||||
}
|
||||
setIsRequestBusy(false);
|
||||
}
|
||||
|
||||
void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
|
||||
CheckBusy();
|
||||
_syncRequest->GetPerformanceCounts(perfMap);
|
||||
}
|
||||
|
||||
void SetBlob(const char* name, const Blob::Ptr& data) override {
|
||||
CheckBusy();
|
||||
_syncRequest->SetBlob(name, data);
|
||||
}
|
||||
|
||||
void SetBlob(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
|
||||
CheckBusy();
|
||||
_syncRequest->SetBlob(name, data, info);
|
||||
}
|
||||
|
||||
void GetBlob(const char* name, Blob::Ptr& data) override {
|
||||
CheckBusy();
|
||||
_syncRequest->GetBlob(name, data);
|
||||
}
|
||||
|
||||
void GetPreProcess(const char* name, const PreProcessInfo** info) const override {
|
||||
_syncRequest->GetPreProcess(name, info);
|
||||
}
|
||||
|
||||
void SetBatch(int batch) override {
|
||||
CheckBusy();
|
||||
_syncRequest->SetBatch(batch);
|
||||
};
|
||||
|
||||
void GetUserData(void** data) override {
|
||||
CheckBusy();
|
||||
if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
|
||||
*data = _userData;
|
||||
}
|
||||
|
||||
void SetUserData(void* data) override {
|
||||
CheckBusy();
|
||||
_userData = data;
|
||||
}
|
||||
|
||||
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
|
||||
CheckBusy();
|
||||
_callback = callback;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets the pointer to public interface.
|
||||
* @note Needed to correctly handle ownership between objects
|
||||
@ -174,6 +240,37 @@ protected:
|
||||
*/
|
||||
using Pipeline = std::vector<Stage>;
|
||||
|
||||
/**
|
||||
* @brief Determines if request busy.
|
||||
* @return `True` if request busy, `false` otherwise.
|
||||
*/
|
||||
bool isRequestBusy() const {
|
||||
return _isRequestBusy;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets the is request busy.
|
||||
* @param[in] isBusy Indicates if busy
|
||||
* @return `True` is case of success, `false` otherwise.
|
||||
*/
|
||||
bool setIsRequestBusy(bool isBusy) {
|
||||
return _isRequestBusy.exchange(isBusy);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Throws an exception that an inference request is busy.
|
||||
*/
|
||||
[[noreturn]] static void ThrowBusy() {
|
||||
THROW_IE_EXCEPTION << InferenceEngine::details::as_status << StatusCode::REQUEST_BUSY << REQUEST_BUSY_str;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks whether an inference request is busy and calls ThrowBusy if `true`
|
||||
*/
|
||||
void CheckBusy() const {
|
||||
if (isRequestBusy()) ThrowBusy();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates and run the first stage task. If destructor was not called add a new std::future to the
|
||||
* AsyncInferRequestThreadSafeDefault::_futures list that would be used to wait
|
||||
@ -262,52 +359,23 @@ protected:
|
||||
Pipeline _pipeline; //!< Pipeline variable that should be filled by inherited class.
|
||||
Pipeline _syncPipeline; //!< Synchronous pipeline variable that should be filled by inherited class.
|
||||
|
||||
void StartAsync_ThreadUnsafe() override {
|
||||
/**
|
||||
* @brief Starts an asynchronous pipeline thread unsafe.
|
||||
* @note Used by StartAsync which ensures thread-safety and calls this method after.
|
||||
*/
|
||||
virtual void StartAsync_ThreadUnsafe() {
|
||||
_syncRequest->checkBlobs();
|
||||
RunFirstStage(_pipeline.begin(), _pipeline.end(), _callbackExecutor);
|
||||
}
|
||||
|
||||
void Infer_ThreadUnsafe() override {
|
||||
/**
|
||||
* @brief Performs inference of pipeline in syncronous mode
|
||||
* @note Used by Infer which ensures thread-safety and calls this method after.
|
||||
*/
|
||||
virtual void Infer_ThreadUnsafe() {
|
||||
InferUsingSync();
|
||||
}
|
||||
|
||||
void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
|
||||
_syncRequest->GetPerformanceCounts(perfMap);
|
||||
}
|
||||
|
||||
void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data) override {
|
||||
_syncRequest->SetBlob(name, data);
|
||||
}
|
||||
|
||||
void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
|
||||
_syncRequest->SetBlob(name, data, info);
|
||||
}
|
||||
|
||||
void GetBlob_ThreadUnsafe(const char* name, Blob::Ptr& data) override {
|
||||
_syncRequest->GetBlob(name, data);
|
||||
}
|
||||
|
||||
void GetPreProcess_ThreadUnsafe(const char* name, const PreProcessInfo** info) const override {
|
||||
_syncRequest->GetPreProcess(name, info);
|
||||
}
|
||||
|
||||
void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) override {
|
||||
_callback = callback;
|
||||
}
|
||||
|
||||
void GetUserData_ThreadUnsafe(void** data) override {
|
||||
if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
|
||||
*data = _userData;
|
||||
}
|
||||
|
||||
void SetUserData_ThreadUnsafe(void* data) override {
|
||||
_userData = data;
|
||||
}
|
||||
|
||||
void SetBatch_ThreadUnsafe(int batch) override {
|
||||
_syncRequest->SetBatch(batch);
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Create a task with next pipeline stage.
|
||||
@ -378,6 +446,7 @@ private:
|
||||
}, std::move(callbackExecutor));
|
||||
}
|
||||
|
||||
std::atomic_bool _isRequestBusy = {false};
|
||||
void* _userData = nullptr;
|
||||
AtomicCallback _callback = {nullptr};
|
||||
IInferRequest::Ptr _publicInterface;
|
||||
|
@ -1,228 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @brief Wrapper of asynchronous inference request to support thread-safe execution.
|
||||
* @ingroup ie_dev_api_async_infer_request_api
|
||||
*/
|
||||
class AsyncInferRequestThreadSafeInternal : public IAsyncInferRequestInternal {
|
||||
std::atomic_bool _isRequestBusy = {false};
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Determines if request busy.
|
||||
* @return `True` if request busy, `false` otherwise.
|
||||
*/
|
||||
virtual bool isRequestBusy() const {
|
||||
return _isRequestBusy;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets the is request busy.
|
||||
* @param[in] isBusy Indicates if busy
|
||||
* @return `True` is case of success, `false` otherwise.
|
||||
*/
|
||||
virtual bool setIsRequestBusy(bool isBusy) {
|
||||
return _isRequestBusy.exchange(isBusy);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Throws an exception that an inference request is busy.
|
||||
*/
|
||||
[[noreturn]] static void ThrowBusy() {
|
||||
THROW_IE_EXCEPTION << InferenceEngine::details::as_status << StatusCode::REQUEST_BUSY << REQUEST_BUSY_str;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks whether an inference request is busy and calls ThrowBusy if `true`
|
||||
*/
|
||||
void CheckBusy() const {
|
||||
if (isRequestBusy()) ThrowBusy();
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to a AsyncInferRequestThreadSafeInternal implementation
|
||||
*/
|
||||
typedef std::shared_ptr<AsyncInferRequestThreadSafeInternal> Ptr;
|
||||
|
||||
/**
|
||||
* @brief Constructs a new instance.
|
||||
*/
|
||||
AsyncInferRequestThreadSafeInternal() {
|
||||
setIsRequestBusy(false);
|
||||
}
|
||||
|
||||
void StartAsync() override {
|
||||
if (setIsRequestBusy(true)) ThrowBusy();
|
||||
try {
|
||||
StartAsync_ThreadUnsafe();
|
||||
} catch (...) {
|
||||
setIsRequestBusy(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void GetUserData(void** data) override {
|
||||
CheckBusy();
|
||||
GetUserData_ThreadUnsafe(data);
|
||||
}
|
||||
|
||||
void SetUserData(void* data) override {
|
||||
CheckBusy();
|
||||
SetUserData_ThreadUnsafe(data);
|
||||
}
|
||||
|
||||
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
|
||||
CheckBusy();
|
||||
SetCompletionCallback_ThreadUnsafe(callback);
|
||||
}
|
||||
|
||||
void Infer() override {
|
||||
if (setIsRequestBusy(true)) ThrowBusy();
|
||||
try {
|
||||
Infer_ThreadUnsafe();
|
||||
} catch (...) {
|
||||
setIsRequestBusy(false);
|
||||
throw;
|
||||
}
|
||||
setIsRequestBusy(false);
|
||||
}
|
||||
|
||||
void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
|
||||
CheckBusy();
|
||||
GetPerformanceCounts_ThreadUnsafe(perfMap);
|
||||
}
|
||||
|
||||
void SetBlob(const char* name, const Blob::Ptr& data) override {
|
||||
CheckBusy();
|
||||
SetBlob_ThreadUnsafe(name, data);
|
||||
}
|
||||
|
||||
void SetBlob(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
|
||||
CheckBusy();
|
||||
SetBlob_ThreadUnsafe(name, data, info);
|
||||
}
|
||||
|
||||
void GetBlob(const char* name, Blob::Ptr& data) override {
|
||||
CheckBusy();
|
||||
GetBlob_ThreadUnsafe(name, data);
|
||||
}
|
||||
|
||||
void GetPreProcess(const char* name, const PreProcessInfo** info) const override {
|
||||
GetPreProcess_ThreadUnsafe(name, info);
|
||||
}
|
||||
|
||||
void SetBatch(int batch) override {
|
||||
CheckBusy();
|
||||
SetBatch_ThreadUnsafe(batch);
|
||||
};
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Starts an asynchronous pipeline thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::StartAsync which ensures thread-safety
|
||||
* and calls this method after.
|
||||
*/
|
||||
virtual void StartAsync_ThreadUnsafe() = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the user data thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::GetUserData which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param data The user data
|
||||
*/
|
||||
virtual void GetUserData_ThreadUnsafe(void** data) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets the user data thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::SetUserData which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param data The user data
|
||||
*/
|
||||
virtual void SetUserData_ThreadUnsafe(void* data) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets the completion callback thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::SetCompletionCallback which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param[in] callback The callback to set
|
||||
*/
|
||||
virtual void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) = 0;
|
||||
|
||||
/**
|
||||
* @brief Performs inference of pipeline in syncronous mode
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::Infer which ensures thread-safety
|
||||
* and calls this method after.
|
||||
*/
|
||||
virtual void Infer_ThreadUnsafe() = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the performance counts thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::GetPerformanceCounts which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param perfMap The performance map
|
||||
*/
|
||||
virtual void GetPerformanceCounts_ThreadUnsafe(
|
||||
std::map<std::string, InferenceEngineProfileInfo>& perfMap) const = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets the blob thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::SetBlob which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param[in] name The name of input / output data to set a blob to
|
||||
* @param[in] data The blob to set
|
||||
*/
|
||||
virtual void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets the blob with preprocessing information thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::SetBlob which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param[in] name The name of input / output data to set a blob to
|
||||
* @param[in] data The blob to set
|
||||
* @param[in] info The preprocessing information
|
||||
*/
|
||||
virtual void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the input or output blob thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::GetBlob which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param[in] name The name of input / output data to get a blob for
|
||||
* @param data The data
|
||||
*/
|
||||
virtual void GetBlob_ThreadUnsafe(const char* name, Blob::Ptr& data) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the preprocessing information thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::GetPreProcess which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param[in] name The name of input / output data to get a processing information for
|
||||
* @param info The preprocessing information
|
||||
*/
|
||||
virtual void GetPreProcess_ThreadUnsafe(const char* name, const PreProcessInfo** info) const = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets the dynamic batch thread unsafe.
|
||||
* @note Used by AsyncInferRequestThreadSafeInternal::SetBatch which ensures thread-safety
|
||||
* and calls this method after.
|
||||
* @param[in] batch The dynamic batch value
|
||||
*/
|
||||
virtual void SetBatch_ThreadUnsafe(int batch) = 0;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
@ -13,7 +13,6 @@
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_network_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_default.hpp"
|
||||
|
@ -1,67 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class MockAsyncInferRequestThreadSafeInternal : public AsyncInferRequestThreadSafeInternal {
|
||||
public:
|
||||
typedef std::shared_ptr<MockAsyncInferRequestThreadSafeInternal> Ptr;
|
||||
|
||||
void setRequestBusy() {
|
||||
AsyncInferRequestThreadSafeInternal::setIsRequestBusy(true);
|
||||
}
|
||||
using AsyncInferRequestThreadSafeInternal::isRequestBusy;
|
||||
bool isRequestBusy() {
|
||||
return AsyncInferRequestThreadSafeInternal::isRequestBusy();
|
||||
}
|
||||
|
||||
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
|
||||
|
||||
MOCK_METHOD0(StartAsync_ThreadUnsafe, void());
|
||||
|
||||
MOCK_METHOD1(GetUserData_ThreadUnsafe, void(void * *));
|
||||
|
||||
MOCK_METHOD1(SetUserData_ThreadUnsafe, void(void *));
|
||||
|
||||
MOCK_METHOD0(Infer_ThreadUnsafe, void());
|
||||
|
||||
MOCK_CONST_METHOD1(GetPerformanceCounts_ThreadUnsafe, void(std::map<std::string, InferenceEngineProfileInfo>
|
||||
&));
|
||||
|
||||
MOCK_METHOD2(GetBlob_ThreadUnsafe, void(
|
||||
const char *name, Blob::Ptr
|
||||
&));
|
||||
|
||||
MOCK_CONST_METHOD2(GetPreProcess_ThreadUnsafe, void(
|
||||
const char* name,
|
||||
const PreProcessInfo** info));
|
||||
|
||||
MOCK_METHOD2(SetBlob_ThreadUnsafe, void(
|
||||
const char *name,
|
||||
const Blob::Ptr &));
|
||||
|
||||
MOCK_METHOD3(SetBlob_ThreadUnsafe, void(
|
||||
const char* name,
|
||||
const Blob::Ptr&,
|
||||
const PreProcessInfo&));
|
||||
|
||||
MOCK_METHOD1(SetCompletionCallback_ThreadUnsafe, void(IInferRequest::CompletionCallback));
|
||||
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(void));
|
||||
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
};
|
@ -15,7 +15,6 @@
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/mock_task_executor.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
@ -260,116 +259,3 @@ TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailed
|
||||
testRequest->StartAsync();
|
||||
EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
|
||||
}
|
||||
|
||||
|
||||
class AsyncInferRequestThreadSafeInternalTests : public ::testing::Test {
|
||||
protected:
|
||||
MockAsyncInferRequestThreadSafeInternal::Ptr testRequest;
|
||||
ResponseDesc dsc;
|
||||
|
||||
bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
|
||||
std::string whatMessage;
|
||||
try {
|
||||
func();
|
||||
} catch (const InferenceEngineException &iee) {
|
||||
whatMessage = iee.what();
|
||||
}
|
||||
return whatMessage.find(refError) != std::string::npos;
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
testRequest = make_shared<MockAsyncInferRequestThreadSafeInternal>();
|
||||
}
|
||||
};
|
||||
|
||||
// StartAsync
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnStartAsync) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, canResetBusyStatusIfStartAsyncTaskFails) {
|
||||
EXPECT_CALL(*testRequest.get(), StartAsync_ThreadUnsafe()).Times(2)
|
||||
.WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
|
||||
.WillOnce(Return());
|
||||
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { testRequest->StartAsync(); }, "compare"));
|
||||
ASSERT_NO_THROW(testRequest->StartAsync());
|
||||
}
|
||||
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, deviceBusyAfterStartAsync) {
|
||||
EXPECT_CALL(*testRequest.get(), StartAsync_ThreadUnsafe()).WillOnce(Return());
|
||||
|
||||
ASSERT_NO_THROW(testRequest->StartAsync());
|
||||
|
||||
ASSERT_TRUE(testRequest->isRequestBusy());
|
||||
}
|
||||
|
||||
// GetUserData
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetUserData) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
// SetUserData
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetUserData) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
// Wait
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnInferNotStartedOnWait) {
|
||||
testRequest->setRequestBusy();
|
||||
int64_t ms = 0;
|
||||
EXPECT_CALL(*testRequest.get(), Wait(ms)).WillOnce(Return(INFER_NOT_STARTED));
|
||||
|
||||
StatusCode actual = testRequest->Wait(ms);
|
||||
ASSERT_EQ(INFER_NOT_STARTED, actual);
|
||||
}
|
||||
|
||||
// Infer
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnInfer) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, canResetBusyStatusIfInferFails) {
|
||||
EXPECT_CALL(*testRequest.get(), Infer_ThreadUnsafe()).Times(2)
|
||||
.WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
|
||||
.WillOnce(Return());
|
||||
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { testRequest->Infer(); }, "compare"));
|
||||
ASSERT_NO_THROW(testRequest->Infer());
|
||||
}
|
||||
|
||||
// GetPerformanceCounts
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetPerformanceCounts) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
|
||||
testRequest->GetPerformanceCounts(info);
|
||||
}, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
// GetBlob
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetBlob) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
|
||||
Blob::Ptr data;
|
||||
testRequest->GetBlob(nullptr, data);
|
||||
}, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
// SetBlob
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetBlob) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
// SetCompletionCallback
|
||||
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetCompletionCallback) {
|
||||
testRequest->setRequestBusy();
|
||||
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },
|
||||
REQUEST_BUSY_str));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user