Merged base implementation files (#3699)

This commit is contained in:
Anton Pankratv 2021-01-12 12:04:47 +03:00 committed by GitHub
parent bfd8f1372c
commit 943e511c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 453 deletions

View File

@ -87,7 +87,8 @@ void MultiDeviceAsyncInferRequest::Infer_ThreadUnsafe() {
InferUsingAsync();
}
void MultiDeviceAsyncInferRequest::GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
void MultiDeviceAsyncInferRequest::GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
CheckBusy();
perfMap = std::move(_perfMap);
}

View File

@ -26,7 +26,7 @@ public:
const MultiDeviceExecutableNetwork::Ptr& multiDeviceExecutableNetwork,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
void Infer_ThreadUnsafe() override;
void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &_perfMap) const override;
void GetPerformanceCounts(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &_perfMap) const override;
~MultiDeviceAsyncInferRequest() override;
protected:

View File

@ -8,8 +8,8 @@
#include <threading/ie_itask_executor.hpp>
#include <threading/ie_istreams_executor.hpp>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
#include <cpp_interfaces/exception2status.hpp>
#include <ie_system_conf.h>
@ -40,7 +40,7 @@ namespace InferenceEngine {
*
* @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
*/
class AsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeInternal {
class AsyncInferRequestThreadSafeDefault : public IAsyncInferRequestInternal {
using AtomicCallback = std::atomic<IInferRequest::CompletionCallback>;
using Futures = std::vector<std::shared_future<void>>;
using Promise = std::shared_ptr<std::promise<void>>;
@ -143,6 +143,72 @@ public:
}
}
void StartAsync() override {
if (setIsRequestBusy(true)) ThrowBusy();
try {
StartAsync_ThreadUnsafe();
} catch (...) {
setIsRequestBusy(false);
throw;
}
}
void Infer() override {
if (setIsRequestBusy(true)) ThrowBusy();
try {
Infer_ThreadUnsafe();
} catch (...) {
setIsRequestBusy(false);
throw;
}
setIsRequestBusy(false);
}
void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
CheckBusy();
_syncRequest->GetPerformanceCounts(perfMap);
}
void SetBlob(const char* name, const Blob::Ptr& data) override {
CheckBusy();
_syncRequest->SetBlob(name, data);
}
void SetBlob(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
CheckBusy();
_syncRequest->SetBlob(name, data, info);
}
void GetBlob(const char* name, Blob::Ptr& data) override {
CheckBusy();
_syncRequest->GetBlob(name, data);
}
void GetPreProcess(const char* name, const PreProcessInfo** info) const override {
_syncRequest->GetPreProcess(name, info);
}
void SetBatch(int batch) override {
CheckBusy();
_syncRequest->SetBatch(batch);
};
void GetUserData(void** data) override {
CheckBusy();
if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
*data = _userData;
}
void SetUserData(void* data) override {
CheckBusy();
_userData = data;
}
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
CheckBusy();
_callback = callback;
}
/**
* @brief Sets the pointer to public interface.
* @note Needed to correctly handle ownership between objects
@ -174,6 +240,37 @@ protected:
*/
using Pipeline = std::vector<Stage>;
/**
* @brief Determines if request busy.
* @return `True` if request busy, `false` otherwise.
*/
bool isRequestBusy() const {
return _isRequestBusy;
}
/**
* @brief Sets the is request busy.
* @param[in] isBusy Indicates if busy
* @return `True` is case of success, `false` otherwise.
*/
bool setIsRequestBusy(bool isBusy) {
return _isRequestBusy.exchange(isBusy);
}
/**
* @brief Throws an exception that an inference request is busy.
*/
[[noreturn]] static void ThrowBusy() {
THROW_IE_EXCEPTION << InferenceEngine::details::as_status << StatusCode::REQUEST_BUSY << REQUEST_BUSY_str;
}
/**
* @brief Checks whether an inference request is busy and calls ThrowBusy if `true`
*/
void CheckBusy() const {
if (isRequestBusy()) ThrowBusy();
}
/**
* @brief Creates and run the first stage task. If destructor was not called add a new std::future to the
* AsyncInferRequestThreadSafeDefault::_futures list that would be used to wait
@ -262,52 +359,23 @@ protected:
Pipeline _pipeline; //!< Pipeline variable that should be filled by inherited class.
Pipeline _syncPipeline; //!< Synchronous pipeline variable that should be filled by inherited class.
void StartAsync_ThreadUnsafe() override {
/**
* @brief Starts an asynchronous pipeline thread unsafe.
* @note Used by StartAsync which ensures thread-safety and calls this method after.
*/
virtual void StartAsync_ThreadUnsafe() {
_syncRequest->checkBlobs();
RunFirstStage(_pipeline.begin(), _pipeline.end(), _callbackExecutor);
}
void Infer_ThreadUnsafe() override {
/**
* @brief Performs inference of pipeline in syncronous mode
* @note Used by Infer which ensures thread-safety and calls this method after.
*/
virtual void Infer_ThreadUnsafe() {
InferUsingSync();
}
void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
_syncRequest->GetPerformanceCounts(perfMap);
}
void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data) override {
_syncRequest->SetBlob(name, data);
}
void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
_syncRequest->SetBlob(name, data, info);
}
void GetBlob_ThreadUnsafe(const char* name, Blob::Ptr& data) override {
_syncRequest->GetBlob(name, data);
}
void GetPreProcess_ThreadUnsafe(const char* name, const PreProcessInfo** info) const override {
_syncRequest->GetPreProcess(name, info);
}
void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) override {
_callback = callback;
}
void GetUserData_ThreadUnsafe(void** data) override {
if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
*data = _userData;
}
void SetUserData_ThreadUnsafe(void* data) override {
_userData = data;
}
void SetBatch_ThreadUnsafe(int batch) override {
_syncRequest->SetBatch(batch);
}
private:
/**
* @brief Create a task with next pipeline stage.
@ -378,6 +446,7 @@ private:
}, std::move(callbackExecutor));
}
std::atomic_bool _isRequestBusy = {false};
void* _userData = nullptr;
AtomicCallback _callback = {nullptr};
IInferRequest::Ptr _publicInterface;

View File

@ -1,228 +0,0 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <atomic>
#include <map>
#include <memory>
#include <string>
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
namespace InferenceEngine {
/**
* @brief Wrapper of asynchronous inference request to support thread-safe execution.
* @ingroup ie_dev_api_async_infer_request_api
*/
class AsyncInferRequestThreadSafeInternal : public IAsyncInferRequestInternal {
std::atomic_bool _isRequestBusy = {false};
protected:
/**
* @brief Determines if request busy.
* @return `True` if request busy, `false` otherwise.
*/
virtual bool isRequestBusy() const {
return _isRequestBusy;
}
/**
* @brief Sets the is request busy.
* @param[in] isBusy Indicates if busy
* @return `True` is case of success, `false` otherwise.
*/
virtual bool setIsRequestBusy(bool isBusy) {
return _isRequestBusy.exchange(isBusy);
}
/**
* @brief Throws an exception that an inference request is busy.
*/
[[noreturn]] static void ThrowBusy() {
THROW_IE_EXCEPTION << InferenceEngine::details::as_status << StatusCode::REQUEST_BUSY << REQUEST_BUSY_str;
}
/**
* @brief Checks whether an inference request is busy and calls ThrowBusy if `true`
*/
void CheckBusy() const {
if (isRequestBusy()) ThrowBusy();
}
public:
/**
* @brief A shared pointer to a AsyncInferRequestThreadSafeInternal implementation
*/
typedef std::shared_ptr<AsyncInferRequestThreadSafeInternal> Ptr;
/**
* @brief Constructs a new instance.
*/
AsyncInferRequestThreadSafeInternal() {
setIsRequestBusy(false);
}
void StartAsync() override {
if (setIsRequestBusy(true)) ThrowBusy();
try {
StartAsync_ThreadUnsafe();
} catch (...) {
setIsRequestBusy(false);
throw;
}
}
void GetUserData(void** data) override {
CheckBusy();
GetUserData_ThreadUnsafe(data);
}
void SetUserData(void* data) override {
CheckBusy();
SetUserData_ThreadUnsafe(data);
}
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
CheckBusy();
SetCompletionCallback_ThreadUnsafe(callback);
}
void Infer() override {
if (setIsRequestBusy(true)) ThrowBusy();
try {
Infer_ThreadUnsafe();
} catch (...) {
setIsRequestBusy(false);
throw;
}
setIsRequestBusy(false);
}
void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
CheckBusy();
GetPerformanceCounts_ThreadUnsafe(perfMap);
}
void SetBlob(const char* name, const Blob::Ptr& data) override {
CheckBusy();
SetBlob_ThreadUnsafe(name, data);
}
void SetBlob(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
CheckBusy();
SetBlob_ThreadUnsafe(name, data, info);
}
void GetBlob(const char* name, Blob::Ptr& data) override {
CheckBusy();
GetBlob_ThreadUnsafe(name, data);
}
void GetPreProcess(const char* name, const PreProcessInfo** info) const override {
GetPreProcess_ThreadUnsafe(name, info);
}
void SetBatch(int batch) override {
CheckBusy();
SetBatch_ThreadUnsafe(batch);
};
protected:
/**
* @brief Starts an asynchronous pipeline thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::StartAsync which ensures thread-safety
* and calls this method after.
*/
virtual void StartAsync_ThreadUnsafe() = 0;
/**
* @brief Gets the user data thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::GetUserData which ensures thread-safety
* and calls this method after.
* @param data The user data
*/
virtual void GetUserData_ThreadUnsafe(void** data) = 0;
/**
* @brief Sets the user data thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::SetUserData which ensures thread-safety
* and calls this method after.
* @param data The user data
*/
virtual void SetUserData_ThreadUnsafe(void* data) = 0;
/**
* @brief Sets the completion callback thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::SetCompletionCallback which ensures thread-safety
* and calls this method after.
* @param[in] callback The callback to set
*/
virtual void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) = 0;
/**
* @brief Performs inference of pipeline in syncronous mode
* @note Used by AsyncInferRequestThreadSafeInternal::Infer which ensures thread-safety
* and calls this method after.
*/
virtual void Infer_ThreadUnsafe() = 0;
/**
* @brief Gets the performance counts thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::GetPerformanceCounts which ensures thread-safety
* and calls this method after.
* @param perfMap The performance map
*/
virtual void GetPerformanceCounts_ThreadUnsafe(
std::map<std::string, InferenceEngineProfileInfo>& perfMap) const = 0;
/**
* @brief Sets the blob thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::SetBlob which ensures thread-safety
* and calls this method after.
* @param[in] name The name of input / output data to set a blob to
* @param[in] data The blob to set
*/
virtual void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data) = 0;
/**
* @brief Sets the blob with preprocessing information thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::SetBlob which ensures thread-safety
* and calls this method after.
* @param[in] name The name of input / output data to set a blob to
* @param[in] data The blob to set
* @param[in] info The preprocessing information
*/
virtual void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) = 0;
/**
* @brief Gets the input or output blob thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::GetBlob which ensures thread-safety
* and calls this method after.
* @param[in] name The name of input / output data to get a blob for
* @param data The data
*/
virtual void GetBlob_ThreadUnsafe(const char* name, Blob::Ptr& data) = 0;
/**
* @brief Gets the preprocessing information thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::GetPreProcess which ensures thread-safety
* and calls this method after.
* @param[in] name The name of input / output data to get a processing information for
* @param info The preprocessing information
*/
virtual void GetPreProcess_ThreadUnsafe(const char* name, const PreProcessInfo** info) const = 0;
/**
* @brief Sets the dynamic batch thread unsafe.
* @note Used by AsyncInferRequestThreadSafeInternal::SetBatch which ensures thread-safety
* and calls this method after.
* @param[in] batch The dynamic batch value
*/
virtual void SetBatch_ThreadUnsafe(int batch) = 0;
};
} // namespace InferenceEngine

View File

@ -13,7 +13,6 @@
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_network_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_default.hpp"

View File

@ -1,67 +0,0 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
using namespace InferenceEngine;
class MockAsyncInferRequestThreadSafeInternal : public AsyncInferRequestThreadSafeInternal {
public:
typedef std::shared_ptr<MockAsyncInferRequestThreadSafeInternal> Ptr;
void setRequestBusy() {
AsyncInferRequestThreadSafeInternal::setIsRequestBusy(true);
}
using AsyncInferRequestThreadSafeInternal::isRequestBusy;
bool isRequestBusy() {
return AsyncInferRequestThreadSafeInternal::isRequestBusy();
}
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
MOCK_METHOD0(StartAsync_ThreadUnsafe, void());
MOCK_METHOD1(GetUserData_ThreadUnsafe, void(void * *));
MOCK_METHOD1(SetUserData_ThreadUnsafe, void(void *));
MOCK_METHOD0(Infer_ThreadUnsafe, void());
MOCK_CONST_METHOD1(GetPerformanceCounts_ThreadUnsafe, void(std::map<std::string, InferenceEngineProfileInfo>
&));
MOCK_METHOD2(GetBlob_ThreadUnsafe, void(
const char *name, Blob::Ptr
&));
MOCK_CONST_METHOD2(GetPreProcess_ThreadUnsafe, void(
const char* name,
const PreProcessInfo** info));
MOCK_METHOD2(SetBlob_ThreadUnsafe, void(
const char *name,
const Blob::Ptr &));
MOCK_METHOD3(SetBlob_ThreadUnsafe, void(
const char* name,
const Blob::Ptr&,
const PreProcessInfo&));
MOCK_METHOD1(SetCompletionCallback_ThreadUnsafe, void(IInferRequest::CompletionCallback));
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int));
MOCK_METHOD0(QueryState, std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(void));
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
};

View File

@ -15,7 +15,6 @@
#include "unit_test_utils/mocks/cpp_interfaces/mock_task_executor.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp"
using namespace ::testing;
using namespace std;
@ -260,116 +259,3 @@ TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailed
testRequest->StartAsync();
EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
}
class AsyncInferRequestThreadSafeInternalTests : public ::testing::Test {
protected:
MockAsyncInferRequestThreadSafeInternal::Ptr testRequest;
ResponseDesc dsc;
bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
std::string whatMessage;
try {
func();
} catch (const InferenceEngineException &iee) {
whatMessage = iee.what();
}
return whatMessage.find(refError) != std::string::npos;
}
virtual void SetUp() {
testRequest = make_shared<MockAsyncInferRequestThreadSafeInternal>();
}
};
// StartAsync
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnStartAsync) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
}
TEST_F(AsyncInferRequestThreadSafeInternalTests, canResetBusyStatusIfStartAsyncTaskFails) {
EXPECT_CALL(*testRequest.get(), StartAsync_ThreadUnsafe()).Times(2)
.WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
.WillOnce(Return());
ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { testRequest->StartAsync(); }, "compare"));
ASSERT_NO_THROW(testRequest->StartAsync());
}
TEST_F(AsyncInferRequestThreadSafeInternalTests, deviceBusyAfterStartAsync) {
EXPECT_CALL(*testRequest.get(), StartAsync_ThreadUnsafe()).WillOnce(Return());
ASSERT_NO_THROW(testRequest->StartAsync());
ASSERT_TRUE(testRequest->isRequestBusy());
}
// GetUserData
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetUserData) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
}
// SetUserData
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetUserData) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
}
// Wait
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnInferNotStartedOnWait) {
testRequest->setRequestBusy();
int64_t ms = 0;
EXPECT_CALL(*testRequest.get(), Wait(ms)).WillOnce(Return(INFER_NOT_STARTED));
StatusCode actual = testRequest->Wait(ms);
ASSERT_EQ(INFER_NOT_STARTED, actual);
}
// Infer
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnInfer) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
}
TEST_F(AsyncInferRequestThreadSafeInternalTests, canResetBusyStatusIfInferFails) {
EXPECT_CALL(*testRequest.get(), Infer_ThreadUnsafe()).Times(2)
.WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
.WillOnce(Return());
ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { testRequest->Infer(); }, "compare"));
ASSERT_NO_THROW(testRequest->Infer());
}
// GetPerformanceCounts
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetPerformanceCounts) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
testRequest->GetPerformanceCounts(info);
}, REQUEST_BUSY_str));
}
// GetBlob
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetBlob) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
Blob::Ptr data;
testRequest->GetBlob(nullptr, data);
}, REQUEST_BUSY_str));
}
// SetBlob
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetBlob) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
}
// SetCompletionCallback
TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetCompletionCallback) {
testRequest->setRequestBusy();
ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },
REQUEST_BUSY_str));
}