[GNA] Fix exception handling on wait and infer (#14578)
* [GNA] Move definition of GNAInferRequest class to source file * [GNA] Fix handling exception for infer() and wait() gna_infer_request * fixed handling of exceptions for wait and infer gna_infer_request * fixed closing ongoing subrequest of divided model in case of exception on on enqueueing request or waiting for end of request. * [GNA] Apply review comments, Removed exceptions from enqueue and wait for Worker * changed API of request worker to return: * erorr in case wait failed instead of throw, * return true/false for enqueue instead of exception for failure case * [GNA] Fix review commentd related to gna_infer_request and worker_impl * [GNA] Add tests for GANInferRequest * added tests for exception handling of Infer, Wait, StartAsync methods of GNAInferRequest class. * [GNA] Added final fixes for review comments.
This commit is contained in:
parent
bc69385093
commit
6bca87a88a
@ -475,13 +475,16 @@ GNAPluginNS::RequestStatus GNADeviceHelper::waitForRequest(uint32_t requestID, i
|
||||
if (status == Gna2StatusDriverQoSTimeoutExceeded) {
|
||||
return GNAPluginNS::RequestStatus::kAborted;
|
||||
}
|
||||
checkGna2Status(status, "Gna2RequestWait");
|
||||
|
||||
if (per_request_diagnostics) {
|
||||
dumpAllAllocations(debugLogIndexRequestWait, "AfterGna2RequestWait");
|
||||
debugLogIndexRequestWait++;
|
||||
}
|
||||
updateGnaPerfCounters();
|
||||
|
||||
// handle error case after updating statistics data.
|
||||
checkGna2Status(status, "Gna2RequestWait");
|
||||
|
||||
return GNAPluginNS::RequestStatus::kCompleted;
|
||||
}
|
||||
|
||||
|
173
src/plugins/intel_gna/src/gna_infer_request.cpp
Normal file
173
src/plugins/intel_gna/src/gna_infer_request.cpp
Normal file
@ -0,0 +1,173 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gna_infer_request.hpp"
|
||||
|
||||
#include "gna_plugin.hpp"
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
GNAInferRequest::GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
|
||||
const std::vector<std::shared_ptr<const ov::Node>>& outputs)
|
||||
: InferenceEngine::IInferRequestInternal(inputs, outputs),
|
||||
plg(plg) {
|
||||
CreateInferRequest();
|
||||
}
|
||||
|
||||
GNAInferRequest::GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||
InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs)
|
||||
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs),
|
||||
plg(plg) {
|
||||
CreateInferRequest();
|
||||
}
|
||||
|
||||
void GNAInferRequest::InferImpl() {
|
||||
// execute input pre-processing.
|
||||
execDataPreprocessing(_inputs);
|
||||
// result returned from sync infer wait method
|
||||
|
||||
auto infer_call = [&]() {
|
||||
auto result = plg->Infer(_inputs, _outputs);
|
||||
// if result is false we are dealing with QoS feature and set kRequestIndexInvalid
|
||||
// if result is ok we set kRequestIndexCompleted to not execute request if it is not
|
||||
// in the queue.
|
||||
auto result_request_index = result ? kRequestIndexCompleted : kRequestIndexInvalid;
|
||||
SetRequestIndex(result_request_index);
|
||||
};
|
||||
|
||||
CallCleanupAndRethrowOnException(std::move(infer_call));
|
||||
}
|
||||
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GNAInferRequest::GetPerformanceCounts() const {
|
||||
return plg->GetPerformanceCounts();
|
||||
}
|
||||
|
||||
void GNAInferRequest::StartAsyncImpl() {
|
||||
// execute input pre-processing.
|
||||
execDataPreprocessing(_inputs);
|
||||
|
||||
auto queue_call = [&]() {
|
||||
SetRequestIndex(plg->QueueInference(_inputs, _outputs));
|
||||
};
|
||||
|
||||
CallCleanupAndRethrowOnException(std::move(queue_call));
|
||||
|
||||
// workaround to unblock callback-based flows
|
||||
if (_callback) {
|
||||
auto res = Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY);
|
||||
std::exception_ptr exceptionPtr;
|
||||
if (res != InferenceEngine::StatusCode::OK) {
|
||||
try {
|
||||
IE_EXCEPTION_SWITCH(res,
|
||||
ExceptionType,
|
||||
InferenceEngine::details::ThrowNow<ExceptionType>{} <<=
|
||||
std::stringstream{}
|
||||
<< IE_LOCATION
|
||||
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
|
||||
} catch (...) {
|
||||
exceptionPtr = std::current_exception();
|
||||
}
|
||||
}
|
||||
_callback(exceptionPtr);
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode GNAInferRequest::Wait(int64_t millis_timeout) {
|
||||
if (!IsRequestIndexValid()) {
|
||||
return InferenceEngine::INFER_NOT_STARTED;
|
||||
}
|
||||
|
||||
ValidateAndConfigureTimeout(millis_timeout);
|
||||
|
||||
if (IsRequestCompleted()) {
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
auto waitStatus = RequestStatus::kAborted;
|
||||
auto wait_call = [&]() {
|
||||
waitStatus = plg->WaitFor(_infer_request_idx, millis_timeout);
|
||||
};
|
||||
CallCleanupAndRethrowOnException(std::move(wait_call));
|
||||
|
||||
return HandleRequestWaitStatus(waitStatus);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> GNAInferRequest::QueryState() {
|
||||
auto pluginStates = plg->QueryState();
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
||||
return plg->QueryState();
|
||||
}
|
||||
|
||||
bool GNAInferRequest::IsRequestIndexValid() {
|
||||
return _infer_request_idx != kRequestIndexInvalid;
|
||||
}
|
||||
|
||||
bool GNAInferRequest::IsRequestCompleted() {
|
||||
return _infer_request_idx == kRequestIndexCompleted;
|
||||
}
|
||||
|
||||
bool GNAInferRequest::SetRequestIndex(uint32_t request_index) {
|
||||
return _infer_request_idx = request_index;
|
||||
}
|
||||
|
||||
void GNAInferRequest::ValidateAndConfigureTimeout(int64_t& millis_timeout) {
|
||||
if (millis_timeout == InferenceEngine::InferRequest::WaitMode::RESULT_READY) {
|
||||
millis_timeout = MAX_TIMEOUT;
|
||||
}
|
||||
|
||||
if (millis_timeout < 0) {
|
||||
IE_THROW(ParameterMismatch) << "Invalid timeout value in milliseconds: " << millis_timeout << "!";
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode GNAInferRequest::HandleRequestWaitStatus(const RequestStatus& request_status) {
|
||||
if (request_status == RequestStatus::kPending) {
|
||||
// request is still pending so Wait() is needed once again
|
||||
return InferenceEngine::RESULT_NOT_READY;
|
||||
}
|
||||
|
||||
if (request_status == RequestStatus::kAborted) {
|
||||
// need to preserve invalid state here to avoid next Wait() from clearing it
|
||||
SetRequestIndex(kRequestIndexInvalid);
|
||||
return InferenceEngine::INFER_NOT_STARTED;
|
||||
}
|
||||
|
||||
if (request_status == RequestStatus::kCompletedWithError) {
|
||||
SetRequestIndex(kRequestIndexInvalid);
|
||||
THROW_GNA_EXCEPTION << "Error when waiting for inference results!";
|
||||
}
|
||||
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
void GNAInferRequest::CallCleanupAndRethrowOnException(std::function<void()>&& function_to_invoke) {
|
||||
try {
|
||||
function_to_invoke();
|
||||
} catch (...) {
|
||||
// need to preserve invalid state here to avoid next Wait() from clearing it
|
||||
// and next rethrow issue.
|
||||
SetRequestIndex(kRequestIndexInvalid);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void GNAInferRequest::CreateInferRequest() {
|
||||
// TODO: internal connection API - better to generalize
|
||||
if (_networkOutputs.empty()) {
|
||||
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
|
||||
}
|
||||
|
||||
// copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
|
||||
for (auto output : _networkOutputs) {
|
||||
_outputs[output.first] = plg->GetOutputBlob(output.first, output.second->getTensorDesc().getPrecision());
|
||||
}
|
||||
|
||||
for (auto input : _networkInputs) {
|
||||
_inputs[input.first] = plg->GetInputBlob(input.first, input.second->getTensorDesc().getPrecision());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace GNAPluginNS
|
@ -5,137 +5,59 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||
#include "cpp/ie_infer_request.hpp"
|
||||
#include "gna_plugin.hpp"
|
||||
#include "request_status.hpp"
|
||||
|
||||
namespace GNAPluginNS {
|
||||
class GNAPlugin;
|
||||
|
||||
class GNAInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
private:
|
||||
void CreateInferRequest() {
|
||||
// TODO: internal connection API - better to generalize
|
||||
if (_networkOutputs.empty()) {
|
||||
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
|
||||
}
|
||||
|
||||
// copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
|
||||
for (auto output : _networkOutputs) {
|
||||
_outputs[output.first] =
|
||||
plg->GetOutputBlob(output.first, output.second->getTensorDesc().getPrecision());
|
||||
}
|
||||
|
||||
for (auto input : _networkInputs) {
|
||||
_inputs[input.first] =
|
||||
plg->GetInputBlob(input.first, input.second->getTensorDesc().getPrecision());
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<GNAPlugin> plg;
|
||||
uint32_t inferRequestIdx = -1;
|
||||
|
||||
public:
|
||||
public:
|
||||
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
|
||||
const std::vector<std::shared_ptr<const ov::Node>>& outputs)
|
||||
: InferenceEngine::IInferRequestInternal(inputs, outputs), plg(plg) {
|
||||
CreateInferRequest();
|
||||
}
|
||||
const std::vector<std::shared_ptr<const ov::Node>>& outputs);
|
||||
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||
InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs)
|
||||
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
|
||||
CreateInferRequest();
|
||||
}
|
||||
InferenceEngine::InputsDataMap network_inputs,
|
||||
InferenceEngine::OutputsDataMap network_outputs);
|
||||
/**
|
||||
* @brief Infers specified input(s) in synchronous mode
|
||||
* @note blocks all method of InferRequest while request is ongoing (running or waiting in queue)
|
||||
*/
|
||||
void InferImpl() override {
|
||||
// execute input pre-processing.
|
||||
execDataPreprocessing(_inputs);
|
||||
// result returned from sync infer wait method
|
||||
auto result = plg->Infer(_inputs, _outputs);
|
||||
|
||||
// if result is false we are dealing with QoS feature
|
||||
// if result is ok, next call to wait() will return Ok, if request not in gna_queue
|
||||
if (!result) {
|
||||
inferRequestIdx = -1;
|
||||
} else {
|
||||
inferRequestIdx = -2;
|
||||
}
|
||||
}
|
||||
void InferImpl() override;
|
||||
|
||||
/**
|
||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
||||
* Note: not all plugins may provide meaningful data
|
||||
* @param perfMap - a map of layer names to profiling information for that layer.
|
||||
*/
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override {
|
||||
return plg->GetPerformanceCounts();
|
||||
}
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||
|
||||
/**
|
||||
* @brief methods with _ThreadUnsafe prefix are to implement in plugins
|
||||
* or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
|
||||
*/
|
||||
void StartAsyncImpl() override {
|
||||
// execute input pre-processing.
|
||||
execDataPreprocessing(_inputs);
|
||||
inferRequestIdx = plg->QueueInference(_inputs, _outputs);
|
||||
// workaround to unblock callback-based flows
|
||||
if (_callback) {
|
||||
auto res = Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY);
|
||||
std::exception_ptr exceptionPtr;
|
||||
if (res != InferenceEngine::StatusCode::OK) {
|
||||
try {
|
||||
IE_EXCEPTION_SWITCH(res, ExceptionType,
|
||||
InferenceEngine::details::ThrowNow<ExceptionType>{}
|
||||
<<= std::stringstream{} << IE_LOCATION
|
||||
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
|
||||
} catch (...) {
|
||||
exceptionPtr = std::current_exception();
|
||||
}
|
||||
}
|
||||
_callback(exceptionPtr);
|
||||
}
|
||||
}
|
||||
void StartAsyncImpl() override;
|
||||
|
||||
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
|
||||
|
||||
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override {
|
||||
if (inferRequestIdx == -1) {
|
||||
return InferenceEngine::INFER_NOT_STARTED;
|
||||
} else if (millis_timeout < -1) {
|
||||
IE_THROW(ParameterMismatch);
|
||||
}
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
|
||||
|
||||
if (millis_timeout == InferenceEngine::InferRequest::WaitMode::RESULT_READY) {
|
||||
millis_timeout = MAX_TIMEOUT;
|
||||
}
|
||||
const auto waitStatus = plg->WaitFor(inferRequestIdx, millis_timeout);
|
||||
protected:
|
||||
bool SetRequestIndex(uint32_t request_index);
|
||||
bool IsRequestIndexValid();
|
||||
bool IsRequestCompleted();
|
||||
|
||||
if (waitStatus == RequestStatus::kPending) {
|
||||
// request is still pending so Wait() is needed once again
|
||||
return InferenceEngine::RESULT_NOT_READY;
|
||||
}
|
||||
if (waitStatus == RequestStatus::kAborted) {
|
||||
// need to preserve invalid state here to avoid next Wait() from clearing it
|
||||
inferRequestIdx = -1;
|
||||
return InferenceEngine::INFER_NOT_STARTED;
|
||||
}
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
private:
|
||||
void CreateInferRequest();
|
||||
InferenceEngine::StatusCode HandleRequestWaitStatus(const RequestStatus& request_status);
|
||||
void ValidateAndConfigureTimeout(int64_t& millis_timeout);
|
||||
void CallCleanupAndRethrowOnException(std::function<void()>&& function_to_invoke);
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
||||
auto pluginStates = plg->QueryState();
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
||||
return plg->QueryState();
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
static constexpr const uint32_t kRequestIndexInvalid = std::numeric_limits<uint32_t>::max();
|
||||
static constexpr const uint32_t kRequestIndexCompleted = std::numeric_limits<uint32_t>::max() - 1;
|
||||
|
||||
uint32_t _infer_request_idx = kRequestIndexInvalid;
|
||||
std::shared_ptr<GNAPlugin> plg;
|
||||
};
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -1335,7 +1335,9 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
|
||||
++inputNum;
|
||||
}
|
||||
|
||||
freeWorker->enqueueRequest();
|
||||
if (!freeWorker->enqueueRequest()) {
|
||||
THROW_GNA_EXCEPTION << "Error with enqueueing inference request";
|
||||
}
|
||||
|
||||
freeWorker->setResult(result);
|
||||
|
||||
@ -1351,7 +1353,13 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
|
||||
}
|
||||
|
||||
bool GNAPlugin::Wait(uint32_t request_idx) {
|
||||
return RequestStatus::kCompleted == WaitFor(request_idx, MAX_TIMEOUT);
|
||||
auto result = WaitFor(request_idx, MAX_TIMEOUT);
|
||||
|
||||
if (result == RequestStatus::kCompletedWithError) {
|
||||
THROW_GNA_EXCEPTION << "Error when waiting for inference results!";
|
||||
}
|
||||
|
||||
return result == RequestStatus::kCompleted;
|
||||
}
|
||||
|
||||
RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||
@ -1368,6 +1376,10 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||
|
||||
const auto waitStatus = worker.wait(millisTimeout);
|
||||
|
||||
if (waitStatus == RequestStatus::kCompletedWithError) {
|
||||
return waitStatus;
|
||||
}
|
||||
|
||||
if (waitStatus == RequestStatus::kAborted) {
|
||||
return waitStatus;
|
||||
}
|
||||
@ -1528,7 +1540,7 @@ bool GNAPlugin::Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob
|
||||
}
|
||||
|
||||
bool GNAPlugin::Infer(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result) {
|
||||
return Wait(QueueInference(input, result));
|
||||
return Wait(QueueInference(input, result));
|
||||
}
|
||||
|
||||
static InferenceEngine::Layout GetLayoutForDims(const InferenceEngine::SizeVector &dims) {
|
||||
|
@ -28,6 +28,7 @@ public:
|
||||
* @param requestID id of request to be used for wait
|
||||
* @param timeoutMilliseconds timeout of wait in milliseconds
|
||||
* @return Status of subrequest @see GNAPluginNS::RequestStatus
|
||||
*
|
||||
*/
|
||||
using WaitHandler = std::function<RequestStatus(uint32_t requestID, int64_t timeoutMilliseconds)>;
|
||||
|
||||
@ -42,8 +43,14 @@ public:
|
||||
|
||||
/**
|
||||
* @brief Add subrequest to execution queue.
|
||||
* @return true in case subrequest was properly enqueued, otherwise return false
|
||||
*/
|
||||
virtual void enqueue() = 0;
|
||||
virtual bool enqueue() = 0;
|
||||
|
||||
/**
|
||||
* @brief Finalize subrequest and set it status to RequestStatus::kNone
|
||||
*/
|
||||
virtual void cleanup() = 0;
|
||||
|
||||
/**
|
||||
* @brief Return true if subrequest is pending, otherwise return false
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <gna2-inference-api.h>
|
||||
|
||||
#include "log/debug.hpp"
|
||||
#include "log/log.hpp"
|
||||
|
||||
namespace GNAPluginNS {
|
||||
namespace request {
|
||||
@ -24,14 +25,30 @@ RequestStatus SubrequestImpl::wait(int64_t timeoutMilliseconds) {
|
||||
return status_;
|
||||
}
|
||||
|
||||
status_ = waitHandler_(requestID_, timeoutMilliseconds);
|
||||
try {
|
||||
status_ = waitHandler_(requestID_, timeoutMilliseconds);
|
||||
} catch (const std::exception& e) {
|
||||
ov::intel_gna::log::error() << "Exception when executiong wait: " << e.what() << std::endl;
|
||||
status_ = RequestStatus::kCompletedWithError;
|
||||
}
|
||||
|
||||
return status_;
|
||||
}
|
||||
|
||||
void SubrequestImpl::enqueue() {
|
||||
requestID_ = enqueueHandler_();
|
||||
status_ = RequestStatus::kPending;
|
||||
bool SubrequestImpl::enqueue() {
|
||||
try {
|
||||
requestID_ = enqueueHandler_();
|
||||
status_ = RequestStatus::kPending;
|
||||
} catch (const std::exception& e) {
|
||||
ov::intel_gna::log::error() << "Exception when executiong enqueue: " << e.what() << std::endl;
|
||||
status_ = RequestStatus::kCompletedWithError;
|
||||
}
|
||||
return status_ != RequestStatus::kCompletedWithError;
|
||||
}
|
||||
|
||||
void SubrequestImpl::cleanup() {
|
||||
static_cast<void>(wait(0));
|
||||
status_ = RequestStatus::kNone;
|
||||
}
|
||||
|
||||
bool SubrequestImpl::isPending() const {
|
||||
|
@ -40,8 +40,14 @@ public:
|
||||
|
||||
/**
|
||||
* @brief Add subrequest to execution queue.
|
||||
* @return true in case subrequest was properly enqueued, otherwise return false
|
||||
*/
|
||||
void enqueue() override;
|
||||
bool enqueue() override;
|
||||
|
||||
/**
|
||||
* @brief Finalize subrequest and set it status to RequestStatus::kNone
|
||||
*/
|
||||
void cleanup() override;
|
||||
|
||||
/**
|
||||
* @brief Return true if subrequest is pending, otherwise return false
|
||||
|
@ -39,15 +39,14 @@ public:
|
||||
|
||||
/**
|
||||
* @brief Enqueue request to requests queue for contained model.
|
||||
* @throw Exception in case worker is busy or if there was an issue with enqueue.
|
||||
* @return true in case subrequest was properly enqueued, otherwise return false
|
||||
*/
|
||||
virtual void enqueueRequest() = 0;
|
||||
virtual bool enqueueRequest() = 0;
|
||||
|
||||
/**
|
||||
* @brief Wait untril request will be not finished for give timeout.
|
||||
* @param timeoutMilliseconds timeout in milliseconds
|
||||
* @return status of execution of ongoing request. @see GNAPluginNS::RequestStatus
|
||||
* @throw Exception in case worker is busy or if there was an issue with enqueue.
|
||||
*/
|
||||
virtual RequestStatus wait(int64_t timeoutMilliseconds) = 0;
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <gna2-inference-api.h>
|
||||
|
||||
#include "log/debug.hpp"
|
||||
#include "log/log.hpp"
|
||||
#include "model_wrapper.hpp"
|
||||
#include "subrequest.hpp"
|
||||
|
||||
@ -39,12 +40,19 @@ Gna2Model* WorkerImpl::model() {
|
||||
return &fullModel_->object();
|
||||
}
|
||||
|
||||
void WorkerImpl::enqueueRequest() {
|
||||
check_if_free();
|
||||
bool WorkerImpl::enqueueRequest() {
|
||||
if (!isFree()) {
|
||||
ov::intel_gna::log::warning() << "Trying to propagate on busy request with id: " << representingIndex_;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& subrequest : modelSubrequests_) {
|
||||
subrequest->enqueue();
|
||||
if (!subrequest->enqueue()) {
|
||||
cleanup_subrequests();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
|
||||
@ -56,8 +64,13 @@ RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (subrequest->wait(timeoutMilliseconds) == RequestStatus::kPending) {
|
||||
auto result = subrequest->wait(timeoutMilliseconds);
|
||||
|
||||
if (result == RequestStatus::kPending) {
|
||||
pending = true;
|
||||
} else if (result == RequestStatus::kCompletedWithError) {
|
||||
cleanup_subrequests();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,9 +120,11 @@ InferenceEngine::BlobMap& WorkerImpl::result() {
|
||||
return requestResult_;
|
||||
}
|
||||
|
||||
void WorkerImpl::check_if_free() {
|
||||
if (!isFree()) {
|
||||
THROW_GNA_EXCEPTION << "Trying to propagte on busy request with id: " << representingIndex_;
|
||||
void WorkerImpl::cleanup_subrequests() {
|
||||
for (auto& subrequest : modelSubrequests_) {
|
||||
if (subrequest->isPending()) {
|
||||
subrequest->cleanup();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ public:
|
||||
/**
|
||||
* @see Worker::enqueueRequest()
|
||||
*/
|
||||
void enqueueRequest() override;
|
||||
bool enqueueRequest() override;
|
||||
|
||||
/**
|
||||
* @see Worker::wait()
|
||||
@ -93,7 +93,7 @@ public:
|
||||
void setResult(InferenceEngine::BlobMap&& result) override;
|
||||
|
||||
private:
|
||||
void check_if_free();
|
||||
void cleanup_subrequests();
|
||||
|
||||
uint32_t representingIndex_{0};
|
||||
std::shared_ptr<ModelWrapper> fullModel_;
|
||||
|
@ -10,10 +10,11 @@ namespace GNAPluginNS {
|
||||
* @brief Enum representing status of request
|
||||
*/
|
||||
enum class RequestStatus {
|
||||
kNone = 0, /// request was not initialized
|
||||
kAborted = 1, /// request was aborted
|
||||
kPending = 2, /// request was started and is onging
|
||||
kCompleted = 3 /// request was completed with success
|
||||
kNone = 0, /// request was not initialized
|
||||
kAborted = 1, /// request was aborted
|
||||
kPending = 2, /// request was started and is onging
|
||||
kCompleted = 3, /// request was completed with success
|
||||
kCompletedWithError = 4 /// request was completed with error
|
||||
};
|
||||
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -14,7 +14,7 @@ endif()
|
||||
|
||||
# TODO: fix CVS-71010 and remove BUILD_SHARED_LIBS
|
||||
if(NOT BUILD_SHARED_LIBS)
|
||||
set(exclude_path EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/(gna_api_stub|gna_wait_test|gna_export_import_test).cpp")
|
||||
set(exclude_path EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/(gna_api_stub|gna_wait_test|gna_export_import_test|gna_infer_request_test).cpp")
|
||||
endif()
|
||||
|
||||
addIeTargetTest(
|
||||
|
@ -138,6 +138,9 @@ GNA2_API enum Gna2Status Gna2RequestConfigSetAccelerationMode(
|
||||
GNA2_API enum Gna2Status Gna2RequestEnqueue(
|
||||
uint32_t requestConfigId,
|
||||
uint32_t * requestId) {
|
||||
if (current != nullptr) {
|
||||
return current->Gna2RequestEnqueue(requestConfigId, requestId);
|
||||
}
|
||||
return Gna2StatusSuccess;
|
||||
}
|
||||
|
||||
|
204
src/plugins/intel_gna/tests/unit/gna_infer_request_test.cpp
Normal file
204
src/plugins/intel_gna/tests/unit/gna_infer_request_test.cpp
Normal file
@ -0,0 +1,204 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gna_infer_request.hpp"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "any_copy.hpp"
|
||||
#include "common_test_utils/data_utils.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gna_mock_api.hpp"
|
||||
#include "gna_plugin.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
using GNAPluginNS::GNAInferRequest;
|
||||
using GNAPluginNS::GNAPlugin;
|
||||
using ::testing::InSequence;
|
||||
|
||||
class GNAInferRequestTest : public ::testing::Test {
|
||||
public:
|
||||
IInferRequestInternal::Ptr CreateRequest() {
|
||||
auto function = GetFunction();
|
||||
CNNNetwork cnn_network = CNNNetwork{function};
|
||||
|
||||
SetExpectsForLoadNetworkAndShutDown(_data);
|
||||
const ov::AnyMap gna_config = {ov::intel_gna::execution_mode(ov::intel_gna::ExecutionMode::SW_EXACT)};
|
||||
auto plugin = std::make_shared<GNAPlugin>(any_copy(gna_config));
|
||||
plugin->LoadNetwork(cnn_network);
|
||||
|
||||
return std::make_shared<GNAInferRequest>(plugin, cnn_network.getInputsInfo(), cnn_network.getOutputsInfo());
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ngraph::Function> GetFunction() {
|
||||
auto ngPrc = ngraph::element::f32;
|
||||
std::vector<size_t> shape = {1, 10};
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {shape});
|
||||
auto shape_size = ov::shape_size(shape);
|
||||
auto add_const =
|
||||
ngraph::builder::makeConstant<float>(ngPrc,
|
||||
shape,
|
||||
CommonTestUtils::generate_float_numbers(shape_size, -0.5f, 0.5f),
|
||||
false);
|
||||
|
||||
auto add = std::make_shared<ngraph::opset9::Add>(params[0], add_const);
|
||||
auto res = std::make_shared<ngraph::op::Result>(add);
|
||||
auto function = std::make_shared<ngraph::Function>(res, params, "Add");
|
||||
return function;
|
||||
}
|
||||
|
||||
void SetExpectsForLoadNetworkAndShutDown(std::vector<std::vector<uint8_t>>& data) {
|
||||
EXPECT_CALL(*_mock_api, Gna2MemoryAlloc(_, _, _))
|
||||
.Times(AtLeast(1))
|
||||
.WillRepeatedly(Invoke([&data](uint32_t size_requested, uint32_t* size_granted, void** memory_address) {
|
||||
data.push_back(std::vector<uint8_t>(size_requested));
|
||||
*size_granted = size_requested;
|
||||
*memory_address = data.back().data();
|
||||
return Gna2StatusSuccess;
|
||||
}));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2DeviceGetVersion(_, _))
|
||||
.WillOnce(Invoke([](uint32_t deviceIndex, enum Gna2DeviceVersion* deviceVersion) {
|
||||
*deviceVersion = Gna2DeviceVersionSoftwareEmulation;
|
||||
return Gna2StatusSuccess;
|
||||
}));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2GetLibraryVersion(_, _))
|
||||
.Times(AtLeast(0))
|
||||
.WillRepeatedly(Return(Gna2StatusSuccess));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2InstrumentationConfigCreate(_, _, _, _)).WillOnce(Return(Gna2StatusSuccess));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2ModelCreate(_, _, _))
|
||||
.WillOnce(Invoke([](uint32_t deviceIndex, struct Gna2Model const* model, uint32_t* model_id) {
|
||||
*model_id = 0;
|
||||
return Gna2StatusSuccess;
|
||||
}));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2RequestConfigCreate(_, _))
|
||||
.WillOnce(Invoke([](uint32_t model_Id, uint32_t* request_config_id) {
|
||||
*request_config_id = 0;
|
||||
return Gna2StatusSuccess;
|
||||
}));
|
||||
|
||||
EXPECT_CALL(*_mock_api, Gna2InstrumentationConfigAssignToRequestConfig(_, _))
|
||||
.Times(AtLeast(1))
|
||||
.WillRepeatedly(Return(Gna2StatusSuccess));
|
||||
|
||||
InSequence seq;
|
||||
EXPECT_CALL(*_mock_api, Gna2DeviceClose(_)).WillOnce(Return(Gna2StatusSuccess));
|
||||
EXPECT_CALL(*_mock_api, Gna2MemoryFree(_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
|
||||
}
|
||||
|
||||
void SetExpectOnEnqueue(const Gna2Status& return_status = Gna2StatusSuccess) {
|
||||
EXPECT_CALL(*_mock_api, Gna2RequestEnqueue(_, _)).WillOnce(Return(return_status));
|
||||
}
|
||||
|
||||
void SetExpectOnWait(const Gna2Status& return_status = Gna2StatusSuccess) {
|
||||
EXPECT_CALL(*_mock_api, Gna2RequestWait(_, _)).WillOnce(Return(return_status));
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
_mock_api = std::make_shared<StrictMock<GNACppApi>>();
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
ASSERT_TRUE(Mock::VerifyAndClearExpectations(_mock_api.get()));
|
||||
}
|
||||
|
||||
std::shared_ptr<StrictMock<GNACppApi>> _mock_api;
|
||||
std::vector<std::vector<uint8_t>> _data;
|
||||
};
|
||||
|
||||
TEST_F(GNAInferRequestTest, start_async) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue();
|
||||
// wait on shutdown neede if request was enqueued by not waited
|
||||
SetExpectOnWait();
|
||||
EXPECT_NO_THROW(request->StartAsync());
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, start_async_with_enqueue_error) {
|
||||
auto request = CreateRequest();
|
||||
// trigger Gna2RequestEnqueue to fail
|
||||
SetExpectOnEnqueue(Gna2StatusUnknownError);
|
||||
// no wait needed due the fact there are no enqueud requests
|
||||
EXPECT_THROW(request->StartAsync(), std::exception);
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, start_async_with_wait) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue();
|
||||
// wait on shutdown needed if request was enqueued by not waited
|
||||
SetExpectOnWait();
|
||||
EXPECT_NO_THROW(request->StartAsync());
|
||||
EXPECT_EQ(OK, request->Wait(0));
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, start_async_error_with_wait) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue(Gna2StatusUnknownError);
|
||||
// wait on shutdown needed if request was enqueued by not waited
|
||||
// SetExpectOnWait();
|
||||
EXPECT_THROW(request->StartAsync(), std::exception);
|
||||
EXPECT_EQ(INFER_NOT_STARTED, request->Wait(0));
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, start_async_with_wait_error) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue();
|
||||
// wait on shutdown needed if request was enqueued by not waited
|
||||
SetExpectOnWait(Gna2StatusUnknownError);
|
||||
EXPECT_NO_THROW(request->StartAsync());
|
||||
EXPECT_THROW(request->Wait(0), std::exception);
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, start_async_wait_check_recovery_after_wait_error) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue();
|
||||
SetExpectOnWait(Gna2StatusUnknownError);
|
||||
EXPECT_NO_THROW(request->StartAsync());
|
||||
EXPECT_THROW(request->Wait(0), std::exception);
|
||||
// check that no there is exception on second wiat after first failing
|
||||
EXPECT_EQ(INFER_NOT_STARTED, request->Wait(0));
|
||||
|
||||
// start new request
|
||||
SetExpectOnEnqueue();
|
||||
SetExpectOnWait();
|
||||
EXPECT_NO_THROW(request->StartAsync());
|
||||
EXPECT_EQ(OK, request->Wait(0));
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, infer) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue();
|
||||
SetExpectOnWait();
|
||||
EXPECT_NO_THROW(request->Infer());
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, infer_enque_error) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue(Gna2StatusUnknownError);
|
||||
EXPECT_THROW(request->Infer(), std::exception);
|
||||
}
|
||||
|
||||
TEST_F(GNAInferRequestTest, infer_wait_error_check_recovery) {
|
||||
auto request = CreateRequest();
|
||||
SetExpectOnEnqueue();
|
||||
SetExpectOnWait(Gna2StatusUnknownError);
|
||||
EXPECT_THROW(request->Infer(), std::exception);
|
||||
// check if next infer will execute properly after wait throwing
|
||||
SetExpectOnEnqueue();
|
||||
SetExpectOnWait();
|
||||
EXPECT_NO_THROW(request->Infer());
|
||||
}
|
@ -47,6 +47,10 @@ public:
|
||||
uint32_t requestId,
|
||||
uint32_t timeoutMilliseconds));
|
||||
|
||||
MOCK_METHOD2(Gna2RequestEnqueue, Gna2Status(
|
||||
uint32_t requestConfigId,
|
||||
uint32_t* requestId));
|
||||
|
||||
MOCK_METHOD2(Gna2DeviceGetVersion, Gna2Status(
|
||||
uint32_t deviceIndex,
|
||||
enum Gna2DeviceVersion * deviceVersion));
|
||||
|
@ -10,14 +10,15 @@
|
||||
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
|
||||
#include "gna_infer_request.hpp"
|
||||
#include "gna_mock_api.hpp"
|
||||
#include "gna_plugin.hpp"
|
||||
#include "request/model_wrapper_factory.hpp"
|
||||
#include "request/subrequest_impl.hpp"
|
||||
#include "request/worker_factory.hpp"
|
||||
#include "request/worker_impl.hpp"
|
||||
#include "request/worker_pool.hpp"
|
||||
|
||||
using GNAPluginNS::GNAInferRequest;
|
||||
using GNAPluginNS::GNAPlugin;
|
||||
using namespace GNAPluginNS;
|
||||
using namespace GNAPluginNS::request;
|
||||
using ::testing::_;
|
||||
using ::testing::Return;
|
||||
|
||||
@ -27,9 +28,6 @@ class GNAPluginForGNAWaitTest : public GNAPlugin {
|
||||
public:
|
||||
// Prepare underlining object to enable GNAInferRequest::Wait() working
|
||||
GNAPluginForGNAWaitTest() {
|
||||
using namespace GNAPluginNS;
|
||||
using namespace request;
|
||||
|
||||
InferenceEngine::TensorDesc td{InferenceEngine::Precision::FP32, {1, 1}, InferenceEngine::Layout::HW};
|
||||
auto fakeInfo = std::make_shared<InferenceEngine::InputInfo>();
|
||||
auto fakePtr = std::make_shared<InferenceEngine::Data>("fakeName", td);
|
||||
@ -55,20 +53,32 @@ public:
|
||||
|
||||
auto model = ModelWrapperFactory::createWithNumberOfEmptyOperations(1);
|
||||
subrequests.push_back(std::make_shared<SubrequestImpl>(std::move(enqueue), std::move(wait)));
|
||||
auto worker = std::make_shared<WorkerImpl>(model, std::move(subrequests));
|
||||
_worker = std::make_shared<WorkerImpl>(model, std::move(subrequests));
|
||||
|
||||
requestWorkerPool_->addModelWorker(worker);
|
||||
worker->enqueueRequest();
|
||||
requestWorkerPool_->addModelWorker(_worker);
|
||||
}
|
||||
|
||||
void EnqueTestRequest() {
|
||||
_worker->enqueueRequest();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Worker> _worker;
|
||||
};
|
||||
|
||||
class GNAInferRequestForGNAWaitTest : public GNAInferRequest {
|
||||
public:
|
||||
// Prepare underlining object to enable Wait() working
|
||||
GNAInferRequestForGNAWaitTest(std::shared_ptr<GNAPlugin> plugin)
|
||||
: GNAInferRequest{plugin, plugin->GetNetworkInputs(), plugin->GetNetworkOutputs()} {
|
||||
inferRequestIdx = 0;
|
||||
GNAInferRequestForGNAWaitTest(std::shared_ptr<GNAPluginForGNAWaitTest> plugin)
|
||||
: GNAInferRequest{plugin, plugin->GetNetworkInputs(), plugin->GetNetworkOutputs()},
|
||||
_plugin(plugin) {}
|
||||
|
||||
void EnqueTestRequest() {
|
||||
_plugin->EnqueTestRequest();
|
||||
SetRequestIndex(0);
|
||||
}
|
||||
|
||||
std::shared_ptr<GNAPluginForGNAWaitTest> _plugin;
|
||||
};
|
||||
|
||||
TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
|
||||
@ -76,6 +86,7 @@ TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDriverQoSTimeoutExceeded));
|
||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||
inferRequest.EnqueTestRequest();
|
||||
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
|
||||
}
|
||||
|
||||
@ -84,6 +95,38 @@ TEST_F(GNAWaitTest, ReturnsGna2StatusWarningDeviceBusy) {
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusWarningDeviceBusy));
|
||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||
|
||||
inferRequest.EnqueTestRequest();
|
||||
ASSERT_EQ(InferenceEngine::RESULT_NOT_READY, inferRequest.Wait(0));
|
||||
}
|
||||
|
||||
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange) {
|
||||
GNACppApi enableMocks;
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
|
||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||
inferRequest.EnqueTestRequest();
|
||||
ASSERT_THROW(inferRequest.Wait(0), std::exception);
|
||||
}
|
||||
|
||||
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange_Extra_Sync) {
|
||||
GNACppApi enableMocks;
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
|
||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||
inferRequest.EnqueTestRequest();
|
||||
ASSERT_THROW(inferRequest.Wait(0), std::exception);
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(0);
|
||||
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
|
||||
}
|
||||
|
||||
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange_Another_Use) {
|
||||
GNACppApi enableMocks;
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
|
||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||
inferRequest.EnqueTestRequest();
|
||||
ASSERT_THROW(inferRequest.Wait(0), std::exception);
|
||||
inferRequest.EnqueTestRequest();
|
||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusSuccess));
|
||||
ASSERT_EQ(InferenceEngine::OK, inferRequest.Wait(0));
|
||||
}
|
||||
|
@ -60,16 +60,40 @@ TEST_F(GNA_Request_WorkerImplTest, enqueueRequest) {
|
||||
subrequests.push_back(subrequestMock2);
|
||||
WorkerImpl worker(wrapper, subrequests);
|
||||
|
||||
// check if exception will be thrown if worker will be busy - at least one subrequest is pending.
|
||||
// check if false will be returned if worker will be busy - at least one subrequest is pending.
|
||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
|
||||
EXPECT_THROW(worker.enqueueRequest(), std::exception);
|
||||
EXPECT_FALSE(worker.enqueueRequest());
|
||||
|
||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1);
|
||||
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1);
|
||||
EXPECT_NO_THROW(worker.enqueueRequest());
|
||||
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1).WillOnce(Return(true));
|
||||
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1).WillOnce(Return(true));
|
||||
EXPECT_TRUE(worker.enqueueRequest());
|
||||
}
|
||||
|
||||
TEST_F(GNA_Request_WorkerImplTest, enqueueRequest_second_enque_failed) {
|
||||
auto wrapper = ModelWrapperFactory::createTrivial();
|
||||
std::vector<std::shared_ptr<Subrequest>> subrequests;
|
||||
|
||||
auto subrequestMock1 = std::make_shared<MockSubrequest>();
|
||||
subrequests.push_back(subrequestMock1);
|
||||
auto subrequestMock2 = std::make_shared<MockSubrequest>();
|
||||
subrequests.push_back(subrequestMock2);
|
||||
WorkerImpl worker(wrapper, subrequests);
|
||||
|
||||
// check if exception will be thrown if worker will be busy - at least one subrequest is pending.
|
||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
|
||||
EXPECT_FALSE(worker.enqueueRequest());
|
||||
|
||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(2).WillOnce(Return(false)).WillOnce(Return(true));
|
||||
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(2).WillOnce(Return(false)).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1).WillOnce(Return(true));
|
||||
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock1.get(), cleanup()).Times(1);
|
||||
EXPECT_CALL(*subrequestMock2.get(), cleanup()).Times(0);
|
||||
EXPECT_FALSE(worker.enqueueRequest());
|
||||
}
|
||||
|
||||
TEST_F(GNA_Request_WorkerImplTest, wait) {
|
||||
@ -107,6 +131,13 @@ TEST_F(GNA_Request_WorkerImplTest, wait) {
|
||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock1.get(), isAborted()).Times(1).WillOnce(Return(true));
|
||||
EXPECT_EQ(RequestStatus::kAborted, worker.wait(referenceTimeout));
|
||||
|
||||
// subrequest enuqued and completed with error
|
||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(2).WillOnce(Return(true)).WillOnce(Return(false));
|
||||
EXPECT_CALL(*subrequestMock1.get(), wait(referenceTimeout))
|
||||
.Times(1)
|
||||
.WillOnce(Return(RequestStatus::kCompletedWithError));
|
||||
EXPECT_EQ(RequestStatus::kCompletedWithError, worker.wait(referenceTimeout));
|
||||
}
|
||||
|
||||
TEST_F(GNA_Request_WorkerImplTest, isFree) {
|
||||
|
@ -13,7 +13,8 @@ namespace request {
|
||||
class MockSubrequest : public Subrequest {
|
||||
public:
|
||||
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
|
||||
MOCK_METHOD(void, enqueue, (), (override));
|
||||
MOCK_METHOD(bool, enqueue, (), (override));
|
||||
MOCK_METHOD(void, cleanup, (), (override));
|
||||
MOCK_METHOD(bool, isPending, (), (const, override));
|
||||
MOCK_METHOD(bool, isAborted, (), (const, override));
|
||||
MOCK_METHOD(bool, isCompleted, (), (const, override));
|
||||
|
@ -14,7 +14,7 @@ class MockWorker : public Worker {
|
||||
public:
|
||||
MOCK_METHOD(Gna2Model*, model, (), (override));
|
||||
MOCK_METHOD(const Gna2Model*, model, (), (const, override));
|
||||
MOCK_METHOD(void, enqueueRequest, (), (override));
|
||||
MOCK_METHOD(bool, enqueueRequest, (), (override));
|
||||
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
|
||||
MOCK_METHOD(bool, isFree, (), (const, override));
|
||||
MOCK_METHOD(uint32_t, representingIndex, (), (const, override));
|
||||
|
@ -36,7 +36,6 @@ public:
|
||||
if (!legacy_str.empty()) {
|
||||
name += "," + legacy_str;
|
||||
}
|
||||
std::cout << "name: " << name << std::endl;
|
||||
return name;
|
||||
}
|
||||
void SetUp() override {
|
||||
|
Loading…
Reference in New Issue
Block a user