[AUTO BATCH PLUGIN] renaming the class name in auto batch plugin & remove namespace (#18145)

* [AUTO BATCH PLUGIN] change namespace AutoBatchPlugin to ov::auto_batch_plugin & remove macro MockAutoBatchPlugin

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] change class AutoBatchInferencePlugin to Plugin

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] change class AutoBatchExecutableNetwork to CompiledModel & class member naming style

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] change class AutoBatchInferRequest to SyncInferRequest & class member naming style

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] change class AutoBatchAsyncInferRequest to AsyncInferRequest & class member naming style

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] fix code format issues

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] remove name space InferenceEngine

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] remove explict & change name network to model

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] remove name space MockAutoBatchPlugin

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

* [AUTO BATCH PLUGIN] fix static build issue

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>

---------

Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
Xuejun Zhai
2023-06-21 10:11:14 +08:00
committed by GitHub
parent 05e8bd375e
commit 08dc2cf1d6
14 changed files with 510 additions and 480 deletions

View File

@@ -6,23 +6,21 @@
#include "async_infer_request.hpp" #include "async_infer_request.hpp"
namespace AutoBatchPlugin { namespace ov {
namespace autobatch_plugin {
using namespace InferenceEngine; AsyncInferRequest::AsyncInferRequest(const SyncInferRequest::Ptr& inferRequest,
InferenceEngine::SoIInferRequestInternal& inferRequestWithoutBatch,
AutoBatchAsyncInferRequest::AutoBatchAsyncInferRequest( const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor)
const AutoBatchInferRequest::Ptr& inferRequest,
InferenceEngine::SoIInferRequestInternal& inferRequestWithoutBatch,
const ITaskExecutor::Ptr& callbackExecutor)
: AsyncInferRequestThreadSafeDefault(inferRequest, nullptr, callbackExecutor), : AsyncInferRequestThreadSafeDefault(inferRequest, nullptr, callbackExecutor),
_inferRequestWithoutBatch(inferRequestWithoutBatch), m_infer_request_without_batch(inferRequestWithoutBatch),
_inferRequest{inferRequest} { m_sync_infer_request{inferRequest} {
// this executor starts the inference while the task (checking the result) is passed to the next stage // this executor starts the inference while the task (checking the result) is passed to the next stage
struct ThisRequestExecutor : public ITaskExecutor { struct ThisRequestExecutor : public InferenceEngine::ITaskExecutor {
explicit ThisRequestExecutor(AutoBatchAsyncInferRequest* _this_) : _this{_this_} {} explicit ThisRequestExecutor(AsyncInferRequest* _this_) : _this{_this_} {}
void run(Task task) override { void run(InferenceEngine::Task task) override {
auto& workerInferRequest = _this->_inferRequest->_myBatchedRequestWrapper; auto& workerInferRequest = _this->m_sync_infer_request->m_batched_request_wrapper;
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t; std::pair<AsyncInferRequest*, InferenceEngine::Task> t;
t.first = _this; t.first = _this;
t.second = std::move(task); t.second = std::move(task);
workerInferRequest._tasks.push(t); workerInferRequest._tasks.push(t);
@@ -32,35 +30,36 @@ AutoBatchAsyncInferRequest::AutoBatchAsyncInferRequest(
workerInferRequest._cond.notify_one(); workerInferRequest._cond.notify_one();
} }
}; };
AutoBatchAsyncInferRequest* _this = nullptr; AsyncInferRequest* _this = nullptr;
}; };
_pipeline = {{/*TaskExecutor*/ std::make_shared<ThisRequestExecutor>(this), /*task*/ [this] { _pipeline = {
if (this->_inferRequest->_exceptionPtr) // if the exception happened in the batch1 fallback {/*TaskExecutor*/ std::make_shared<ThisRequestExecutor>(this), /*task*/ [this] {
std::rethrow_exception(this->_inferRequest->_exceptionPtr); if (this->m_sync_infer_request->m_exceptionPtr) // if the exception happened in the batch1 fallback
auto& batchReq = this->_inferRequest->_myBatchedRequestWrapper; std::rethrow_exception(this->m_sync_infer_request->m_exceptionPtr);
if (batchReq._exceptionPtr) // when the batchN execution failed auto& batchReq = this->m_sync_infer_request->m_batched_request_wrapper;
std::rethrow_exception(batchReq._exceptionPtr); if (batchReq.m_exceptionPtr) // when the batchN execution failed
// in the case of non-batched execution the blobs were set explicitly std::rethrow_exception(batchReq.m_exceptionPtr);
if (AutoBatchInferRequest::eExecutionFlavor::BATCH_EXECUTED == // in the case of non-batched execution the blobs were set explicitly
this->_inferRequest->_wasBatchedRequestUsed) if (SyncInferRequest::eExecutionFlavor::BATCH_EXECUTED ==
this->_inferRequest->CopyOutputsIfNeeded(); this->m_sync_infer_request->m_batched_request_status)
}}}; this->m_sync_infer_request->CopyOutputsIfNeeded();
}}};
} }
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> AutoBatchAsyncInferRequest::GetPerformanceCounts() std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> AsyncInferRequest::GetPerformanceCounts() const {
const {
CheckState(); CheckState();
if (AutoBatchInferRequest::eExecutionFlavor::BATCH_EXECUTED == _inferRequest->_wasBatchedRequestUsed) if (SyncInferRequest::eExecutionFlavor::BATCH_EXECUTED == m_sync_infer_request->m_batched_request_status)
return _inferRequest->_myBatchedRequestWrapper._inferRequestBatched->GetPerformanceCounts(); return m_sync_infer_request->m_batched_request_wrapper._inferRequestBatched->GetPerformanceCounts();
else else
return _inferRequestWithoutBatch->GetPerformanceCounts(); return m_infer_request_without_batch->GetPerformanceCounts();
} }
void AutoBatchAsyncInferRequest::Infer_ThreadUnsafe() { void AsyncInferRequest::Infer_ThreadUnsafe() {
InferUsingAsync(); InferUsingAsync();
} }
AutoBatchAsyncInferRequest::~AutoBatchAsyncInferRequest() { AsyncInferRequest::~AsyncInferRequest() {
StopAndWait(); StopAndWait();
} }
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -7,19 +7,25 @@
#include "cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp" #include "cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp"
#include "sync_infer_request.hpp" #include "sync_infer_request.hpp"
namespace AutoBatchPlugin { namespace ov {
class AutoBatchAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault { namespace autobatch_plugin {
class AsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
public: public:
using Ptr = std::shared_ptr<AutoBatchAsyncInferRequest>; using Ptr = std::shared_ptr<AsyncInferRequest>;
explicit AsyncInferRequest(const SyncInferRequest::Ptr& inferRequest,
InferenceEngine::SoIInferRequestInternal& inferRequestWithoutBatch,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
explicit AutoBatchAsyncInferRequest(const AutoBatchInferRequest::Ptr& inferRequest,
InferenceEngine::SoIInferRequestInternal& inferRequestWithoutBatch,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
void Infer_ThreadUnsafe() override; void Infer_ThreadUnsafe() override;
virtual ~AutoBatchAsyncInferRequest();
virtual ~AsyncInferRequest();
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override; std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
InferenceEngine::SoIInferRequestInternal _inferRequestWithoutBatch; InferenceEngine::SoIInferRequestInternal m_infer_request_without_batch;
AutoBatchInferRequest::Ptr _inferRequest;
SyncInferRequest::Ptr m_sync_infer_request;
}; };
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -9,90 +9,89 @@
#include "ie_performance_hints.hpp" #include "ie_performance_hints.hpp"
#include "sync_infer_request.hpp" #include "sync_infer_request.hpp"
namespace AutoBatchPlugin { namespace ov {
using namespace InferenceEngine; namespace autobatch_plugin {
AutoBatchExecutableNetwork::AutoBatchExecutableNetwork( CompiledModel::CompiledModel(const InferenceEngine::SoExecutableNetworkInternal& networkWithBatch,
const InferenceEngine::SoExecutableNetworkInternal& networkWithBatch, const InferenceEngine::SoExecutableNetworkInternal& networkWithoutBatch,
const InferenceEngine::SoExecutableNetworkInternal& networkWithoutBatch, const DeviceInformation& networkDevice,
const DeviceInformation& networkDevice, const std::unordered_map<std::string, InferenceEngine::Parameter>& config,
const std::unordered_map<std::string, InferenceEngine::Parameter>& config, const std::set<std::string>& batchedInputs,
const std::set<std::string>& batchedInputs, const std::set<std::string>& batchedOutputs)
const std::set<std::string>& batchedOutputs)
: InferenceEngine::ExecutableNetworkThreadSafeDefault(nullptr, : InferenceEngine::ExecutableNetworkThreadSafeDefault(nullptr,
std::make_shared<InferenceEngine::ImmediateExecutor>()), std::make_shared<InferenceEngine::ImmediateExecutor>()),
_network{networkWithBatch}, m_model_with_batch{networkWithBatch},
_networkWithoutBatch{networkWithoutBatch}, m_model_without_batch{networkWithoutBatch},
_config{config}, m_config{config},
_batchedInputs(batchedInputs), m_batched_inputs(batchedInputs),
_batchedOutputs(batchedOutputs) { m_batched_outputs(batchedOutputs) {
// WA for gcc 4.8 ( fails compilation with member init-list) // WA for gcc 4.8 ( fails compilation with member init-list)
_device = networkDevice; m_device_info = networkDevice;
auto time_out = config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT)); auto time_out = config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT));
IE_ASSERT(time_out != config.end()); IE_ASSERT(time_out != config.end());
_timeOut = ParseTimeoutValue(time_out->second.as<std::string>()); m_timeout = ParseTimeoutValue(time_out->second.as<std::string>());
} }
AutoBatchExecutableNetwork::~AutoBatchExecutableNetwork() { CompiledModel::~CompiledModel() {
_terminate = true; m_terminate = true;
for (const auto& w : _workerRequests) { for (const auto& w : m_worker_requests) {
w->_thread.join(); w->_thread.join();
} }
_workerRequests.clear(); m_worker_requests.clear();
} }
unsigned int AutoBatchExecutableNetwork::ParseTimeoutValue(const std::string& s) { unsigned int CompiledModel::ParseTimeoutValue(const std::string& s) {
auto val = std::stoi(s); auto val = std::stoi(s);
if (val < 0) if (val < 0)
IE_THROW(ParameterMismatch) << "Value for the " << CONFIG_KEY(AUTO_BATCH_TIMEOUT) << " should be unsigned int"; IE_THROW(ParameterMismatch) << "Value for the " << CONFIG_KEY(AUTO_BATCH_TIMEOUT) << " should be unsigned int";
return val; return val;
} }
std::shared_ptr<InferenceEngine::RemoteContext> AutoBatchExecutableNetwork::GetContext() const { std::shared_ptr<InferenceEngine::RemoteContext> CompiledModel::GetContext() const {
return _networkWithoutBatch->GetContext(); return m_model_without_batch->GetContext();
} }
InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateInferRequestImpl( InferenceEngine::IInferRequestInternal::Ptr CompiledModel::CreateInferRequestImpl(
InferenceEngine::InputsDataMap networkInputs, InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) { InferenceEngine::OutputsDataMap networkOutputs) {
auto workerRequestPtrAndId = GetWorkerInferRequest(); auto workerRequestPtrAndId = GetWorkerInferRequest();
return std::make_shared<AutoBatchInferRequest>(networkInputs, return std::make_shared<SyncInferRequest>(networkInputs,
networkOutputs, networkOutputs,
workerRequestPtrAndId.first, workerRequestPtrAndId.first,
workerRequestPtrAndId.second, workerRequestPtrAndId.second,
_device.batchForDevice, m_device_info.batch_for_device,
_batchedInputs, m_batched_inputs,
_batchedOutputs); m_batched_outputs);
} }
InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateInferRequestImpl( InferenceEngine::IInferRequestInternal::Ptr CompiledModel::CreateInferRequestImpl(
const std::vector<std::shared_ptr<const ov::Node>>& inputs, const std::vector<std::shared_ptr<const ov::Node>>& inputs,
const std::vector<std::shared_ptr<const ov::Node>>& outputs) { const std::vector<std::shared_ptr<const ov::Node>>& outputs) {
if (!this->_plugin || !_plugin->IsNewAPI()) if (!this->_plugin || !_plugin->IsNewAPI())
return nullptr; return nullptr;
auto workerRequestPtrAndId = GetWorkerInferRequest(); auto workerRequestPtrAndId = GetWorkerInferRequest();
return std::make_shared<AutoBatchInferRequest>(inputs, return std::make_shared<SyncInferRequest>(inputs,
outputs, outputs,
workerRequestPtrAndId.first, workerRequestPtrAndId.first,
workerRequestPtrAndId.second, workerRequestPtrAndId.second,
_device.batchForDevice, m_device_info.batch_for_device,
_batchedInputs, m_batched_inputs,
_batchedOutputs); m_batched_outputs);
} }
std::pair<AutoBatchExecutableNetwork::WorkerInferRequest&, int> AutoBatchExecutableNetwork::GetWorkerInferRequest() { std::pair<CompiledModel::WorkerInferRequest&, int> CompiledModel::GetWorkerInferRequest() {
auto num = _numRequestsCreated++; auto num = m_num_requests_created++;
std::lock_guard<std::mutex> lock(_workerRequestsMutex); std::lock_guard<std::mutex> lock(m_worker_requests_mutex);
auto batch_id = num % _device.batchForDevice; auto batch_id = num % m_device_info.batch_for_device;
if (!batch_id) { // need new request if (!batch_id) { // need new request
_workerRequests.push_back(std::make_shared<WorkerInferRequest>()); m_worker_requests.push_back(std::make_shared<WorkerInferRequest>());
auto workerRequestPtr = _workerRequests.back().get(); auto workerRequestPtr = m_worker_requests.back().get();
workerRequestPtr->_inferRequestBatched = {_network->CreateInferRequest(), _network._so}; workerRequestPtr->_inferRequestBatched = {m_model_with_batch->CreateInferRequest(), m_model_with_batch._so};
workerRequestPtr->_batchSize = _device.batchForDevice; workerRequestPtr->_batchSize = m_device_info.batch_for_device;
workerRequestPtr->_completionTasks.resize(workerRequestPtr->_batchSize); workerRequestPtr->_completionTasks.resize(workerRequestPtr->_batchSize);
workerRequestPtr->_inferRequestBatched->SetCallback( workerRequestPtr->_inferRequestBatched->SetCallback(
[workerRequestPtr](std::exception_ptr exceptionPtr) mutable { [workerRequestPtr](std::exception_ptr exceptionPtr) mutable {
if (exceptionPtr) if (exceptionPtr)
workerRequestPtr->_exceptionPtr = exceptionPtr; workerRequestPtr->m_exceptionPtr = exceptionPtr;
IE_ASSERT(workerRequestPtr->_completionTasks.size() == (size_t)workerRequestPtr->_batchSize); IE_ASSERT(workerRequestPtr->_completionTasks.size() == (size_t)workerRequestPtr->_batchSize);
// notify the individual requests on the completion // notify the individual requests on the completion
for (int c = 0; c < workerRequestPtr->_batchSize; c++) { for (int c = 0; c < workerRequestPtr->_batchSize; c++) {
@@ -107,45 +106,46 @@ std::pair<AutoBatchExecutableNetwork::WorkerInferRequest&, int> AutoBatchExecuta
std::cv_status status; std::cv_status status;
{ {
std::unique_lock<std::mutex> lock(workerRequestPtr->_mutex); std::unique_lock<std::mutex> lock(workerRequestPtr->_mutex);
status = workerRequestPtr->_cond.wait_for(lock, std::chrono::milliseconds(_timeOut)); status = workerRequestPtr->_cond.wait_for(lock, std::chrono::milliseconds(m_timeout));
} }
if (_terminate) { if (m_terminate) {
break; break;
} else { } else {
// as we pop the tasks from the queue only here // as we pop the tasks from the queue only here
// it is ok to call size() (as the _tasks can only grow in parallel) // it is ok to call size() (as the _tasks can only grow in parallel)
const int sz = static_cast<int>(workerRequestPtr->_tasks.size()); const int sz = static_cast<int>(workerRequestPtr->_tasks.size());
if (sz == workerRequestPtr->_batchSize) { if (sz == workerRequestPtr->_batchSize) {
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t; std::pair<AsyncInferRequest*, InferenceEngine::Task> t;
for (int n = 0; n < sz; n++) { for (int n = 0; n < sz; n++) {
IE_ASSERT(workerRequestPtr->_tasks.try_pop(t)); IE_ASSERT(workerRequestPtr->_tasks.try_pop(t));
workerRequestPtr->_completionTasks[n] = std::move(t.second); workerRequestPtr->_completionTasks[n] = std::move(t.second);
t.first->_inferRequest->CopyInputsIfNeeded(); t.first->m_sync_infer_request->CopyInputsIfNeeded();
t.first->_inferRequest->_wasBatchedRequestUsed = t.first->m_sync_infer_request->m_batched_request_status =
AutoBatchInferRequest::eExecutionFlavor::BATCH_EXECUTED; SyncInferRequest::eExecutionFlavor::BATCH_EXECUTED;
} }
workerRequestPtr->_inferRequestBatched->StartAsync(); workerRequestPtr->_inferRequestBatched->StartAsync();
} else if ((status == std::cv_status::timeout) && sz) { } else if ((status == std::cv_status::timeout) && sz) {
// timeout to collect the batch is over, have to execute the requests in the batch1 mode // timeout to collect the batch is over, have to execute the requests in the batch1 mode
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t; std::pair<AsyncInferRequest*, InferenceEngine::Task> t;
// popping all tasks collected by the moment of the time-out and execute each with batch1 // popping all tasks collected by the moment of the time-out and execute each with batch1
std::atomic<int> arrived = {0}; std::atomic<int> arrived = {0};
std::promise<void> all_completed; std::promise<void> all_completed;
auto all_completed_future = all_completed.get_future(); auto all_completed_future = all_completed.get_future();
for (int n = 0; n < sz; n++) { for (int n = 0; n < sz; n++) {
IE_ASSERT(workerRequestPtr->_tasks.try_pop(t)); IE_ASSERT(workerRequestPtr->_tasks.try_pop(t));
t.first->_inferRequestWithoutBatch->SetCallback( t.first->m_infer_request_without_batch->SetCallback(
[t, sz, &arrived, &all_completed](std::exception_ptr p) { [t, sz, &arrived, &all_completed](std::exception_ptr p) {
if (p) if (p)
t.first->_inferRequest->_exceptionPtr = p; t.first->m_sync_infer_request->m_exceptionPtr = p;
t.second(); t.second();
if (sz == ++arrived) if (sz == ++arrived)
all_completed.set_value(); all_completed.set_value();
}); });
t.first->_inferRequest->_wasBatchedRequestUsed = t.first->m_sync_infer_request->m_batched_request_status =
AutoBatchInferRequest::eExecutionFlavor::TIMEOUT_EXECUTED; SyncInferRequest::eExecutionFlavor::TIMEOUT_EXECUTED;
t.first->_inferRequest->SetBlobsToAnotherRequest(t.first->_inferRequestWithoutBatch); t.first->m_sync_infer_request->SetBlobsToAnotherRequest(
t.first->_inferRequestWithoutBatch->StartAsync(); t.first->m_infer_request_without_batch);
t.first->m_infer_request_without_batch->StartAsync();
} }
all_completed_future.get(); all_completed_future.get();
// now when all the tasks for this batch are completed, start waiting for the timeout again // now when all the tasks for this batch are completed, start waiting for the timeout again
@@ -154,77 +154,78 @@ std::pair<AutoBatchExecutableNetwork::WorkerInferRequest&, int> AutoBatchExecuta
} }
}); });
} }
return {*_workerRequests.back(), static_cast<int>(batch_id)}; return {*m_worker_requests.back(), static_cast<int>(batch_id)};
} }
InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateInferRequest() { InferenceEngine::IInferRequestInternal::Ptr CompiledModel::CreateInferRequest() {
if (!_network) { if (!m_model_with_batch) {
auto res = _networkWithoutBatch->CreateInferRequest(); auto res = m_model_without_batch->CreateInferRequest();
res->setPointerToExecutableNetworkInternal(shared_from_this()); res->setPointerToExecutableNetworkInternal(shared_from_this());
res->setPointerToSo(_networkWithoutBatch._so); res->setPointerToSo(m_model_without_batch._so);
_so = _networkWithoutBatch._so; _so = m_model_without_batch._so;
return res; return res;
} }
// trying to create the new API request first // trying to create the new API request first
IInferRequestInternal::Ptr syncRequestImpl = CreateInferRequestImpl(_parameters, _results); InferenceEngine::IInferRequestInternal::Ptr syncRequestImpl = CreateInferRequestImpl(_parameters, _results);
if (!syncRequestImpl) if (!syncRequestImpl)
syncRequestImpl = CreateInferRequestImpl(_networkInputs, _networkOutputs); syncRequestImpl = CreateInferRequestImpl(_networkInputs, _networkOutputs);
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this()); syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
InferenceEngine::SoIInferRequestInternal inferRequestWithoutBatch = {_networkWithoutBatch->CreateInferRequest(), InferenceEngine::SoIInferRequestInternal inferRequestWithoutBatch = {m_model_without_batch->CreateInferRequest(),
_networkWithoutBatch._so}; m_model_without_batch._so};
return std::make_shared<AutoBatchAsyncInferRequest>( return std::make_shared<AsyncInferRequest>(std::static_pointer_cast<SyncInferRequest>(syncRequestImpl),
std::static_pointer_cast<AutoBatchInferRequest>(syncRequestImpl), inferRequestWithoutBatch,
inferRequestWithoutBatch, _callbackExecutor);
_callbackExecutor);
} }
std::shared_ptr<ngraph::Function> AutoBatchExecutableNetwork::GetExecGraphInfo() { std::shared_ptr<ngraph::Function> CompiledModel::GetExecGraphInfo() {
return _network && _network->GetExecGraphInfo() ? _network->GetExecGraphInfo() return m_model_with_batch && m_model_with_batch->GetExecGraphInfo() ? m_model_with_batch->GetExecGraphInfo()
: _networkWithoutBatch->GetExecGraphInfo(); : m_model_without_batch->GetExecGraphInfo();
} }
void AutoBatchExecutableNetwork::SetConfig(const std::map<std::string, InferenceEngine::Parameter>& user_config) { void CompiledModel::SetConfig(const std::map<std::string, InferenceEngine::Parameter>& user_config) {
auto timeout = user_config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT)); auto timeout = user_config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT));
if (timeout == user_config.end() || user_config.size() > 1) { if (timeout == user_config.end() || user_config.size() > 1) {
IE_THROW() << "The only config that can be changed on the fly for the AutoBatching the is the " IE_THROW() << "The only config that can be changed on the fly for the AutoBatching the is the "
<< CONFIG_KEY(AUTO_BATCH_TIMEOUT); << CONFIG_KEY(AUTO_BATCH_TIMEOUT);
} else { } else {
_timeOut = ParseTimeoutValue(timeout->second.as<std::string>()); m_timeout = ParseTimeoutValue(timeout->second.as<std::string>());
} }
} }
InferenceEngine::Parameter AutoBatchExecutableNetwork::GetConfig(const std::string& name) const { InferenceEngine::Parameter CompiledModel::GetConfig(const std::string& name) const {
auto it = _config.find(name); auto it = m_config.find(name);
if (it != _config.end()) { if (it != m_config.end()) {
return it->second; return it->second;
} else { } else {
// find config key among networks config keys // find config key among networks config keys
auto param = _networkWithoutBatch->GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS)); auto param = m_model_without_batch->GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS));
for (auto&& configKey : param.as<std::vector<std::string>>()) { for (auto&& configKey : param.as<std::vector<std::string>>()) {
if (configKey == name) { if (configKey == name) {
return _networkWithoutBatch->GetConfig(configKey); return m_model_without_batch->GetConfig(configKey);
} }
} }
IE_THROW(NotFound) << name << " not found in the ExecutableNetwork config"; IE_THROW(NotFound) << name << " not found in the ExecutableNetwork config";
} }
} }
InferenceEngine::Parameter AutoBatchExecutableNetwork::GetMetric(const std::string& name) const { InferenceEngine::Parameter CompiledModel::GetMetric(const std::string& name) const {
if (name == METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)) { if (name == METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)) {
auto reqs = 0; auto reqs = 0;
try { try {
auto hint = _networkWithoutBatch->GetConfig(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as<std::string>(); auto hint = m_model_without_batch->GetConfig(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as<std::string>();
reqs = InferenceEngine::PerfHintsConfig::CheckPerformanceHintRequestValue(hint); reqs = InferenceEngine::PerfHintsConfig::CheckPerformanceHintRequestValue(hint);
if (!reqs) // no limitations from user, let's deduce the full blown #requests if (!reqs) // no limitations from user, let's deduce the full blown #requests
// (multiplied by the devices capabilities to run multiple <batched> requests for further perf) // (multiplied by the devices capabilities to run multiple <batched> requests for further perf)
reqs = _device.batchForDevice * reqs =
_networkWithoutBatch->GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)).as<unsigned int>(); m_device_info.batch_for_device *
m_model_without_batch->GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)).as<unsigned int>();
} catch (const InferenceEngine::Exception&) { } catch (const InferenceEngine::Exception&) {
} }
reqs = std::max(reqs, _device.batchForDevice); // round up to the possible user's value reqs = std::max(reqs, m_device_info.batch_for_device); // round up to the possible user's value
IE_SET_METRIC_RETURN(OPTIMAL_NUMBER_OF_INFER_REQUESTS, reqs); IE_SET_METRIC_RETURN(OPTIMAL_NUMBER_OF_INFER_REQUESTS, reqs);
} else if (name == METRIC_KEY(NETWORK_NAME)) { } else if (name == METRIC_KEY(NETWORK_NAME)) {
IE_SET_METRIC_RETURN(NETWORK_NAME, _networkWithoutBatch->GetMetric(METRIC_KEY(NETWORK_NAME)).as<std::string>()); IE_SET_METRIC_RETURN(NETWORK_NAME,
m_model_without_batch->GetMetric(METRIC_KEY(NETWORK_NAME)).as<std::string>());
} else if (name == METRIC_KEY(SUPPORTED_METRICS)) { } else if (name == METRIC_KEY(SUPPORTED_METRICS)) {
IE_SET_METRIC_RETURN(SUPPORTED_METRICS, IE_SET_METRIC_RETURN(SUPPORTED_METRICS,
{METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS), {METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS),
@@ -236,9 +237,10 @@ InferenceEngine::Parameter AutoBatchExecutableNetwork::GetMetric(const std::stri
IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS,
{CONFIG_KEY(AUTO_BATCH_TIMEOUT)}); // only timeout can be changed on the fly {CONFIG_KEY(AUTO_BATCH_TIMEOUT)}); // only timeout can be changed on the fly
} else if (name == ov::execution_devices) { } else if (name == ov::execution_devices) {
return _networkWithoutBatch->GetMetric(name); return m_model_without_batch->GetMetric(name);
} else { } else {
IE_THROW() << "Unsupported Network metric: " << name; IE_THROW() << "Unsupported Network metric: " << name;
} }
} }
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -12,64 +12,72 @@
#include "plugin.hpp" #include "plugin.hpp"
#include "threading/ie_thread_safe_containers.hpp" #include "threading/ie_thread_safe_containers.hpp"
namespace AutoBatchPlugin { namespace ov {
namespace autobatch_plugin {
class AutoBatchAsyncInferRequest; class AsyncInferRequest;
class AutoBatchExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafeDefault { class CompiledModel : public InferenceEngine::ExecutableNetworkThreadSafeDefault {
public: public:
using Ptr = std::shared_ptr<AutoBatchExecutableNetwork>; using Ptr = std::shared_ptr<CompiledModel>;
struct WorkerInferRequest { struct WorkerInferRequest {
using Ptr = std::shared_ptr<WorkerInferRequest>; using Ptr = std::shared_ptr<WorkerInferRequest>;
InferenceEngine::SoIInferRequestInternal _inferRequestBatched; InferenceEngine::SoIInferRequestInternal _inferRequestBatched;
int _batchSize; int _batchSize;
InferenceEngine::ThreadSafeQueueWithSize<std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task>> _tasks; InferenceEngine::ThreadSafeQueueWithSize<std::pair<AsyncInferRequest*, InferenceEngine::Task>> _tasks;
std::vector<InferenceEngine::Task> _completionTasks; std::vector<InferenceEngine::Task> _completionTasks;
std::thread _thread; std::thread _thread;
std::condition_variable _cond; std::condition_variable _cond;
std::mutex _mutex; std::mutex _mutex;
std::exception_ptr _exceptionPtr; std::exception_ptr m_exceptionPtr;
}; };
explicit AutoBatchExecutableNetwork( CompiledModel(const InferenceEngine::SoExecutableNetworkInternal& networkForDevice,
const InferenceEngine::SoExecutableNetworkInternal& networkForDevice, const InferenceEngine::SoExecutableNetworkInternal& networkForDeviceWithoutBatch,
const InferenceEngine::SoExecutableNetworkInternal& networkForDeviceWithoutBatch, const DeviceInformation& networkDevices,
const DeviceInformation& networkDevices, const std::unordered_map<std::string, InferenceEngine::Parameter>& config,
const std::unordered_map<std::string, InferenceEngine::Parameter>& config, const std::set<std::string>& batchedIntputs,
const std::set<std::string>& batchedIntputs, const std::set<std::string>& batchedOutputs);
const std::set<std::string>& batchedOutputs);
void SetConfig(const std::map<std::string, InferenceEngine::Parameter>& config) override; void SetConfig(const std::map<std::string, InferenceEngine::Parameter>& config) override;
InferenceEngine::Parameter GetConfig(const std::string& name) const override; InferenceEngine::Parameter GetConfig(const std::string& name) const override;
InferenceEngine::Parameter GetMetric(const std::string& name) const override; InferenceEngine::Parameter GetMetric(const std::string& name) const override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override; InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl( InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(
InferenceEngine::InputsDataMap networkInputs, InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override; InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl( InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(
const std::vector<std::shared_ptr<const ov::Node>>& inputs, const std::vector<std::shared_ptr<const ov::Node>>& inputs,
const std::vector<std::shared_ptr<const ov::Node>>& outputs) override; const std::vector<std::shared_ptr<const ov::Node>>& outputs) override;
std::shared_ptr<InferenceEngine::RemoteContext> GetContext() const override; std::shared_ptr<InferenceEngine::RemoteContext> GetContext() const override;
std::shared_ptr<ngraph::Function> GetExecGraphInfo() override; std::shared_ptr<ngraph::Function> GetExecGraphInfo() override;
virtual ~AutoBatchExecutableNetwork();
virtual ~CompiledModel();
protected: protected:
static unsigned int ParseTimeoutValue(const std::string&); static unsigned int ParseTimeoutValue(const std::string&);
std::atomic_bool _terminate = {false}; std::atomic_bool m_terminate = {false};
DeviceInformation _device; DeviceInformation m_device_info;
InferenceEngine::SoExecutableNetworkInternal _network; InferenceEngine::SoExecutableNetworkInternal m_model_with_batch;
InferenceEngine::SoExecutableNetworkInternal _networkWithoutBatch; InferenceEngine::SoExecutableNetworkInternal m_model_without_batch;
std::pair<WorkerInferRequest&, int> GetWorkerInferRequest(); std::pair<WorkerInferRequest&, int> GetWorkerInferRequest();
std::vector<WorkerInferRequest::Ptr> _workerRequests; std::vector<WorkerInferRequest::Ptr> m_worker_requests;
std::mutex _workerRequestsMutex; std::mutex m_worker_requests_mutex;
std::unordered_map<std::string, InferenceEngine::Parameter> _config; std::unordered_map<std::string, InferenceEngine::Parameter> m_config;
bool _needPerfCounters = false; std::atomic_size_t m_num_requests_created = {0};
std::atomic_size_t _numRequestsCreated = {0}; std::atomic_int m_timeout = {0}; // in ms
std::atomic_int _timeOut = {0}; // in ms
const std::set<std::string> _batchedInputs; const std::set<std::string> m_batched_inputs;
const std::set<std::string> _batchedOutputs; const std::set<std::string> m_batched_outputs;
}; };
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -19,8 +19,8 @@
#include "transformations/init_node_info.hpp" #include "transformations/init_node_info.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
namespace AutoBatchPlugin { namespace ov {
using namespace InferenceEngine; namespace autobatch_plugin {
std::vector<std::string> supported_configKeys = {CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG), std::vector<std::string> supported_configKeys = {CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG),
ov::device::priorities.name(), ov::device::priorities.name(),
@@ -38,7 +38,7 @@ std::map<std::string, std::string> mergeConfigs(std::map<std::string, std::strin
} // namespace } // namespace
DeviceInformation AutoBatchInferencePlugin::ParseBatchDevice(const std::string& deviceWithBatch) { DeviceInformation Plugin::ParseBatchDevice(const std::string& deviceWithBatch) {
auto&& d = deviceWithBatch; auto&& d = deviceWithBatch;
auto openingBracket = d.find_first_of('('); auto openingBracket = d.find_first_of('(');
auto closingBracket = d.find_first_of(')', openingBracket); auto closingBracket = d.find_first_of(')', openingBracket);
@@ -55,11 +55,10 @@ DeviceInformation AutoBatchInferencePlugin::ParseBatchDevice(const std::string&
return {deviceName, {{}}, batch}; return {deviceName, {{}}, batch};
} }
DeviceInformation AutoBatchInferencePlugin::ParseMetaDevice( DeviceInformation Plugin::ParseMetaDevice(const std::string& devicesBatchCfg,
const std::string& devicesBatchCfg, const std::map<std::string, std::string>& user_config) const {
const std::map<std::string, std::string>& user_config) const {
auto metaDevice = ParseBatchDevice(devicesBatchCfg); auto metaDevice = ParseBatchDevice(devicesBatchCfg);
metaDevice.config = GetCore()->GetSupportedConfig(metaDevice.deviceName, user_config); metaDevice.config = GetCore()->GetSupportedConfig(metaDevice.device_name, user_config);
// check that no irrelevant config-keys left // check that no irrelevant config-keys left
for (const auto& k : user_config) { for (const auto& k : user_config) {
@@ -72,7 +71,7 @@ DeviceInformation AutoBatchInferencePlugin::ParseMetaDevice(
return metaDevice; return metaDevice;
} }
RemoteContext::Ptr AutoBatchInferencePlugin::CreateContext(const InferenceEngine::ParamMap& remote_properties) { InferenceEngine::RemoteContext::Ptr Plugin::CreateContext(const InferenceEngine::ParamMap& remote_properties) {
auto cfg = remote_properties; auto cfg = remote_properties;
auto it = cfg.find(CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG)); auto it = cfg.find(CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG));
if (it == cfg.end()) if (it == cfg.end())
@@ -86,11 +85,12 @@ RemoteContext::Ptr AutoBatchInferencePlugin::CreateContext(const InferenceEngine
return nullptr; return nullptr;
auto metaDevice = ParseMetaDevice(val, std::map<std::string, std::string>()); auto metaDevice = ParseMetaDevice(val, std::map<std::string, std::string>());
cfg.erase(it); cfg.erase(it);
return core->CreateContext(metaDevice.deviceName, cfg); return core->CreateContext(metaDevice.device_name, cfg);
} }
Parameter AutoBatchInferencePlugin::GetConfig(const std::string& name, InferenceEngine::Parameter Plugin::GetConfig(
const std::map<std::string, Parameter>& user_options) const { const std::string& name,
const std::map<std::string, InferenceEngine::Parameter>& user_options) const {
if (supported_configKeys.end() != std::find(supported_configKeys.begin(), supported_configKeys.end(), name)) { if (supported_configKeys.end() != std::find(supported_configKeys.begin(), supported_configKeys.end(), name)) {
auto it = _config.find(name); auto it = _config.find(name);
if (it == _config.end()) { if (it == _config.end()) {
@@ -103,7 +103,7 @@ Parameter AutoBatchInferencePlugin::GetConfig(const std::string& name,
} }
} }
void AutoBatchInferencePlugin::CheckConfig(const std::map<std::string, std::string>& user_config) { void Plugin::CheckConfig(const std::map<std::string, std::string>& user_config) {
for (auto&& kvp : user_config) { for (auto&& kvp : user_config) {
const auto name = kvp.first; const auto name = kvp.first;
const auto val = kvp.second; const auto val = kvp.second;
@@ -124,22 +124,22 @@ void AutoBatchInferencePlugin::CheckConfig(const std::map<std::string, std::stri
} }
} }
void AutoBatchInferencePlugin::SetConfig(const std::map<std::string, std::string>& user_config) { void Plugin::SetConfig(const std::map<std::string, std::string>& user_config) {
CheckConfig(user_config); CheckConfig(user_config);
for (auto&& kvp : user_config) { for (auto&& kvp : user_config) {
_config[kvp.first] = kvp.second; _config[kvp.first] = kvp.second;
} }
} }
static const Version version = {{2, 1}, CI_BUILD_NUMBER, "AutoBatchPlugin"}; static const InferenceEngine::Version version = {{2, 1}, CI_BUILD_NUMBER, "AutoBatchPlugin"};
IE_DEFINE_PLUGIN_CREATE_FUNCTION(AutoBatchInferencePlugin, version) IE_DEFINE_PLUGIN_CREATE_FUNCTION(Plugin, version)
AutoBatchInferencePlugin::AutoBatchInferencePlugin() { Plugin::Plugin() {
_pluginName = "BATCH"; _pluginName = "BATCH";
_config[CONFIG_KEY(AUTO_BATCH_TIMEOUT)] = "1000"; // default value, in ms _config[CONFIG_KEY(AUTO_BATCH_TIMEOUT)] = "1000"; // default value, in ms
} }
InferenceEngine::Parameter AutoBatchInferencePlugin::GetMetric( InferenceEngine::Parameter Plugin::GetMetric(
const std::string& name, const std::string& name,
const std::map<std::string, InferenceEngine::Parameter>& user_options) const { const std::map<std::string, InferenceEngine::Parameter>& user_options) const {
if (name == METRIC_KEY(SUPPORTED_METRICS)) { if (name == METRIC_KEY(SUPPORTED_METRICS)) {
@@ -157,13 +157,13 @@ InferenceEngine::Parameter AutoBatchInferencePlugin::GetMetric(
} }
} }
IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadExeNetworkImpl( InferenceEngine::IExecutableNetworkInternal::Ptr Plugin::LoadExeNetworkImpl(
const InferenceEngine::CNNNetwork& network, const InferenceEngine::CNNNetwork& network,
const std::map<std::string, std::string>& user_config) { const std::map<std::string, std::string>& user_config) {
return LoadNetworkImpl(network, nullptr, user_config); return LoadNetworkImpl(network, nullptr, user_config);
} }
InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadNetworkImpl( InferenceEngine::IExecutableNetworkInternal::Ptr Plugin::LoadNetworkImpl(
const InferenceEngine::CNNNetwork& network, const InferenceEngine::CNNNetwork& network,
const std::shared_ptr<InferenceEngine::RemoteContext> ctx, const std::shared_ptr<InferenceEngine::RemoteContext> ctx,
const std::map<std::string, std::string>& user_config) { const std::map<std::string, std::string>& user_config) {
@@ -179,7 +179,7 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
IE_THROW() << "KEY_AUTO_BATCH key is not set for BATCH device"; IE_THROW() << "KEY_AUTO_BATCH key is not set for BATCH device";
} }
auto metaDevice = ParseMetaDevice(device_batch->second, user_config); auto metaDevice = ParseMetaDevice(device_batch->second, user_config);
const auto& deviceName = metaDevice.deviceName; const auto& deviceName = metaDevice.device_name;
const auto& deviceConfig = metaDevice.config; const auto& deviceConfig = metaDevice.config;
auto deviceConfigNoAutoBatch = deviceConfig; auto deviceConfigNoAutoBatch = deviceConfig;
// avoid recursive auto-batching // avoid recursive auto-batching
@@ -196,7 +196,7 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
const bool bTputInLoadCfg = (mode != deviceConfig.end() && mode->second == tput); const bool bTputInLoadCfg = (mode != deviceConfig.end() && mode->second == tput);
// if the auto-batching is enabled implicitly, check the dims carefully, to avoid outstanding failures // if the auto-batching is enabled implicitly, check the dims carefully, to avoid outstanding failures
const bool check_dims = (bTputInPlg || bTputInLoadCfg); const bool check_dims = (bTputInPlg || bTputInLoadCfg);
CNNNetwork clonedNetwork(InferenceEngine::details::cloneNetwork(network)); InferenceEngine::CNNNetwork clonedNetwork(InferenceEngine::details::cloneNetwork(network));
auto function = clonedNetwork.getFunction(); auto function = clonedNetwork.getFunction();
// find the batch dim // find the batch dim
ov::pass::Manager m; ov::pass::Manager m;
@@ -252,10 +252,10 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
IE_THROW(NotImplemented) IE_THROW(NotImplemented)
<< "Auto-batching supports only networks with inputs/outputs featuring batched dim!"; << "Auto-batching supports only networks with inputs/outputs featuring batched dim!";
} catch (const InferenceEngine::Exception&) { } catch (const InferenceEngine::Exception&) {
metaDevice.batchForDevice = 1; metaDevice.batch_for_device = 1;
} }
if (!metaDevice.batchForDevice) { if (!metaDevice.batch_for_device) {
unsigned int requests = 0; unsigned int requests = 0;
// batch size is not set explicitly via device name e.g. BATCH:GPU(4) // batch size is not set explicitly via device name e.g. BATCH:GPU(4)
// let's query the optimal batch size // let's query the optimal batch size
@@ -263,19 +263,20 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
options["MODEL_PTR"] = std::const_pointer_cast<ngraph::Function>(network.getFunction()); options["MODEL_PTR"] = std::const_pointer_cast<ngraph::Function>(network.getFunction());
auto optBatchSize = core->GetMetric(deviceName, METRIC_KEY(OPTIMAL_BATCH_SIZE), options).as<unsigned int>(); auto optBatchSize = core->GetMetric(deviceName, METRIC_KEY(OPTIMAL_BATCH_SIZE), options).as<unsigned int>();
auto res = core->GetConfig(deviceName, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as<std::string>(); auto res = core->GetConfig(deviceName, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as<std::string>();
requests = PerfHintsConfig::CheckPerformanceHintRequestValue(res); requests = InferenceEngine::PerfHintsConfig::CheckPerformanceHintRequestValue(res);
const auto& reqs = user_config.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)); const auto& reqs = user_config.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS));
if (reqs != user_config.end()) if (reqs != user_config.end())
requests = static_cast<unsigned int>(PerfHintsConfig::CheckPerformanceHintRequestValue(reqs->second)); requests = static_cast<unsigned int>(
InferenceEngine::PerfHintsConfig::CheckPerformanceHintRequestValue(reqs->second));
if (requests) if (requests)
optBatchSize = std::max(1u, std::min(requests, optBatchSize)); optBatchSize = std::max(1u, std::min(requests, optBatchSize));
if (optBatchSize > 2) // batching is usually in-efficient for batch<4 (as batch1 kernels are heavily optimized) if (optBatchSize > 2) // batching is usually in-efficient for batch<4 (as batch1 kernels are heavily optimized)
metaDevice.batchForDevice = optBatchSize; metaDevice.batch_for_device = optBatchSize;
else else
metaDevice.batchForDevice = 1; metaDevice.batch_for_device = 1;
} }
auto report_footprint = [](std::shared_ptr<ICore> pCore, std::string device) -> size_t { auto report_footprint = [](std::shared_ptr<InferenceEngine::ICore> pCore, std::string device) -> size_t {
size_t footprint = 0; size_t footprint = 0;
// TODO: use the per-network metric (22.2) rather than plugin-level // TODO: use the per-network metric (22.2) rather than plugin-level
auto stats = auto stats =
@@ -296,9 +297,9 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
const auto total_mem = const auto total_mem =
GetCore()->GetMetric(deviceName, GPU_METRIC_KEY(DEVICE_TOTAL_MEM_SIZE)).as<uint64_t>(); GetCore()->GetMetric(deviceName, GPU_METRIC_KEY(DEVICE_TOTAL_MEM_SIZE)).as<uint64_t>();
const int estimated_batch = static_cast<int>((total_mem - batch1_footprint) / batch1_footprint); const int estimated_batch = static_cast<int>((total_mem - batch1_footprint) / batch1_footprint);
int closest = static_cast<int>(pow(2, floor(log(estimated_batch) / log(2)))); int closest = static_cast<int>(pow(2, floor(std::log(estimated_batch) / std::log(2))));
closest = std::max(1, closest); closest = std::max(1, closest);
metaDevice.batchForDevice = std::min(metaDevice.batchForDevice, closest); metaDevice.batch_for_device = std::min(metaDevice.batch_for_device, closest);
} }
} }
// auto-batch settings // auto-batch settings
@@ -309,38 +310,37 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadN
} }
InferenceEngine::SoExecutableNetworkInternal executableNetworkWithBatch; InferenceEngine::SoExecutableNetworkInternal executableNetworkWithBatch;
if (metaDevice.batchForDevice > 1 && batched_inputs.size()) { if (metaDevice.batch_for_device > 1 && batched_inputs.size()) {
try { try {
CNNNetwork reshaped(InferenceEngine::details::cloneNetwork(network)); InferenceEngine::CNNNetwork reshaped(InferenceEngine::details::cloneNetwork(network));
ICNNNetwork::InputShapes shapes = reshaped.getInputShapes(); InferenceEngine::ICNNNetwork::InputShapes shapes = reshaped.getInputShapes();
for (const auto& input : batched_inputs) for (const auto& input : batched_inputs)
shapes[input][0] = metaDevice.batchForDevice; shapes[input][0] = metaDevice.batch_for_device;
reshaped.reshape(shapes); reshaped.reshape(shapes);
executableNetworkWithBatch = ctx ? core->LoadNetwork(reshaped, ctx, deviceConfigNoAutoBatch) executableNetworkWithBatch = ctx ? core->LoadNetwork(reshaped, ctx, deviceConfigNoAutoBatch)
: core->LoadNetwork(reshaped, deviceName, deviceConfigNoAutoBatch); : core->LoadNetwork(reshaped, deviceName, deviceConfigNoAutoBatch);
} catch (const InferenceEngine::Exception&) { } catch (const InferenceEngine::Exception&) {
metaDevice.batchForDevice = 1; metaDevice.batch_for_device = 1;
} }
} }
return std::make_shared<AutoBatchExecutableNetwork>(executableNetworkWithBatch, return std::make_shared<CompiledModel>(executableNetworkWithBatch,
executableNetworkWithoutBatch, executableNetworkWithoutBatch,
metaDevice, metaDevice,
networkConfig, networkConfig,
batched_inputs, batched_inputs,
batched_outputs); batched_outputs);
} }
InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadExeNetworkImpl( InferenceEngine::IExecutableNetworkInternal::Ptr Plugin::LoadExeNetworkImpl(
const InferenceEngine::CNNNetwork& network, const InferenceEngine::CNNNetwork& network,
const std::shared_ptr<InferenceEngine::RemoteContext>& context, const std::shared_ptr<InferenceEngine::RemoteContext>& context,
const std::map<std::string, std::string>& user_config) { const std::map<std::string, std::string>& user_config) {
return LoadNetworkImpl(network, context, user_config); return LoadNetworkImpl(network, context, user_config);
} }
InferenceEngine::QueryNetworkResult AutoBatchInferencePlugin::QueryNetwork( InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::CNNNetwork& network,
const InferenceEngine::CNNNetwork& network, const std::map<std::string, std::string>& user_config) const {
const std::map<std::string, std::string>& user_config) const {
auto core = GetCore(); auto core = GetCore();
if (!core) if (!core)
return InferenceEngine::QueryNetworkResult(); return InferenceEngine::QueryNetworkResult();
@@ -350,9 +350,10 @@ InferenceEngine::QueryNetworkResult AutoBatchInferencePlugin::QueryNetwork(
auto val = c.second; auto val = c.second;
cfg.erase(c.first); cfg.erase(c.first);
auto metaDevice = ParseMetaDevice(val, cfg); auto metaDevice = ParseMetaDevice(val, cfg);
return core->QueryNetwork(network, metaDevice.deviceName, cfg); return core->QueryNetwork(network, metaDevice.device_name, cfg);
} }
} }
IE_THROW() << "Value for KEY_AUTO_BATCH_DEVICE_CONFIG is not set"; IE_THROW() << "Value for KEY_AUTO_BATCH_DEVICE_CONFIG is not set";
} }
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -11,43 +11,49 @@
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp" #include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
#ifdef AUTOBATCH_UNITTEST #ifdef AUTOBATCH_UNITTEST
# define AutoBatchPlugin MockAutoBatchPlugin # define autobatch_plugin mock_autobatch_plugin
#endif #endif
namespace AutoBatchPlugin { namespace ov {
namespace autobatch_plugin {
using DeviceName = std::string;
struct DeviceInformation { struct DeviceInformation {
DeviceName deviceName; std::string device_name;
std::map<std::string, std::string> config; std::map<std::string, std::string> config;
int batchForDevice; int batch_for_device;
}; };
class AutoBatchInferencePlugin : public InferenceEngine::IInferencePlugin { class Plugin : public InferenceEngine::IInferencePlugin {
public: public:
AutoBatchInferencePlugin(); Plugin();
virtual ~AutoBatchInferencePlugin() = default;
virtual ~Plugin() = default;
InferenceEngine::IExecutableNetworkInternal::Ptr LoadExeNetworkImpl( InferenceEngine::IExecutableNetworkInternal::Ptr LoadExeNetworkImpl(
const InferenceEngine::CNNNetwork& network, const InferenceEngine::CNNNetwork& network,
const std::map<std::string, std::string>& config) override; const std::map<std::string, std::string>& config) override;
InferenceEngine::IExecutableNetworkInternal::Ptr LoadExeNetworkImpl( InferenceEngine::IExecutableNetworkInternal::Ptr LoadExeNetworkImpl(
const InferenceEngine::CNNNetwork& network, const InferenceEngine::CNNNetwork& network,
const std::shared_ptr<InferenceEngine::RemoteContext>& context, const std::shared_ptr<InferenceEngine::RemoteContext>& context,
const std::map<std::string, std::string>& config) override; const std::map<std::string, std::string>& config) override;
void SetConfig(const std::map<std::string, std::string>& config) override; void SetConfig(const std::map<std::string, std::string>& config) override;
void CheckConfig(const std::map<std::string, std::string>& config); void CheckConfig(const std::map<std::string, std::string>& config);
InferenceEngine::Parameter GetConfig( InferenceEngine::Parameter GetConfig(
const std::string& name, const std::string& name,
const std::map<std::string, InferenceEngine::Parameter>& options) const override; const std::map<std::string, InferenceEngine::Parameter>& options) const override;
InferenceEngine::QueryNetworkResult QueryNetwork(const InferenceEngine::CNNNetwork& network, InferenceEngine::QueryNetworkResult QueryNetwork(const InferenceEngine::CNNNetwork& network,
const std::map<std::string, std::string>& config) const override; const std::map<std::string, std::string>& config) const override;
InferenceEngine::Parameter GetMetric( InferenceEngine::Parameter GetMetric(
const std::string& name, const std::string& name,
const std::map<std::string, InferenceEngine::Parameter>& options) const override; const std::map<std::string, InferenceEngine::Parameter>& options) const override;
InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap&) override; InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap&) override;
#ifdef AUTOBATCH_UNITTEST #ifdef AUTOBATCH_UNITTEST
public: public:
@@ -65,4 +71,5 @@ protected:
const std::shared_ptr<InferenceEngine::RemoteContext> context, const std::shared_ptr<InferenceEngine::RemoteContext> context,
const std::map<std::string, std::string>& config); const std::map<std::string, std::string>& config);
}; };
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -5,170 +5,171 @@
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
#include "sync_infer_request.hpp" #include "sync_infer_request.hpp"
namespace AutoBatchPlugin { namespace ov {
using namespace InferenceEngine; namespace autobatch_plugin {
template <Precision::ePrecision precision> template <InferenceEngine::Precision::ePrecision precision>
Blob::Ptr create_shared_blob_on_top_of_batched_blob(Blob::Ptr batched_blob, InferenceEngine::Blob::Ptr create_shared_blob_on_top_of_batched_blob(InferenceEngine::Blob::Ptr batched_blob,
std::string name, std::string name,
const std::set<std::string>& batched_names, const std::set<std::string>& batched_names,
size_t batch_id, size_t batch_id,
size_t batch_num) { size_t batch_num) {
typedef typename PrecisionTrait<precision>::value_type TYPE; typedef typename InferenceEngine::PrecisionTrait<precision>::value_type TYPE;
typedef typename std::add_pointer<TYPE>::type TYPEPTR; typedef typename std::add_pointer<TYPE>::type TYPEPTR;
auto ptr = batched_blob->buffer().as<TYPEPTR>(); auto ptr = batched_blob->buffer().as<TYPEPTR>();
auto sizePerBatch = batched_blob->size() / batch_num; auto sizePerBatch = batched_blob->size() / batch_num;
SizeVector dims = batched_blob->getTensorDesc().getDims(); InferenceEngine::SizeVector dims = batched_blob->getTensorDesc().getDims();
// for performance reason (copy avoidance) current impl of the auto-batching supports only batching by 0th dim // for performance reason (copy avoidance) current impl of the auto-batching supports only batching by 0th dim
if (batched_names.count(name)) { if (batched_names.count(name)) {
dims[0] = 1; dims[0] = 1;
return make_shared_blob<TYPE>({precision, dims, batched_blob->getTensorDesc().getLayout()}, return InferenceEngine::make_shared_blob<TYPE>({precision, dims, batched_blob->getTensorDesc().getLayout()},
ptr + sizePerBatch * batch_id, ptr + sizePerBatch * batch_id,
sizePerBatch); sizePerBatch);
} else { } else {
// same blob for all requests (e.g. constants) // same blob for all requests (e.g. constants)
return make_shared_blob<TYPE>({precision, dims, batched_blob->getTensorDesc().getLayout()}, ptr); return InferenceEngine::make_shared_blob<TYPE>({precision, dims, batched_blob->getTensorDesc().getLayout()},
ptr);
} }
} }
AutoBatchInferRequest::AutoBatchInferRequest(const std::vector<std::shared_ptr<const ov::Node>>& inputs, SyncInferRequest::SyncInferRequest(const std::vector<std::shared_ptr<const ov::Node>>& inputs,
const std::vector<std::shared_ptr<const ov::Node>>& outputs, const std::vector<std::shared_ptr<const ov::Node>>& outputs,
AutoBatchExecutableNetwork::WorkerInferRequest& workerRequest, CompiledModel::WorkerInferRequest& workerRequest,
int batch_id, int batch_id,
int num_batch, int num_batch,
const std::set<std::string>& batchedInputs, const std::set<std::string>& batchedInputs,
const std::set<std::string>& batchedOutputs) const std::set<std::string>& batchedOutputs)
: IInferRequestInternal(inputs, outputs), : IInferRequestInternal(inputs, outputs),
_myBatchedRequestWrapper(workerRequest), m_batched_request_wrapper(workerRequest),
_batchId(batch_id), m_batch_id(batch_id),
_batchSize(num_batch) { m_batch_size(num_batch) {
ShareBlobsWithBatchRequest(batchedInputs, batchedOutputs); ShareBlobsWithBatchRequest(batchedInputs, batchedOutputs);
} }
AutoBatchInferRequest::AutoBatchInferRequest(const InputsDataMap& networkInputs, SyncInferRequest::SyncInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
const OutputsDataMap& networkOutputs, const InferenceEngine::OutputsDataMap& networkOutputs,
AutoBatchExecutableNetwork::WorkerInferRequest& workerRequest, CompiledModel::WorkerInferRequest& workerRequest,
int batch_id, int batch_id,
int num_batch, int num_batch,
const std::set<std::string>& batchedInputs, const std::set<std::string>& batchedInputs,
const std::set<std::string>& batchedOutputs) const std::set<std::string>& batchedOutputs)
: IInferRequestInternal(networkInputs, networkOutputs), : IInferRequestInternal(networkInputs, networkOutputs),
_myBatchedRequestWrapper(workerRequest), m_batched_request_wrapper(workerRequest),
_batchId(batch_id), m_batch_id(batch_id),
_batchSize(num_batch) { m_batch_size(num_batch) {
ShareBlobsWithBatchRequest(batchedInputs, batchedOutputs); ShareBlobsWithBatchRequest(batchedInputs, batchedOutputs);
} }
void AutoBatchInferRequest::ShareBlobsWithBatchRequest(const std::set<std::string>& batchedInputs, void SyncInferRequest::ShareBlobsWithBatchRequest(const std::set<std::string>& batchedInputs,
const std::set<std::string>& batchedOutputs) { const std::set<std::string>& batchedOutputs) {
// Allocate all input blobs // Allocate all input blobs
for (const auto& it : _networkInputs) { for (const auto& it : _networkInputs) {
auto blob = _myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first); auto blob = m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first);
Blob::Ptr res; InferenceEngine::Blob::Ptr res;
switch (it.second->getTensorDesc().getPrecision()) { switch (it.second->getTensorDesc().getPrecision()) {
case InferenceEngine::Precision::FP32: case InferenceEngine::Precision::FP32:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP32>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP32>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I32: case InferenceEngine::Precision::I32:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I32>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I32>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I8: case InferenceEngine::Precision::I8:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I8>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I8>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I16: case InferenceEngine::Precision::I16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U16: case InferenceEngine::Precision::U16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U32: case InferenceEngine::Precision::U32:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U32>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U32>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::FP64: case InferenceEngine::Precision::FP64:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP64>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP64>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::FP16: case InferenceEngine::Precision::FP16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::BF16: case InferenceEngine::Precision::BF16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BF16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BF16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U64: case InferenceEngine::Precision::U64:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U64>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U64>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I64: case InferenceEngine::Precision::I64:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I64>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I64>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U8: case InferenceEngine::Precision::U8:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U8>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U8>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::BOOL: case InferenceEngine::Precision::BOOL:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BOOL>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BOOL>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedInputs, batchedInputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
default: default:
IE_THROW() << "Unsupported input precision " << it.second->getTensorDesc().getPrecision(); IE_THROW() << "Unsupported input precision " << it.second->getTensorDesc().getPrecision();
@@ -177,112 +178,112 @@ void AutoBatchInferRequest::ShareBlobsWithBatchRequest(const std::set<std::strin
} }
// Allocate all output blobs // Allocate all output blobs
for (const auto& it : _networkOutputs) { for (const auto& it : _networkOutputs) {
auto blob = _myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first); auto blob = m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first);
Blob::Ptr res; InferenceEngine::Blob::Ptr res;
switch (it.second->getTensorDesc().getPrecision()) { switch (it.second->getTensorDesc().getPrecision()) {
case InferenceEngine::Precision::FP32: case InferenceEngine::Precision::FP32:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP32>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP32>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I32: case InferenceEngine::Precision::I32:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I32>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I32>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I8: case InferenceEngine::Precision::I8:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I8>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I8>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I16: case InferenceEngine::Precision::I16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U16: case InferenceEngine::Precision::U16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U32: case InferenceEngine::Precision::U32:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U32>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U32>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::FP64: case InferenceEngine::Precision::FP64:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP64>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP64>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::FP16: case InferenceEngine::Precision::FP16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::FP16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::BF16: case InferenceEngine::Precision::BF16:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BF16>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BF16>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U64: case InferenceEngine::Precision::U64:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U64>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U64>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::I64: case InferenceEngine::Precision::I64:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I64>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::I64>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::U8: case InferenceEngine::Precision::U8:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U8>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::U8>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
case InferenceEngine::Precision::BOOL: case InferenceEngine::Precision::BOOL:
res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BOOL>( res = create_shared_blob_on_top_of_batched_blob<InferenceEngine::Precision::BOOL>(
_myBatchedRequestWrapper._inferRequestBatched->GetBlob(it.first), m_batched_request_wrapper._inferRequestBatched->GetBlob(it.first),
it.first, it.first,
batchedOutputs, batchedOutputs,
_batchId, m_batch_id,
_batchSize); m_batch_size);
break; break;
default: default:
IE_THROW(NotImplemented) << "Unsupported input precision " << it.second->getTensorDesc().getPrecision(); IE_THROW(NotImplemented) << "Unsupported input precision " << it.second->getTensorDesc().getPrecision();
@@ -290,7 +291,7 @@ void AutoBatchInferRequest::ShareBlobsWithBatchRequest(const std::set<std::strin
_outputs[it.first] = res; _outputs[it.first] = res;
} }
} }
void AutoBatchInferRequest::SetBlobsToAnotherRequest(SoIInferRequestInternal& req) { void SyncInferRequest::SetBlobsToAnotherRequest(InferenceEngine::SoIInferRequestInternal& req) {
for (const auto& it : _networkInputs) { for (const auto& it : _networkInputs) {
auto& name = it.first; auto& name = it.first;
// this request is already in BUSY state, so using the internal functions safely // this request is already in BUSY state, so using the internal functions safely
@@ -307,17 +308,15 @@ void AutoBatchInferRequest::SetBlobsToAnotherRequest(SoIInferRequestInternal& re
} }
} }
void AutoBatchInferRequest::CopyInputsIfNeeded() { void SyncInferRequest::CopyInputsIfNeeded() {
for (const auto& it : _networkInputs) { for (const auto& it : _networkInputs) {
auto& name = it.first; auto& name = it.first;
// this request is already in BUSY state, so using the internal functions safely // this request is already in BUSY state, so using the internal functions safely
CopyBlobIfNeeded(GetBlob(name), _myBatchedRequestWrapper._inferRequestBatched->GetBlob(name), true); CopyBlobIfNeeded(GetBlob(name), m_batched_request_wrapper._inferRequestBatched->GetBlob(name), true);
} }
} }
void AutoBatchInferRequest::CopyBlobIfNeeded(InferenceEngine::Blob::CPtr src, void SyncInferRequest::CopyBlobIfNeeded(InferenceEngine::Blob::CPtr src, InferenceEngine::Blob::Ptr dst, bool bInput) {
InferenceEngine::Blob::Ptr dst,
bool bInput) {
auto bufferDst = dst->buffer(); auto bufferDst = dst->buffer();
auto ptrDst = bufferDst.as<char*>(); auto ptrDst = bufferDst.as<char*>();
auto bufferSrc = src->cbuffer(); auto bufferSrc = src->cbuffer();
@@ -325,13 +324,13 @@ void AutoBatchInferRequest::CopyBlobIfNeeded(InferenceEngine::Blob::CPtr src,
ptrdiff_t szDst = dst->byteSize(); ptrdiff_t szDst = dst->byteSize();
ptrdiff_t szSrc = src->byteSize(); ptrdiff_t szSrc = src->byteSize();
if (bInput) { if (bInput) {
ptrdiff_t offset = szSrc != szDst ? _batchId * szDst / _batchSize : 0; ptrdiff_t offset = szSrc != szDst ? m_batch_id * szDst / m_batch_size : 0;
if ((ptrDst + offset) == ptrSrc) if ((ptrDst + offset) == ptrSrc)
return; return;
else else
memcpy(ptrDst + offset, ptrSrc, szSrc); memcpy(ptrDst + offset, ptrSrc, szSrc);
} else { } else {
ptrdiff_t offset = szSrc != szDst ? _batchId * szSrc / _batchSize : 0; ptrdiff_t offset = szSrc != szDst ? m_batch_id * szSrc / m_batch_size : 0;
if ((ptrSrc + offset) == ptrDst) if ((ptrSrc + offset) == ptrDst)
return; return;
else else
@@ -339,11 +338,12 @@ void AutoBatchInferRequest::CopyBlobIfNeeded(InferenceEngine::Blob::CPtr src,
} }
} }
void AutoBatchInferRequest::CopyOutputsIfNeeded() { void SyncInferRequest::CopyOutputsIfNeeded() {
for (const auto& it : _networkOutputs) { for (const auto& it : _networkOutputs) {
auto& name = it.first; auto& name = it.first;
// this request is already in BUSY state, so using the internal functions safely // this request is already in BUSY state, so using the internal functions safely
CopyBlobIfNeeded(_myBatchedRequestWrapper._inferRequestBatched->GetBlob(name), GetBlob(name), false); CopyBlobIfNeeded(m_batched_request_wrapper._inferRequestBatched->GetBlob(name), GetBlob(name), false);
} }
} }
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -8,42 +8,53 @@
#include "compiled_model.hpp" #include "compiled_model.hpp"
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp" #include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
namespace AutoBatchPlugin { namespace ov {
class AutoBatchInferRequest : public InferenceEngine::IInferRequestInternal { namespace autobatch_plugin {
class SyncInferRequest : public InferenceEngine::IInferRequestInternal {
public: public:
using Ptr = std::shared_ptr<AutoBatchInferRequest>; using Ptr = std::shared_ptr<SyncInferRequest>;
explicit AutoBatchInferRequest(const InferenceEngine::InputsDataMap& networkInputs, explicit SyncInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
const InferenceEngine::OutputsDataMap& networkOutputs, const InferenceEngine::OutputsDataMap& networkOutputs,
AutoBatchExecutableNetwork::WorkerInferRequest& workerRequestPtr, CompiledModel::WorkerInferRequest& workerRequestPtr,
int batch_id, int batch_id,
int num_batch, int num_batch,
const std::set<std::string>& batchedIntputs, const std::set<std::string>& batchedIntputs,
const std::set<std::string>& batchedOutputs); const std::set<std::string>& batchedOutputs);
explicit AutoBatchInferRequest(const std::vector<std::shared_ptr<const ov::Node>>& inputs,
const std::vector<std::shared_ptr<const ov::Node>>& outputs, explicit SyncInferRequest(const std::vector<std::shared_ptr<const ov::Node>>& inputs,
AutoBatchExecutableNetwork::WorkerInferRequest& workerRequestPtr, const std::vector<std::shared_ptr<const ov::Node>>& outputs,
int batch_id, CompiledModel::WorkerInferRequest& workerRequestPtr,
int num_batch, int batch_id,
const std::set<std::string>& batchedIntputs, int num_batch,
const std::set<std::string>& batchedOutputs); const std::set<std::string>& batchedIntputs,
const std::set<std::string>& batchedOutputs);
// Batch-Device impl specific: sets the data (blobs from the device request to the batched device request) // Batch-Device impl specific: sets the data (blobs from the device request to the batched device request)
void SetBlobsToAnotherRequest(InferenceEngine::SoIInferRequestInternal& req); void SetBlobsToAnotherRequest(InferenceEngine::SoIInferRequestInternal& req);
void CopyInputsIfNeeded(); void CopyInputsIfNeeded();
void CopyOutputsIfNeeded(); void CopyOutputsIfNeeded();
AutoBatchExecutableNetwork::WorkerInferRequest& _myBatchedRequestWrapper;
std::exception_ptr _exceptionPtr; CompiledModel::WorkerInferRequest& m_batched_request_wrapper;
std::exception_ptr m_exceptionPtr;
enum eExecutionFlavor : uint8_t { enum eExecutionFlavor : uint8_t {
NOT_EXECUTED, NOT_EXECUTED,
BATCH_EXECUTED, BATCH_EXECUTED,
TIMEOUT_EXECUTED TIMEOUT_EXECUTED
} _wasBatchedRequestUsed = eExecutionFlavor::NOT_EXECUTED; } m_batched_request_status = eExecutionFlavor::NOT_EXECUTED;
protected: protected:
void CopyBlobIfNeeded(InferenceEngine::Blob::CPtr src, InferenceEngine::Blob::Ptr dst, bool bInput); void CopyBlobIfNeeded(InferenceEngine::Blob::CPtr src, InferenceEngine::Blob::Ptr dst, bool bInput);
void ShareBlobsWithBatchRequest(const std::set<std::string>& batchedIntputs, void ShareBlobsWithBatchRequest(const std::set<std::string>& batchedIntputs,
const std::set<std::string>& batchedOutputs); const std::set<std::string>& batchedOutputs);
size_t _batchId; size_t m_batch_id;
size_t _batchSize;
size_t m_batch_size;
}; };
} // namespace AutoBatchPlugin } // namespace autobatch_plugin
} // namespace ov

View File

@@ -4,6 +4,7 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thread> #include <thread>
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp" #include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
@@ -30,8 +31,7 @@ using ::testing::ReturnRef;
using ::testing::StrEq; using ::testing::StrEq;
using ::testing::StrNe; using ::testing::StrNe;
using ::testing::Throw; using ::testing::Throw;
using namespace MockAutoBatchPlugin; using namespace ov::mock_autobatch_plugin;
using namespace MockAutoBatchDevice;
using namespace InferenceEngine; using namespace InferenceEngine;
using AutoBatchRequestTestParams = std::tuple<int, // batch_size using AutoBatchRequestTestParams = std::tuple<int, // batch_size
@@ -42,12 +42,12 @@ public:
// Mock inferRequest // Mock inferRequest
std::shared_ptr<NiceMock<MockIInferRequestInternal>> mockInferRequestBatched; std::shared_ptr<NiceMock<MockIInferRequestInternal>> mockInferRequestBatched;
std::vector<std::shared_ptr<AutoBatchInferRequest>> autoBatchInferRequests; std::vector<std::shared_ptr<SyncInferRequest>> autoBatchInferRequests;
std::map<std::string, InferenceEngine::Blob::Ptr> blobMap; std::map<std::string, InferenceEngine::Blob::Ptr> blobMap;
std::vector<std::shared_ptr<const ov::Node>> inputs, outputs; std::vector<std::shared_ptr<const ov::Node>> inputs, outputs;
std::set<std::string> batchedInputs, batchedOutputs; std::set<std::string> batchedInputs, batchedOutputs;
std::shared_ptr<AutoBatchExecutableNetwork::WorkerInferRequest> workerRequestPtr; std::shared_ptr<CompiledModel::WorkerInferRequest> workerRequestPtr;
public: public:
static std::string getTestCaseName(testing::TestParamInfo<AutoBatchRequestTestParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<AutoBatchRequestTestParams> obj) {
@@ -80,14 +80,14 @@ public:
} }
void create_worker(int batch_size) { void create_worker(int batch_size) {
workerRequestPtr = std::make_shared<AutoBatchExecutableNetwork::WorkerInferRequest>(); workerRequestPtr = std::make_shared<CompiledModel::WorkerInferRequest>();
workerRequestPtr->_inferRequestBatched = {mockInferRequestBatched, {}}; workerRequestPtr->_inferRequestBatched = {mockInferRequestBatched, {}};
workerRequestPtr->_batchSize = batch_size; workerRequestPtr->_batchSize = batch_size;
workerRequestPtr->_completionTasks.resize(workerRequestPtr->_batchSize); workerRequestPtr->_completionTasks.resize(workerRequestPtr->_batchSize);
workerRequestPtr->_inferRequestBatched->SetCallback([this](std::exception_ptr exceptionPtr) mutable { workerRequestPtr->_inferRequestBatched->SetCallback([this](std::exception_ptr exceptionPtr) mutable {
if (exceptionPtr) if (exceptionPtr)
workerRequestPtr->_exceptionPtr = exceptionPtr; workerRequestPtr->m_exceptionPtr = exceptionPtr;
}); });
workerRequestPtr->_thread = std::thread([] { workerRequestPtr->_thread = std::thread([] {
std::this_thread::sleep_for(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10));
@@ -173,13 +173,13 @@ TEST_P(AutoBatchRequestTest, AutoBatchRequestCreateTestCase) {
create_worker(batch_size); create_worker(batch_size);
for (int batch_id = 0; batch_id < batch_size; batch_id++) { for (int batch_id = 0; batch_id < batch_size; batch_id++) {
auto req = std::make_shared<AutoBatchInferRequest>(inputs, auto req = std::make_shared<SyncInferRequest>(inputs,
outputs, outputs,
*workerRequestPtr, *workerRequestPtr,
batch_id, batch_id,
batch_size, batch_size,
batchedInputs, batchedInputs,
batchedOutputs); batchedOutputs);
EXPECT_NE(req, nullptr); EXPECT_NE(req, nullptr);
autoBatchInferRequests.emplace_back(req); autoBatchInferRequests.emplace_back(req);
@@ -206,13 +206,13 @@ TEST_P(AutoBatchRequestTest, AutoBatchRequestCopyBlobTestCase) {
create_worker(batch_size); create_worker(batch_size);
for (int batch_id = 0; batch_id < batch_size; batch_id++) { for (int batch_id = 0; batch_id < batch_size; batch_id++) {
auto req = std::make_shared<AutoBatchInferRequest>(inputs, auto req = std::make_shared<SyncInferRequest>(inputs,
outputs, outputs,
*workerRequestPtr, *workerRequestPtr,
batch_id, batch_id,
batch_size, batch_size,
batchedInputs, batchedInputs,
batchedOutputs); batchedOutputs);
EXPECT_NE(req, nullptr); EXPECT_NE(req, nullptr);
autoBatchInferRequests.emplace_back(req); autoBatchInferRequests.emplace_back(req);
@@ -225,7 +225,7 @@ class AutoBatchAsyncInferRequestTest : public AutoBatchRequestTest {
public: public:
std::shared_ptr<NiceMock<MockIInferRequestInternal>> mockInferRequestWithoutBatched; std::shared_ptr<NiceMock<MockIInferRequestInternal>> mockInferRequestWithoutBatched;
MockTaskExecutor::Ptr mockTaskExecutor; MockTaskExecutor::Ptr mockTaskExecutor;
std::vector<AutoBatchAsyncInferRequest::Ptr> autoBatchAsyncInferRequestVec; std::vector<AsyncInferRequest::Ptr> autoBatchAsyncInferRequestVec;
bool terminate; bool terminate;
public: public:
@@ -245,14 +245,14 @@ public:
} }
void create_worker(int batch_size) { void create_worker(int batch_size) {
workerRequestPtr = std::make_shared<AutoBatchExecutableNetwork::WorkerInferRequest>(); workerRequestPtr = std::make_shared<CompiledModel::WorkerInferRequest>();
workerRequestPtr->_inferRequestBatched = {mockInferRequestBatched, {}}; workerRequestPtr->_inferRequestBatched = {mockInferRequestBatched, {}};
workerRequestPtr->_batchSize = batch_size; workerRequestPtr->_batchSize = batch_size;
workerRequestPtr->_completionTasks.resize(workerRequestPtr->_batchSize); workerRequestPtr->_completionTasks.resize(workerRequestPtr->_batchSize);
workerRequestPtr->_inferRequestBatched->SetCallback([this](std::exception_ptr exceptionPtr) mutable { workerRequestPtr->_inferRequestBatched->SetCallback([this](std::exception_ptr exceptionPtr) mutable {
if (exceptionPtr) if (exceptionPtr)
workerRequestPtr->_exceptionPtr = exceptionPtr; workerRequestPtr->m_exceptionPtr = exceptionPtr;
}); });
ON_CALL(*mockInferRequestBatched, StartAsync()).WillByDefault([this]() { ON_CALL(*mockInferRequestBatched, StartAsync()).WillByDefault([this]() {
@@ -275,21 +275,21 @@ public:
} else { } else {
const int sz = static_cast<int>(workerRequestPtr->_tasks.size()); const int sz = static_cast<int>(workerRequestPtr->_tasks.size());
if (sz == workerRequestPtr->_batchSize) { if (sz == workerRequestPtr->_batchSize) {
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t; std::pair<AsyncInferRequest*, InferenceEngine::Task> t;
for (int n = 0; n < sz; n++) { for (int n = 0; n < sz; n++) {
IE_ASSERT(workerRequestPtr->_tasks.try_pop(t)); IE_ASSERT(workerRequestPtr->_tasks.try_pop(t));
workerRequestPtr->_completionTasks[n] = std::move(t.second); workerRequestPtr->_completionTasks[n] = std::move(t.second);
t.first->_inferRequest->_wasBatchedRequestUsed = t.first->m_sync_infer_request->m_batched_request_status =
AutoBatchInferRequest::eExecutionFlavor::BATCH_EXECUTED; SyncInferRequest::eExecutionFlavor::BATCH_EXECUTED;
} }
workerRequestPtr->_inferRequestBatched->StartAsync(); workerRequestPtr->_inferRequestBatched->StartAsync();
} else if ((status == std::cv_status::timeout) && sz) { } else if ((status == std::cv_status::timeout) && sz) {
std::pair<AutoBatchAsyncInferRequest*, InferenceEngine::Task> t; std::pair<AsyncInferRequest*, InferenceEngine::Task> t;
for (int n = 0; n < sz; n++) { for (int n = 0; n < sz; n++) {
IE_ASSERT(workerRequestPtr->_tasks.try_pop(t)); IE_ASSERT(workerRequestPtr->_tasks.try_pop(t));
t.first->_inferRequest->_wasBatchedRequestUsed = t.first->m_sync_infer_request->m_batched_request_status =
AutoBatchInferRequest::eExecutionFlavor::TIMEOUT_EXECUTED; SyncInferRequest::eExecutionFlavor::TIMEOUT_EXECUTED;
t.first->_inferRequestWithoutBatch->StartAsync(); t.first->m_infer_request_without_batch->StartAsync();
t.second(); t.second();
} }
} }
@@ -311,19 +311,19 @@ TEST_P(AutoBatchAsyncInferRequestTest, AutoBatchAsyncInferRequestCreateTest) {
create_worker(batch_size); create_worker(batch_size);
for (int batch_id = 0; batch_id < batch_size; batch_id++) { for (int batch_id = 0; batch_id < batch_size; batch_id++) {
auto autoRequestImpl = std::make_shared<AutoBatchInferRequest>(inputs, auto autoRequestImpl = std::make_shared<SyncInferRequest>(inputs,
outputs, outputs,
*workerRequestPtr, *workerRequestPtr,
batch_id, batch_id,
batch_size, batch_size,
batchedInputs, batchedInputs,
batchedOutputs); batchedOutputs);
EXPECT_NE(autoRequestImpl, nullptr); EXPECT_NE(autoRequestImpl, nullptr);
autoBatchInferRequests.emplace_back(autoRequestImpl); autoBatchInferRequests.emplace_back(autoRequestImpl);
InferenceEngine::SoIInferRequestInternal inferRequestWithoutBatched = {mockInferRequestWithoutBatched, {}}; InferenceEngine::SoIInferRequestInternal inferRequestWithoutBatched = {mockInferRequestWithoutBatched, {}};
auto asyncInferRequest = auto asyncInferRequest =
std::make_shared<AutoBatchAsyncInferRequest>(autoRequestImpl, inferRequestWithoutBatched, nullptr); std::make_shared<AsyncInferRequest>(autoRequestImpl, inferRequestWithoutBatched, nullptr);
EXPECT_NE(asyncInferRequest, nullptr); EXPECT_NE(asyncInferRequest, nullptr);
autoBatchAsyncInferRequestVec.emplace_back(asyncInferRequest); autoBatchAsyncInferRequestVec.emplace_back(asyncInferRequest);
} }
@@ -340,19 +340,19 @@ TEST_P(AutoBatchAsyncInferRequestTest, AutoBatchAsyncInferRequestStartAsyncTest)
create_worker(batch_size); create_worker(batch_size);
for (int batch_id = 0; batch_id < batch_size; batch_id++) { for (int batch_id = 0; batch_id < batch_size; batch_id++) {
auto autoRequestImpl = std::make_shared<AutoBatchInferRequest>(inputs, auto autoRequestImpl = std::make_shared<SyncInferRequest>(inputs,
outputs, outputs,
*workerRequestPtr, *workerRequestPtr,
batch_id, batch_id,
batch_size, batch_size,
batchedInputs, batchedInputs,
batchedOutputs); batchedOutputs);
EXPECT_NE(autoRequestImpl, nullptr); EXPECT_NE(autoRequestImpl, nullptr);
autoBatchInferRequests.emplace_back(autoRequestImpl); autoBatchInferRequests.emplace_back(autoRequestImpl);
InferenceEngine::SoIInferRequestInternal inferRequestWithoutBatched = {mockInferRequestWithoutBatched, {}}; InferenceEngine::SoIInferRequestInternal inferRequestWithoutBatched = {mockInferRequestWithoutBatched, {}};
auto asyncInferRequest = auto asyncInferRequest =
std::make_shared<AutoBatchAsyncInferRequest>(autoRequestImpl, inferRequestWithoutBatched, nullptr); std::make_shared<AsyncInferRequest>(autoRequestImpl, inferRequestWithoutBatched, nullptr);
EXPECT_NE(asyncInferRequest, nullptr); EXPECT_NE(asyncInferRequest, nullptr);
autoBatchAsyncInferRequestVec.emplace_back(asyncInferRequest); autoBatchAsyncInferRequestVec.emplace_back(asyncInferRequest);
} }

View File

@@ -26,12 +26,11 @@ using ::testing::ReturnRef;
using ::testing::StrEq; using ::testing::StrEq;
using ::testing::StrNe; using ::testing::StrNe;
using ::testing::Throw; using ::testing::Throw;
using namespace MockAutoBatchPlugin; using namespace ov::mock_autobatch_plugin;
using namespace MockAutoBatchDevice;
using namespace InferenceEngine; using namespace InferenceEngine;
using CreateInferRequestTestParams = std::tuple<int, // batch_size using CreateInferRequestTestParams = std::tuple<int, // batch_size
int>; // inferReq number int>; // inferReq number
class CreateInferRequestTest : public ::testing::TestWithParam<CreateInferRequestTestParams> { class CreateInferRequestTest : public ::testing::TestWithParam<CreateInferRequestTestParams> {
public: public:
std::shared_ptr<NiceMock<MockICore>> core; std::shared_ptr<NiceMock<MockICore>> core;
@@ -44,7 +43,7 @@ public:
std::shared_ptr<InferenceEngine::IInferencePlugin> mockPlugin; std::shared_ptr<InferenceEngine::IInferencePlugin> mockPlugin;
ov::SoPtr<IExecutableNetworkInternal> batchedExecNetwork; ov::SoPtr<IExecutableNetworkInternal> batchedExecNetwork;
std::shared_ptr<AutoBatchExecutableNetwork> actualExecNet; std::shared_ptr<CompiledModel> actualExecNet;
std::vector<std::shared_ptr<NiceMock<MockIInferRequestInternal>>> inferRequestVec; std::vector<std::shared_ptr<NiceMock<MockIInferRequestInternal>>> inferRequestVec;
public: public:
@@ -75,7 +74,8 @@ public:
mockIPlugin = std::make_shared<NiceMock<MockIInferencePlugin>>(); mockIPlugin = std::make_shared<NiceMock<MockIInferencePlugin>>();
ON_CALL(*mockIPlugin, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExecNet)); ON_CALL(*mockIPlugin, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExecNet));
mockPlugin = mockIPlugin; mockPlugin = mockIPlugin;
mockExecNetwork = ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(mockPlugin->LoadNetwork(CNNNetwork{}, {}), {}); mockExecNetwork =
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(mockPlugin->LoadNetwork(CNNNetwork{}, {}), {});
batchedExecNetwork = {}; batchedExecNetwork = {};
core = std::shared_ptr<NiceMock<MockICore>>(new NiceMock<MockICore>()); core = std::shared_ptr<NiceMock<MockICore>>(new NiceMock<MockICore>());
@@ -90,20 +90,21 @@ public:
}); });
} }
AutoBatchExecutableNetwork::Ptr createAutoBatchExecutableNetwork(int batch_size) { CompiledModel::Ptr createAutoBatchExecutableNetwork(int batch_size) {
DeviceInformation metaDevice = {"CPU", {}, batch_size}; DeviceInformation metaDevice = {"CPU", {}, batch_size};
std::unordered_map<std::string, InferenceEngine::Parameter> config = {{CONFIG_KEY(AUTO_BATCH_TIMEOUT), "200"}}; std::unordered_map<std::string, InferenceEngine::Parameter> config = {{CONFIG_KEY(AUTO_BATCH_TIMEOUT), "200"}};
std::set<std::string> batched_inputs = {"Parameter_0"}; std::set<std::string> batched_inputs = {"Parameter_0"};
std::set<std::string> batched_outputs = {"Convolution_20"}; std::set<std::string> batched_outputs = {"Convolution_20"};
if (batch_size > 1) if (batch_size > 1)
batchedExecNetwork = ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(mockPlugin->LoadNetwork(CNNNetwork{}, {}), {}); batchedExecNetwork =
return std::make_shared<AutoBatchExecutableNetwork>(batchedExecNetwork, ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(mockPlugin->LoadNetwork(CNNNetwork{}, {}), {});
mockExecNetwork, return std::make_shared<CompiledModel>(batchedExecNetwork,
metaDevice, mockExecNetwork,
config, metaDevice,
batched_inputs, config,
batched_outputs); batched_inputs,
batched_outputs);
} }
}; };
@@ -128,7 +129,5 @@ const std::vector<int> batch_size{1, 8, 16, 32, 128, 256};
INSTANTIATE_TEST_SUITE_P(smoke_AutoBatch_BehaviorTests, INSTANTIATE_TEST_SUITE_P(smoke_AutoBatch_BehaviorTests,
CreateInferRequestTest, CreateInferRequestTest,
::testing::Combine( ::testing::Combine(::testing::ValuesIn(batch_size), ::testing::ValuesIn(requests_num)),
::testing::ValuesIn(batch_size),
::testing::ValuesIn(requests_num)),
CreateInferRequestTest::getTestCaseName); CreateInferRequestTest::getTestCaseName);

View File

@@ -26,8 +26,7 @@ using ::testing::ReturnRef;
using ::testing::StrEq; using ::testing::StrEq;
using ::testing::StrNe; using ::testing::StrNe;
using ::testing::Throw; using ::testing::Throw;
using namespace MockAutoBatchPlugin; using namespace ov::mock_autobatch_plugin;
using namespace MockAutoBatchDevice;
using namespace InferenceEngine; using namespace InferenceEngine;
using ExecNetworkParams = std::tuple<std::string, // Key name using ExecNetworkParams = std::tuple<std::string, // Key name
@@ -89,14 +88,15 @@ public:
ON_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExecNet)); ON_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExecNet));
mockPlugin = mockIPluginPtr; mockPlugin = mockIPluginPtr;
EXPECT_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).Times(1); EXPECT_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).Times(1);
mockExecNetwork = ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(mockPlugin->LoadNetwork(CNNNetwork{}, {}), {}); mockExecNetwork =
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(mockPlugin->LoadNetwork(CNNNetwork{}, {}), {});
core = std::shared_ptr<NiceMock<MockICore>>(new NiceMock<MockICore>()); core = std::shared_ptr<NiceMock<MockICore>>(new NiceMock<MockICore>());
plugin = std::shared_ptr<NiceMock<MockAutoBatchInferencePlugin>>(new NiceMock<MockAutoBatchInferencePlugin>()); plugin = std::shared_ptr<NiceMock<MockAutoBatchInferencePlugin>>(new NiceMock<MockAutoBatchInferencePlugin>());
plugin->SetCore(core); plugin->SetCore(core);
ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) { ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) {
return plugin->AutoBatchInferencePlugin::ParseBatchDevice(batchDevice); return plugin->Plugin::ParseBatchDevice(batchDevice);
}); });
ON_CALL(*core, LoadNetwork(MatcherCast<const CNNNetwork&>(_), MatcherCast<const std::string&>(_), _)) ON_CALL(*core, LoadNetwork(MatcherCast<const CNNNetwork&>(_), MatcherCast<const std::string&>(_), _))
.WillByDefault(Return(mockExecNetwork)); .WillByDefault(Return(mockExecNetwork));
@@ -174,25 +174,25 @@ TEST_P(ExecNetworkTest, ExecNetworkGetConfigMetricTestCase) {
} }
const std::vector<ExecNetworkParams> testConfigs = { const std::vector<ExecNetworkParams> testConfigs = {
// Metric // Metric
ExecNetworkParams{METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS), 0, false}, ExecNetworkParams{METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS), 0, false},
ExecNetworkParams{METRIC_KEY(NETWORK_NAME), 0, false}, ExecNetworkParams{METRIC_KEY(NETWORK_NAME), 0, false},
ExecNetworkParams{METRIC_KEY(SUPPORTED_METRICS), 0, false}, ExecNetworkParams{METRIC_KEY(SUPPORTED_METRICS), 0, false},
ExecNetworkParams{METRIC_KEY(SUPPORTED_CONFIG_KEYS), 0, false}, ExecNetworkParams{METRIC_KEY(SUPPORTED_CONFIG_KEYS), 0, false},
ExecNetworkParams{ov::execution_devices.name(), 0, false}, ExecNetworkParams{ov::execution_devices.name(), 0, false},
// Config in autobatch // Config in autobatch
ExecNetworkParams{CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG), 1, false}, ExecNetworkParams{CONFIG_KEY(AUTO_BATCH_DEVICE_CONFIG), 1, false},
ExecNetworkParams{CONFIG_KEY(AUTO_BATCH_TIMEOUT), 1, false}, ExecNetworkParams{CONFIG_KEY(AUTO_BATCH_TIMEOUT), 1, false},
ExecNetworkParams{CONFIG_KEY(CACHE_DIR), 1, false}, ExecNetworkParams{CONFIG_KEY(CACHE_DIR), 1, false},
// Config in dependent plugin // Config in dependent plugin
ExecNetworkParams{"OPTIMAL_BATCH_SIZE", 1, false}, ExecNetworkParams{"OPTIMAL_BATCH_SIZE", 1, false},
// Incorrect Metric // Incorrect Metric
ExecNetworkParams{"INCORRECT_METRIC", 0, true}, ExecNetworkParams{"INCORRECT_METRIC", 0, true},
// Incorrect config // Incorrect config
ExecNetworkParams{"INCORRECT_CONFIG", 1, true}, ExecNetworkParams{"INCORRECT_CONFIG", 1, true},
// Set Config // Set Config
ExecNetworkParams{CONFIG_KEY(AUTO_BATCH_TIMEOUT), 2, false}, ExecNetworkParams{CONFIG_KEY(AUTO_BATCH_TIMEOUT), 2, false},
ExecNetworkParams{"INCORRECT_CONFIG", 2, true}, ExecNetworkParams{"INCORRECT_CONFIG", 2, true},
}; };
INSTANTIATE_TEST_SUITE_P(smoke_AutoBatch_BehaviorTests, INSTANTIATE_TEST_SUITE_P(smoke_AutoBatch_BehaviorTests,

View File

@@ -28,8 +28,7 @@ using ::testing::ReturnRef;
using ::testing::StrEq; using ::testing::StrEq;
using ::testing::StrNe; using ::testing::StrNe;
using ::testing::Throw; using ::testing::Throw;
using namespace MockAutoBatchPlugin; using namespace ov::mock_autobatch_plugin;
using namespace MockAutoBatchDevice;
using namespace InferenceEngine; using namespace InferenceEngine;
using PluginLoadNetworkParams = std::tuple<std::map<std::string, std::string>, // Paramters using PluginLoadNetworkParams = std::tuple<std::map<std::string, std::string>, // Paramters
@@ -79,14 +78,15 @@ public:
.WillByDefault(Return(cpuMockIExecNet)); .WillByDefault(Return(cpuMockIExecNet));
cpuMockPlugin = cpuMockIPluginPtr; cpuMockPlugin = cpuMockIPluginPtr;
EXPECT_CALL(*cpuMockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).Times(1); EXPECT_CALL(*cpuMockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).Times(1);
cpuMockExecNetwork = ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(cpuMockPlugin->LoadNetwork(CNNNetwork{}, {}), {}); cpuMockExecNetwork =
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal>(cpuMockPlugin->LoadNetwork(CNNNetwork{}, {}), {});
core = std::shared_ptr<NiceMock<MockICore>>(new NiceMock<MockICore>()); core = std::shared_ptr<NiceMock<MockICore>>(new NiceMock<MockICore>());
plugin = std::shared_ptr<NiceMock<MockAutoBatchInferencePlugin>>(new NiceMock<MockAutoBatchInferencePlugin>()); plugin = std::shared_ptr<NiceMock<MockAutoBatchInferencePlugin>>(new NiceMock<MockAutoBatchInferencePlugin>());
plugin->SetCore(core); plugin->SetCore(core);
ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) { ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) {
return plugin->AutoBatchInferencePlugin::ParseBatchDevice(batchDevice); return plugin->Plugin::ParseBatchDevice(batchDevice);
}); });
ON_CALL(*core, LoadNetwork(MatcherCast<const CNNNetwork&>(_), MatcherCast<const std::string&>(_), _)) ON_CALL(*core, LoadNetwork(MatcherCast<const CNNNetwork&>(_), MatcherCast<const std::string&>(_), _))
.WillByDefault(Return(cpuMockExecNetwork)); .WillByDefault(Return(cpuMockExecNetwork));
@@ -257,12 +257,13 @@ const std::vector<PluginLoadNetworkParams> testConfigs = {
{"GPU_DEVICE_TOTAL_MEM_SIZE", "4096000000"}}, {"GPU_DEVICE_TOTAL_MEM_SIZE", "4096000000"}},
{{"AUTO_BATCH_TIMEOUT", "200"}, {"AUTO_BATCH_DEVICE_CONFIG", "CPU"}}, {{"AUTO_BATCH_TIMEOUT", "200"}, {"AUTO_BATCH_DEVICE_CONFIG", "CPU"}},
1}, 1},
//PluginLoadNetworkParams{{{"PERFORMANCE_HINT", "THROUGHPUT"}, // PluginLoadNetworkParams{{{"PERFORMANCE_HINT", "THROUGHPUT"},
// {"OPTIMAL_BATCH_SIZE", "32"}, // {"OPTIMAL_BATCH_SIZE", "32"},
// {"PERFORMANCE_HINT_NUM_REQUESTS", "16"}, // {"PERFORMANCE_HINT_NUM_REQUESTS", "16"},
// {"GPU_MEMORY_STATISTICS", "1024000"}, // {"GPU_MEMORY_STATISTICS", "1024000"},
// {"GPU_DEVICE_TOTAL_MEM_SIZE", "4096000000"}}, // {"GPU_DEVICE_TOTAL_MEM_SIZE", "4096000000"}},
// {{"AUTO_BATCH_TIMEOUT", "200"}, {"AUTO_BATCH_DEVICE_CONFIG", "CPU"}, {"PERFORMANCE_HINT_NUM_REQUESTS", "12"}}, // {{"AUTO_BATCH_TIMEOUT", "200"}, {"AUTO_BATCH_DEVICE_CONFIG", "CPU"},
// {"PERFORMANCE_HINT_NUM_REQUESTS", "12"}},
// 12}, // 12},
// //
// Case 3: GPU batch size is figured out by // Case 3: GPU batch size is figured out by

View File

@@ -13,10 +13,9 @@
#include "plugin.hpp" #include "plugin.hpp"
#include "sync_infer_request.hpp" #include "sync_infer_request.hpp"
using namespace MockAutoBatchPlugin; using namespace ov::mock_autobatch_plugin;
namespace MockAutoBatchDevice {
class MockAutoBatchInferencePlugin : public AutoBatchInferencePlugin { class MockAutoBatchInferencePlugin : public Plugin {
public: public:
MOCK_METHOD((DeviceInformation), MOCK_METHOD((DeviceInformation),
ParseMetaDevices, ParseMetaDevices,
@@ -30,10 +29,8 @@ public:
(const, override)); (const, override));
}; };
class MockAutoBatchExecutableNetwork : public AutoBatchExecutableNetwork { class MockAutoBatchExecutableNetwork : public CompiledModel {
public: public:
MOCK_METHOD((InferenceEngine::Parameter), GetConfig, (const std::string&), (const, override)); MOCK_METHOD((InferenceEngine::Parameter), GetConfig, (const std::string&), (const, override));
MOCK_METHOD((InferenceEngine::Parameter), GetMetric, (const std::string&), (const, override)); MOCK_METHOD((InferenceEngine::Parameter), GetMetric, (const std::string&), (const, override));
}; };
} // namespace MockAutoBatchDevice

View File

@@ -20,8 +20,7 @@ using ::testing::ReturnRef;
using ::testing::StrEq; using ::testing::StrEq;
using ::testing::StrNe; using ::testing::StrNe;
using ::testing::Throw; using ::testing::Throw;
using namespace MockAutoBatchPlugin; using namespace ov::mock_autobatch_plugin;
using namespace MockAutoBatchDevice;
using BatchDeviceConfigParams = std::tuple<std::string, // Batch devices using BatchDeviceConfigParams = std::tuple<std::string, // Batch devices
std::string, // Expected device name std::string, // Expected device name
int, // Expected batch size int, // Expected batch size
@@ -82,7 +81,7 @@ public:
plugin->SetCore(core); plugin->SetCore(core);
ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) { ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) {
return plugin->AutoBatchInferencePlugin::ParseBatchDevice(batchDevice); return plugin->Plugin::ParseBatchDevice(batchDevice);
}); });
} }
}; };
@@ -192,7 +191,7 @@ public:
}); });
ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) { ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) {
return plugin->AutoBatchInferencePlugin::ParseBatchDevice(batchDevice); return plugin->Plugin::ParseBatchDevice(batchDevice);
}); });
} }
@@ -223,8 +222,8 @@ TEST_P(ParseMetaDeviceTest, ParseMetaDeviceTestCase) {
ASSERT_ANY_THROW(plugin->ParseMetaDevice(batch_cfg, config)); ASSERT_ANY_THROW(plugin->ParseMetaDevice(batch_cfg, config));
} else { } else {
auto result = plugin->ParseMetaDevice(batch_cfg, config); auto result = plugin->ParseMetaDevice(batch_cfg, config);
EXPECT_EQ(result.deviceName, expected.deviceName); EXPECT_EQ(result.device_name, expected.device_name);
EXPECT_EQ(result.batchForDevice, expected.batchForDevice); EXPECT_EQ(result.batch_for_device, expected.batch_for_device);
EXPECT_TRUE(compare(result.config, expected.config)); EXPECT_TRUE(compare(result.config, expected.config));
} }
} }
@@ -255,7 +254,7 @@ public:
plugin->SetCore(core); plugin->SetCore(core);
ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) { ON_CALL(*plugin, ParseBatchDevice).WillByDefault([this](const std::string& batchDevice) {
return plugin->AutoBatchInferencePlugin::ParseBatchDevice(batchDevice); return plugin->Plugin::ParseBatchDevice(batchDevice);
}); });
} }
}; };
@@ -271,8 +270,8 @@ TEST_P(ParseBatchDeviceTest, ParseBatchDeviceTestCase) {
ASSERT_ANY_THROW(plugin->ParseBatchDevice(batchDevice)); ASSERT_ANY_THROW(plugin->ParseBatchDevice(batchDevice));
} else { } else {
auto result = plugin->ParseBatchDevice(batchDevice); auto result = plugin->ParseBatchDevice(batchDevice);
EXPECT_EQ(result.deviceName, deviceName); EXPECT_EQ(result.device_name, deviceName);
EXPECT_EQ(result.batchForDevice, batchSize); EXPECT_EQ(result.batch_for_device, batchSize);
} }
} }
@@ -303,7 +302,7 @@ public:
ON_CALL(*plugin, GetMetric) ON_CALL(*plugin, GetMetric)
.WillByDefault( .WillByDefault(
[this](const std::string& name, const std::map<std::string, InferenceEngine::Parameter>& options) { [this](const std::string& name, const std::map<std::string, InferenceEngine::Parameter>& options) {
return plugin->AutoBatchInferencePlugin::GetMetric(name, options); return plugin->Plugin::GetMetric(name, options);
}); });
} }
}; };