[GNA] Fix exception handling on wait and infer (#14578)

* [GNA] Move definition of GNAInferRequest class to source file

* [GNA] Fix handling exception for infer() and wait() gna_infer_request

* fixed handling of exceptions for wait and infer gna_infer_request
* fixed closing ongoing subrequest of divided model in case of exception on
    on enqueueing request or waiting for end of request.

* [GNA] Apply review comments, Removed exceptions from enqueue and wait for Worker

* changed API of request worker to return:
    * erorr in case wait failed instead of
        throw,
    * return true/false for enqueue instead of exception for failure case

* [GNA] Fix review commentd related to gna_infer_request and worker_impl

* [GNA] Add tests for GANInferRequest

* added tests for exception handling of Infer, Wait, StartAsync methods of
    GNAInferRequest class.

* [GNA] Added final fixes for review comments.
This commit is contained in:
Marcin Kusmierski 2022-12-20 11:15:15 +01:00 committed by GitHub
parent bc69385093
commit 6bca87a88a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 590 additions and 150 deletions

View File

@ -475,13 +475,16 @@ GNAPluginNS::RequestStatus GNADeviceHelper::waitForRequest(uint32_t requestID, i
if (status == Gna2StatusDriverQoSTimeoutExceeded) {
return GNAPluginNS::RequestStatus::kAborted;
}
checkGna2Status(status, "Gna2RequestWait");
if (per_request_diagnostics) {
dumpAllAllocations(debugLogIndexRequestWait, "AfterGna2RequestWait");
debugLogIndexRequestWait++;
}
updateGnaPerfCounters();
// handle error case after updating statistics data.
checkGna2Status(status, "Gna2RequestWait");
return GNAPluginNS::RequestStatus::kCompleted;
}

View File

@ -0,0 +1,173 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gna_infer_request.hpp"
#include "gna_plugin.hpp"
namespace GNAPluginNS {
GNAInferRequest::GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
const std::vector<std::shared_ptr<const ov::Node>>& outputs)
: InferenceEngine::IInferRequestInternal(inputs, outputs),
plg(plg) {
CreateInferRequest();
}
GNAInferRequest::GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs)
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs),
plg(plg) {
CreateInferRequest();
}
void GNAInferRequest::InferImpl() {
// execute input pre-processing.
execDataPreprocessing(_inputs);
// result returned from sync infer wait method
auto infer_call = [&]() {
auto result = plg->Infer(_inputs, _outputs);
// if result is false we are dealing with QoS feature and set kRequestIndexInvalid
// if result is ok we set kRequestIndexCompleted to not execute request if it is not
// in the queue.
auto result_request_index = result ? kRequestIndexCompleted : kRequestIndexInvalid;
SetRequestIndex(result_request_index);
};
CallCleanupAndRethrowOnException(std::move(infer_call));
}
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GNAInferRequest::GetPerformanceCounts() const {
return plg->GetPerformanceCounts();
}
void GNAInferRequest::StartAsyncImpl() {
// execute input pre-processing.
execDataPreprocessing(_inputs);
auto queue_call = [&]() {
SetRequestIndex(plg->QueueInference(_inputs, _outputs));
};
CallCleanupAndRethrowOnException(std::move(queue_call));
// workaround to unblock callback-based flows
if (_callback) {
auto res = Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY);
std::exception_ptr exceptionPtr;
if (res != InferenceEngine::StatusCode::OK) {
try {
IE_EXCEPTION_SWITCH(res,
ExceptionType,
InferenceEngine::details::ThrowNow<ExceptionType>{} <<=
std::stringstream{}
<< IE_LOCATION
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
} catch (...) {
exceptionPtr = std::current_exception();
}
}
_callback(exceptionPtr);
}
}
InferenceEngine::StatusCode GNAInferRequest::Wait(int64_t millis_timeout) {
if (!IsRequestIndexValid()) {
return InferenceEngine::INFER_NOT_STARTED;
}
ValidateAndConfigureTimeout(millis_timeout);
if (IsRequestCompleted()) {
return InferenceEngine::OK;
}
auto waitStatus = RequestStatus::kAborted;
auto wait_call = [&]() {
waitStatus = plg->WaitFor(_infer_request_idx, millis_timeout);
};
CallCleanupAndRethrowOnException(std::move(wait_call));
return HandleRequestWaitStatus(waitStatus);
}
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> GNAInferRequest::QueryState() {
auto pluginStates = plg->QueryState();
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
return plg->QueryState();
}
bool GNAInferRequest::IsRequestIndexValid() {
return _infer_request_idx != kRequestIndexInvalid;
}
bool GNAInferRequest::IsRequestCompleted() {
return _infer_request_idx == kRequestIndexCompleted;
}
bool GNAInferRequest::SetRequestIndex(uint32_t request_index) {
return _infer_request_idx = request_index;
}
void GNAInferRequest::ValidateAndConfigureTimeout(int64_t& millis_timeout) {
if (millis_timeout == InferenceEngine::InferRequest::WaitMode::RESULT_READY) {
millis_timeout = MAX_TIMEOUT;
}
if (millis_timeout < 0) {
IE_THROW(ParameterMismatch) << "Invalid timeout value in milliseconds: " << millis_timeout << "!";
}
}
InferenceEngine::StatusCode GNAInferRequest::HandleRequestWaitStatus(const RequestStatus& request_status) {
if (request_status == RequestStatus::kPending) {
// request is still pending so Wait() is needed once again
return InferenceEngine::RESULT_NOT_READY;
}
if (request_status == RequestStatus::kAborted) {
// need to preserve invalid state here to avoid next Wait() from clearing it
SetRequestIndex(kRequestIndexInvalid);
return InferenceEngine::INFER_NOT_STARTED;
}
if (request_status == RequestStatus::kCompletedWithError) {
SetRequestIndex(kRequestIndexInvalid);
THROW_GNA_EXCEPTION << "Error when waiting for inference results!";
}
return InferenceEngine::OK;
}
void GNAInferRequest::CallCleanupAndRethrowOnException(std::function<void()>&& function_to_invoke) {
try {
function_to_invoke();
} catch (...) {
// need to preserve invalid state here to avoid next Wait() from clearing it
// and next rethrow issue.
SetRequestIndex(kRequestIndexInvalid);
throw;
}
}
void GNAInferRequest::CreateInferRequest() {
// TODO: internal connection API - better to generalize
if (_networkOutputs.empty()) {
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
}
// copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
for (auto output : _networkOutputs) {
_outputs[output.first] = plg->GetOutputBlob(output.first, output.second->getTensorDesc().getPrecision());
}
for (auto input : _networkInputs) {
_inputs[input.first] = plg->GetInputBlob(input.first, input.second->getTensorDesc().getPrecision());
}
}
} // namespace GNAPluginNS

View File

@ -5,137 +5,59 @@
#pragma once
#include <memory>
#include <string>
#include <map>
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
#include "cpp/ie_infer_request.hpp"
#include "gna_plugin.hpp"
#include "request_status.hpp"
namespace GNAPluginNS {
class GNAPlugin;
class GNAInferRequest : public InferenceEngine::IInferRequestInternal {
private:
void CreateInferRequest() {
// TODO: internal connection API - better to generalize
if (_networkOutputs.empty()) {
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
}
// copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
for (auto output : _networkOutputs) {
_outputs[output.first] =
plg->GetOutputBlob(output.first, output.second->getTensorDesc().getPrecision());
}
for (auto input : _networkInputs) {
_inputs[input.first] =
plg->GetInputBlob(input.first, input.second->getTensorDesc().getPrecision());
}
}
protected:
std::shared_ptr<GNAPlugin> plg;
uint32_t inferRequestIdx = -1;
public:
public:
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
const std::vector<std::shared_ptr<const ov::Node>>& outputs)
: InferenceEngine::IInferRequestInternal(inputs, outputs), plg(plg) {
CreateInferRequest();
}
const std::vector<std::shared_ptr<const ov::Node>>& outputs);
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs)
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
CreateInferRequest();
}
InferenceEngine::InputsDataMap network_inputs,
InferenceEngine::OutputsDataMap network_outputs);
/**
* @brief Infers specified input(s) in synchronous mode
* @note blocks all method of InferRequest while request is ongoing (running or waiting in queue)
*/
void InferImpl() override {
// execute input pre-processing.
execDataPreprocessing(_inputs);
// result returned from sync infer wait method
auto result = plg->Infer(_inputs, _outputs);
// if result is false we are dealing with QoS feature
// if result is ok, next call to wait() will return Ok, if request not in gna_queue
if (!result) {
inferRequestIdx = -1;
} else {
inferRequestIdx = -2;
}
}
void InferImpl() override;
/**
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
* Note: not all plugins may provide meaningful data
* @param perfMap - a map of layer names to profiling information for that layer.
*/
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override {
return plg->GetPerformanceCounts();
}
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
/**
* @brief methods with _ThreadUnsafe prefix are to implement in plugins
* or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
*/
void StartAsyncImpl() override {
// execute input pre-processing.
execDataPreprocessing(_inputs);
inferRequestIdx = plg->QueueInference(_inputs, _outputs);
// workaround to unblock callback-based flows
if (_callback) {
auto res = Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY);
std::exception_ptr exceptionPtr;
if (res != InferenceEngine::StatusCode::OK) {
try {
IE_EXCEPTION_SWITCH(res, ExceptionType,
InferenceEngine::details::ThrowNow<ExceptionType>{}
<<= std::stringstream{} << IE_LOCATION
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
} catch (...) {
exceptionPtr = std::current_exception();
}
}
_callback(exceptionPtr);
}
}
void StartAsyncImpl() override;
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override {
if (inferRequestIdx == -1) {
return InferenceEngine::INFER_NOT_STARTED;
} else if (millis_timeout < -1) {
IE_THROW(ParameterMismatch);
}
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
if (millis_timeout == InferenceEngine::InferRequest::WaitMode::RESULT_READY) {
millis_timeout = MAX_TIMEOUT;
}
const auto waitStatus = plg->WaitFor(inferRequestIdx, millis_timeout);
protected:
bool SetRequestIndex(uint32_t request_index);
bool IsRequestIndexValid();
bool IsRequestCompleted();
if (waitStatus == RequestStatus::kPending) {
// request is still pending so Wait() is needed once again
return InferenceEngine::RESULT_NOT_READY;
}
if (waitStatus == RequestStatus::kAborted) {
// need to preserve invalid state here to avoid next Wait() from clearing it
inferRequestIdx = -1;
return InferenceEngine::INFER_NOT_STARTED;
}
return InferenceEngine::OK;
}
private:
void CreateInferRequest();
InferenceEngine::StatusCode HandleRequestWaitStatus(const RequestStatus& request_status);
void ValidateAndConfigureTimeout(int64_t& millis_timeout);
void CallCleanupAndRethrowOnException(std::function<void()>&& function_to_invoke);
IE_SUPPRESS_DEPRECATED_START
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
auto pluginStates = plg->QueryState();
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
return plg->QueryState();
}
IE_SUPPRESS_DEPRECATED_END
static constexpr const uint32_t kRequestIndexInvalid = std::numeric_limits<uint32_t>::max();
static constexpr const uint32_t kRequestIndexCompleted = std::numeric_limits<uint32_t>::max() - 1;
uint32_t _infer_request_idx = kRequestIndexInvalid;
std::shared_ptr<GNAPlugin> plg;
};
} // namespace GNAPluginNS

View File

@ -1335,7 +1335,9 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
++inputNum;
}
freeWorker->enqueueRequest();
if (!freeWorker->enqueueRequest()) {
THROW_GNA_EXCEPTION << "Error with enqueueing inference request";
}
freeWorker->setResult(result);
@ -1351,7 +1353,13 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
}
bool GNAPlugin::Wait(uint32_t request_idx) {
return RequestStatus::kCompleted == WaitFor(request_idx, MAX_TIMEOUT);
auto result = WaitFor(request_idx, MAX_TIMEOUT);
if (result == RequestStatus::kCompletedWithError) {
THROW_GNA_EXCEPTION << "Error when waiting for inference results!";
}
return result == RequestStatus::kCompleted;
}
RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
@ -1368,6 +1376,10 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
const auto waitStatus = worker.wait(millisTimeout);
if (waitStatus == RequestStatus::kCompletedWithError) {
return waitStatus;
}
if (waitStatus == RequestStatus::kAborted) {
return waitStatus;
}
@ -1528,7 +1540,7 @@ bool GNAPlugin::Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob
}
bool GNAPlugin::Infer(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result) {
return Wait(QueueInference(input, result));
return Wait(QueueInference(input, result));
}
static InferenceEngine::Layout GetLayoutForDims(const InferenceEngine::SizeVector &dims) {

View File

@ -28,6 +28,7 @@ public:
* @param requestID id of request to be used for wait
* @param timeoutMilliseconds timeout of wait in milliseconds
* @return Status of subrequest @see GNAPluginNS::RequestStatus
*
*/
using WaitHandler = std::function<RequestStatus(uint32_t requestID, int64_t timeoutMilliseconds)>;
@ -42,8 +43,14 @@ public:
/**
* @brief Add subrequest to execution queue.
* @return true in case subrequest was properly enqueued, otherwise return false
*/
virtual void enqueue() = 0;
virtual bool enqueue() = 0;
/**
* @brief Finalize subrequest and set it status to RequestStatus::kNone
*/
virtual void cleanup() = 0;
/**
* @brief Return true if subrequest is pending, otherwise return false

View File

@ -7,6 +7,7 @@
#include <gna2-inference-api.h>
#include "log/debug.hpp"
#include "log/log.hpp"
namespace GNAPluginNS {
namespace request {
@ -24,14 +25,30 @@ RequestStatus SubrequestImpl::wait(int64_t timeoutMilliseconds) {
return status_;
}
status_ = waitHandler_(requestID_, timeoutMilliseconds);
try {
status_ = waitHandler_(requestID_, timeoutMilliseconds);
} catch (const std::exception& e) {
ov::intel_gna::log::error() << "Exception when executiong wait: " << e.what() << std::endl;
status_ = RequestStatus::kCompletedWithError;
}
return status_;
}
void SubrequestImpl::enqueue() {
requestID_ = enqueueHandler_();
status_ = RequestStatus::kPending;
bool SubrequestImpl::enqueue() {
try {
requestID_ = enqueueHandler_();
status_ = RequestStatus::kPending;
} catch (const std::exception& e) {
ov::intel_gna::log::error() << "Exception when executiong enqueue: " << e.what() << std::endl;
status_ = RequestStatus::kCompletedWithError;
}
return status_ != RequestStatus::kCompletedWithError;
}
void SubrequestImpl::cleanup() {
static_cast<void>(wait(0));
status_ = RequestStatus::kNone;
}
bool SubrequestImpl::isPending() const {

View File

@ -40,8 +40,14 @@ public:
/**
* @brief Add subrequest to execution queue.
* @return true in case subrequest was properly enqueued, otherwise return false
*/
void enqueue() override;
bool enqueue() override;
/**
* @brief Finalize subrequest and set it status to RequestStatus::kNone
*/
void cleanup() override;
/**
* @brief Return true if subrequest is pending, otherwise return false

View File

@ -39,15 +39,14 @@ public:
/**
* @brief Enqueue request to requests queue for contained model.
* @throw Exception in case worker is busy or if there was an issue with enqueue.
* @return true in case subrequest was properly enqueued, otherwise return false
*/
virtual void enqueueRequest() = 0;
virtual bool enqueueRequest() = 0;
/**
* @brief Wait untril request will be not finished for give timeout.
* @param timeoutMilliseconds timeout in milliseconds
* @return status of execution of ongoing request. @see GNAPluginNS::RequestStatus
* @throw Exception in case worker is busy or if there was an issue with enqueue.
*/
virtual RequestStatus wait(int64_t timeoutMilliseconds) = 0;

View File

@ -7,6 +7,7 @@
#include <gna2-inference-api.h>
#include "log/debug.hpp"
#include "log/log.hpp"
#include "model_wrapper.hpp"
#include "subrequest.hpp"
@ -39,12 +40,19 @@ Gna2Model* WorkerImpl::model() {
return &fullModel_->object();
}
void WorkerImpl::enqueueRequest() {
check_if_free();
bool WorkerImpl::enqueueRequest() {
if (!isFree()) {
ov::intel_gna::log::warning() << "Trying to propagate on busy request with id: " << representingIndex_;
return false;
}
for (auto& subrequest : modelSubrequests_) {
subrequest->enqueue();
if (!subrequest->enqueue()) {
cleanup_subrequests();
return false;
}
}
return true;
}
RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
@ -56,8 +64,13 @@ RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
continue;
}
if (subrequest->wait(timeoutMilliseconds) == RequestStatus::kPending) {
auto result = subrequest->wait(timeoutMilliseconds);
if (result == RequestStatus::kPending) {
pending = true;
} else if (result == RequestStatus::kCompletedWithError) {
cleanup_subrequests();
return result;
}
}
@ -107,9 +120,11 @@ InferenceEngine::BlobMap& WorkerImpl::result() {
return requestResult_;
}
void WorkerImpl::check_if_free() {
if (!isFree()) {
THROW_GNA_EXCEPTION << "Trying to propagte on busy request with id: " << representingIndex_;
void WorkerImpl::cleanup_subrequests() {
for (auto& subrequest : modelSubrequests_) {
if (subrequest->isPending()) {
subrequest->cleanup();
}
}
}

View File

@ -55,7 +55,7 @@ public:
/**
* @see Worker::enqueueRequest()
*/
void enqueueRequest() override;
bool enqueueRequest() override;
/**
* @see Worker::wait()
@ -93,7 +93,7 @@ public:
void setResult(InferenceEngine::BlobMap&& result) override;
private:
void check_if_free();
void cleanup_subrequests();
uint32_t representingIndex_{0};
std::shared_ptr<ModelWrapper> fullModel_;

View File

@ -10,10 +10,11 @@ namespace GNAPluginNS {
* @brief Enum representing status of request
*/
enum class RequestStatus {
kNone = 0, /// request was not initialized
kAborted = 1, /// request was aborted
kPending = 2, /// request was started and is onging
kCompleted = 3 /// request was completed with success
kNone = 0, /// request was not initialized
kAborted = 1, /// request was aborted
kPending = 2, /// request was started and is onging
kCompleted = 3, /// request was completed with success
kCompletedWithError = 4 /// request was completed with error
};
} // namespace GNAPluginNS

View File

@ -14,7 +14,7 @@ endif()
# TODO: fix CVS-71010 and remove BUILD_SHARED_LIBS
if(NOT BUILD_SHARED_LIBS)
set(exclude_path EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/(gna_api_stub|gna_wait_test|gna_export_import_test).cpp")
set(exclude_path EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/(gna_api_stub|gna_wait_test|gna_export_import_test|gna_infer_request_test).cpp")
endif()
addIeTargetTest(

View File

@ -138,6 +138,9 @@ GNA2_API enum Gna2Status Gna2RequestConfigSetAccelerationMode(
GNA2_API enum Gna2Status Gna2RequestEnqueue(
uint32_t requestConfigId,
uint32_t * requestId) {
if (current != nullptr) {
return current->Gna2RequestEnqueue(requestConfigId, requestId);
}
return Gna2StatusSuccess;
}

View File

@ -0,0 +1,204 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gna_infer_request.hpp"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <vector>
#include "any_copy.hpp"
#include "common_test_utils/data_utils.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gna_mock_api.hpp"
#include "gna_plugin.hpp"
#include "ngraph_functions/builders.hpp"
using namespace ::testing;
using namespace InferenceEngine;
using GNAPluginNS::GNAInferRequest;
using GNAPluginNS::GNAPlugin;
using ::testing::InSequence;
class GNAInferRequestTest : public ::testing::Test {
public:
IInferRequestInternal::Ptr CreateRequest() {
auto function = GetFunction();
CNNNetwork cnn_network = CNNNetwork{function};
SetExpectsForLoadNetworkAndShutDown(_data);
const ov::AnyMap gna_config = {ov::intel_gna::execution_mode(ov::intel_gna::ExecutionMode::SW_EXACT)};
auto plugin = std::make_shared<GNAPlugin>(any_copy(gna_config));
plugin->LoadNetwork(cnn_network);
return std::make_shared<GNAInferRequest>(plugin, cnn_network.getInputsInfo(), cnn_network.getOutputsInfo());
}
protected:
std::shared_ptr<ngraph::Function> GetFunction() {
auto ngPrc = ngraph::element::f32;
std::vector<size_t> shape = {1, 10};
auto params = ngraph::builder::makeParams(ngPrc, {shape});
auto shape_size = ov::shape_size(shape);
auto add_const =
ngraph::builder::makeConstant<float>(ngPrc,
shape,
CommonTestUtils::generate_float_numbers(shape_size, -0.5f, 0.5f),
false);
auto add = std::make_shared<ngraph::opset9::Add>(params[0], add_const);
auto res = std::make_shared<ngraph::op::Result>(add);
auto function = std::make_shared<ngraph::Function>(res, params, "Add");
return function;
}
void SetExpectsForLoadNetworkAndShutDown(std::vector<std::vector<uint8_t>>& data) {
EXPECT_CALL(*_mock_api, Gna2MemoryAlloc(_, _, _))
.Times(AtLeast(1))
.WillRepeatedly(Invoke([&data](uint32_t size_requested, uint32_t* size_granted, void** memory_address) {
data.push_back(std::vector<uint8_t>(size_requested));
*size_granted = size_requested;
*memory_address = data.back().data();
return Gna2StatusSuccess;
}));
EXPECT_CALL(*_mock_api, Gna2DeviceGetVersion(_, _))
.WillOnce(Invoke([](uint32_t deviceIndex, enum Gna2DeviceVersion* deviceVersion) {
*deviceVersion = Gna2DeviceVersionSoftwareEmulation;
return Gna2StatusSuccess;
}));
EXPECT_CALL(*_mock_api, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
EXPECT_CALL(*_mock_api, Gna2GetLibraryVersion(_, _))
.Times(AtLeast(0))
.WillRepeatedly(Return(Gna2StatusSuccess));
EXPECT_CALL(*_mock_api, Gna2InstrumentationConfigCreate(_, _, _, _)).WillOnce(Return(Gna2StatusSuccess));
EXPECT_CALL(*_mock_api, Gna2ModelCreate(_, _, _))
.WillOnce(Invoke([](uint32_t deviceIndex, struct Gna2Model const* model, uint32_t* model_id) {
*model_id = 0;
return Gna2StatusSuccess;
}));
EXPECT_CALL(*_mock_api, Gna2RequestConfigCreate(_, _))
.WillOnce(Invoke([](uint32_t model_Id, uint32_t* request_config_id) {
*request_config_id = 0;
return Gna2StatusSuccess;
}));
EXPECT_CALL(*_mock_api, Gna2InstrumentationConfigAssignToRequestConfig(_, _))
.Times(AtLeast(1))
.WillRepeatedly(Return(Gna2StatusSuccess));
InSequence seq;
EXPECT_CALL(*_mock_api, Gna2DeviceClose(_)).WillOnce(Return(Gna2StatusSuccess));
EXPECT_CALL(*_mock_api, Gna2MemoryFree(_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
}
void SetExpectOnEnqueue(const Gna2Status& return_status = Gna2StatusSuccess) {
EXPECT_CALL(*_mock_api, Gna2RequestEnqueue(_, _)).WillOnce(Return(return_status));
}
void SetExpectOnWait(const Gna2Status& return_status = Gna2StatusSuccess) {
EXPECT_CALL(*_mock_api, Gna2RequestWait(_, _)).WillOnce(Return(return_status));
}
void SetUp() override {
_mock_api = std::make_shared<StrictMock<GNACppApi>>();
}
void TearDown() override {
ASSERT_TRUE(Mock::VerifyAndClearExpectations(_mock_api.get()));
}
std::shared_ptr<StrictMock<GNACppApi>> _mock_api;
std::vector<std::vector<uint8_t>> _data;
};
TEST_F(GNAInferRequestTest, start_async) {
auto request = CreateRequest();
SetExpectOnEnqueue();
// wait on shutdown neede if request was enqueued by not waited
SetExpectOnWait();
EXPECT_NO_THROW(request->StartAsync());
}
TEST_F(GNAInferRequestTest, start_async_with_enqueue_error) {
auto request = CreateRequest();
// trigger Gna2RequestEnqueue to fail
SetExpectOnEnqueue(Gna2StatusUnknownError);
// no wait needed due the fact there are no enqueud requests
EXPECT_THROW(request->StartAsync(), std::exception);
}
TEST_F(GNAInferRequestTest, start_async_with_wait) {
auto request = CreateRequest();
SetExpectOnEnqueue();
// wait on shutdown needed if request was enqueued by not waited
SetExpectOnWait();
EXPECT_NO_THROW(request->StartAsync());
EXPECT_EQ(OK, request->Wait(0));
}
TEST_F(GNAInferRequestTest, start_async_error_with_wait) {
auto request = CreateRequest();
SetExpectOnEnqueue(Gna2StatusUnknownError);
// wait on shutdown needed if request was enqueued by not waited
// SetExpectOnWait();
EXPECT_THROW(request->StartAsync(), std::exception);
EXPECT_EQ(INFER_NOT_STARTED, request->Wait(0));
}
TEST_F(GNAInferRequestTest, start_async_with_wait_error) {
auto request = CreateRequest();
SetExpectOnEnqueue();
// wait on shutdown needed if request was enqueued by not waited
SetExpectOnWait(Gna2StatusUnknownError);
EXPECT_NO_THROW(request->StartAsync());
EXPECT_THROW(request->Wait(0), std::exception);
}
TEST_F(GNAInferRequestTest, start_async_wait_check_recovery_after_wait_error) {
auto request = CreateRequest();
SetExpectOnEnqueue();
SetExpectOnWait(Gna2StatusUnknownError);
EXPECT_NO_THROW(request->StartAsync());
EXPECT_THROW(request->Wait(0), std::exception);
// check that no there is exception on second wiat after first failing
EXPECT_EQ(INFER_NOT_STARTED, request->Wait(0));
// start new request
SetExpectOnEnqueue();
SetExpectOnWait();
EXPECT_NO_THROW(request->StartAsync());
EXPECT_EQ(OK, request->Wait(0));
}
TEST_F(GNAInferRequestTest, infer) {
auto request = CreateRequest();
SetExpectOnEnqueue();
SetExpectOnWait();
EXPECT_NO_THROW(request->Infer());
}
TEST_F(GNAInferRequestTest, infer_enque_error) {
auto request = CreateRequest();
SetExpectOnEnqueue(Gna2StatusUnknownError);
EXPECT_THROW(request->Infer(), std::exception);
}
TEST_F(GNAInferRequestTest, infer_wait_error_check_recovery) {
auto request = CreateRequest();
SetExpectOnEnqueue();
SetExpectOnWait(Gna2StatusUnknownError);
EXPECT_THROW(request->Infer(), std::exception);
// check if next infer will execute properly after wait throwing
SetExpectOnEnqueue();
SetExpectOnWait();
EXPECT_NO_THROW(request->Infer());
}

View File

@ -47,6 +47,10 @@ public:
uint32_t requestId,
uint32_t timeoutMilliseconds));
MOCK_METHOD2(Gna2RequestEnqueue, Gna2Status(
uint32_t requestConfigId,
uint32_t* requestId));
MOCK_METHOD2(Gna2DeviceGetVersion, Gna2Status(
uint32_t deviceIndex,
enum Gna2DeviceVersion * deviceVersion));

View File

@ -10,14 +10,15 @@
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
#include "gna_infer_request.hpp"
#include "gna_mock_api.hpp"
#include "gna_plugin.hpp"
#include "request/model_wrapper_factory.hpp"
#include "request/subrequest_impl.hpp"
#include "request/worker_factory.hpp"
#include "request/worker_impl.hpp"
#include "request/worker_pool.hpp"
using GNAPluginNS::GNAInferRequest;
using GNAPluginNS::GNAPlugin;
using namespace GNAPluginNS;
using namespace GNAPluginNS::request;
using ::testing::_;
using ::testing::Return;
@ -27,9 +28,6 @@ class GNAPluginForGNAWaitTest : public GNAPlugin {
public:
// Prepare underlining object to enable GNAInferRequest::Wait() working
GNAPluginForGNAWaitTest() {
using namespace GNAPluginNS;
using namespace request;
InferenceEngine::TensorDesc td{InferenceEngine::Precision::FP32, {1, 1}, InferenceEngine::Layout::HW};
auto fakeInfo = std::make_shared<InferenceEngine::InputInfo>();
auto fakePtr = std::make_shared<InferenceEngine::Data>("fakeName", td);
@ -55,20 +53,32 @@ public:
auto model = ModelWrapperFactory::createWithNumberOfEmptyOperations(1);
subrequests.push_back(std::make_shared<SubrequestImpl>(std::move(enqueue), std::move(wait)));
auto worker = std::make_shared<WorkerImpl>(model, std::move(subrequests));
_worker = std::make_shared<WorkerImpl>(model, std::move(subrequests));
requestWorkerPool_->addModelWorker(worker);
worker->enqueueRequest();
requestWorkerPool_->addModelWorker(_worker);
}
void EnqueTestRequest() {
_worker->enqueueRequest();
}
private:
std::shared_ptr<Worker> _worker;
};
class GNAInferRequestForGNAWaitTest : public GNAInferRequest {
public:
// Prepare underlining object to enable Wait() working
GNAInferRequestForGNAWaitTest(std::shared_ptr<GNAPlugin> plugin)
: GNAInferRequest{plugin, plugin->GetNetworkInputs(), plugin->GetNetworkOutputs()} {
inferRequestIdx = 0;
GNAInferRequestForGNAWaitTest(std::shared_ptr<GNAPluginForGNAWaitTest> plugin)
: GNAInferRequest{plugin, plugin->GetNetworkInputs(), plugin->GetNetworkOutputs()},
_plugin(plugin) {}
void EnqueTestRequest() {
_plugin->EnqueTestRequest();
SetRequestIndex(0);
}
std::shared_ptr<GNAPluginForGNAWaitTest> _plugin;
};
TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
@ -76,6 +86,7 @@ TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDriverQoSTimeoutExceeded));
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
GNAInferRequestForGNAWaitTest inferRequest{plugin};
inferRequest.EnqueTestRequest();
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
}
@ -84,6 +95,38 @@ TEST_F(GNAWaitTest, ReturnsGna2StatusWarningDeviceBusy) {
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusWarningDeviceBusy));
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
GNAInferRequestForGNAWaitTest inferRequest{plugin};
inferRequest.EnqueTestRequest();
ASSERT_EQ(InferenceEngine::RESULT_NOT_READY, inferRequest.Wait(0));
}
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange) {
GNACppApi enableMocks;
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
GNAInferRequestForGNAWaitTest inferRequest{plugin};
inferRequest.EnqueTestRequest();
ASSERT_THROW(inferRequest.Wait(0), std::exception);
}
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange_Extra_Sync) {
GNACppApi enableMocks;
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
GNAInferRequestForGNAWaitTest inferRequest{plugin};
inferRequest.EnqueTestRequest();
ASSERT_THROW(inferRequest.Wait(0), std::exception);
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(0);
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
}
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange_Another_Use) {
GNACppApi enableMocks;
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
GNAInferRequestForGNAWaitTest inferRequest{plugin};
inferRequest.EnqueTestRequest();
ASSERT_THROW(inferRequest.Wait(0), std::exception);
inferRequest.EnqueTestRequest();
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusSuccess));
ASSERT_EQ(InferenceEngine::OK, inferRequest.Wait(0));
}

View File

@ -60,16 +60,40 @@ TEST_F(GNA_Request_WorkerImplTest, enqueueRequest) {
subrequests.push_back(subrequestMock2);
WorkerImpl worker(wrapper, subrequests);
// check if exception will be thrown if worker will be busy - at least one subrequest is pending.
// check if false will be returned if worker will be busy - at least one subrequest is pending.
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
EXPECT_THROW(worker.enqueueRequest(), std::exception);
EXPECT_FALSE(worker.enqueueRequest());
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1);
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1);
EXPECT_NO_THROW(worker.enqueueRequest());
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1).WillOnce(Return(true));
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1).WillOnce(Return(true));
EXPECT_TRUE(worker.enqueueRequest());
}
TEST_F(GNA_Request_WorkerImplTest, enqueueRequest_second_enque_failed) {
auto wrapper = ModelWrapperFactory::createTrivial();
std::vector<std::shared_ptr<Subrequest>> subrequests;
auto subrequestMock1 = std::make_shared<MockSubrequest>();
subrequests.push_back(subrequestMock1);
auto subrequestMock2 = std::make_shared<MockSubrequest>();
subrequests.push_back(subrequestMock2);
WorkerImpl worker(wrapper, subrequests);
// check if exception will be thrown if worker will be busy - at least one subrequest is pending.
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
EXPECT_FALSE(worker.enqueueRequest());
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(2).WillOnce(Return(false)).WillOnce(Return(true));
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(2).WillOnce(Return(false)).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1).WillOnce(Return(true));
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock1.get(), cleanup()).Times(1);
EXPECT_CALL(*subrequestMock2.get(), cleanup()).Times(0);
EXPECT_FALSE(worker.enqueueRequest());
}
TEST_F(GNA_Request_WorkerImplTest, wait) {
@ -107,6 +131,13 @@ TEST_F(GNA_Request_WorkerImplTest, wait) {
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock1.get(), isAborted()).Times(1).WillOnce(Return(true));
EXPECT_EQ(RequestStatus::kAborted, worker.wait(referenceTimeout));
// subrequest enuqued and completed with error
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(2).WillOnce(Return(true)).WillOnce(Return(false));
EXPECT_CALL(*subrequestMock1.get(), wait(referenceTimeout))
.Times(1)
.WillOnce(Return(RequestStatus::kCompletedWithError));
EXPECT_EQ(RequestStatus::kCompletedWithError, worker.wait(referenceTimeout));
}
TEST_F(GNA_Request_WorkerImplTest, isFree) {

View File

@ -13,7 +13,8 @@ namespace request {
class MockSubrequest : public Subrequest {
public:
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
MOCK_METHOD(void, enqueue, (), (override));
MOCK_METHOD(bool, enqueue, (), (override));
MOCK_METHOD(void, cleanup, (), (override));
MOCK_METHOD(bool, isPending, (), (const, override));
MOCK_METHOD(bool, isAborted, (), (const, override));
MOCK_METHOD(bool, isCompleted, (), (const, override));

View File

@ -14,7 +14,7 @@ class MockWorker : public Worker {
public:
MOCK_METHOD(Gna2Model*, model, (), (override));
MOCK_METHOD(const Gna2Model*, model, (), (const, override));
MOCK_METHOD(void, enqueueRequest, (), (override));
MOCK_METHOD(bool, enqueueRequest, (), (override));
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
MOCK_METHOD(bool, isFree, (), (const, override));
MOCK_METHOD(uint32_t, representingIndex, (), (const, override));

View File

@ -36,7 +36,6 @@ public:
if (!legacy_str.empty()) {
name += "," + legacy_str;
}
std::cout << "name: " << name << std::endl;
return name;
}
void SetUp() override {