[GNA] Fix exception handling on wait and infer (#14578)
* [GNA] Move definition of GNAInferRequest class to source file * [GNA] Fix handling exception for infer() and wait() gna_infer_request * fixed handling of exceptions for wait and infer gna_infer_request * fixed closing ongoing subrequest of divided model in case of exception on on enqueueing request or waiting for end of request. * [GNA] Apply review comments, Removed exceptions from enqueue and wait for Worker * changed API of request worker to return: * erorr in case wait failed instead of throw, * return true/false for enqueue instead of exception for failure case * [GNA] Fix review commentd related to gna_infer_request and worker_impl * [GNA] Add tests for GANInferRequest * added tests for exception handling of Infer, Wait, StartAsync methods of GNAInferRequest class. * [GNA] Added final fixes for review comments.
This commit is contained in:
parent
bc69385093
commit
6bca87a88a
@ -475,13 +475,16 @@ GNAPluginNS::RequestStatus GNADeviceHelper::waitForRequest(uint32_t requestID, i
|
|||||||
if (status == Gna2StatusDriverQoSTimeoutExceeded) {
|
if (status == Gna2StatusDriverQoSTimeoutExceeded) {
|
||||||
return GNAPluginNS::RequestStatus::kAborted;
|
return GNAPluginNS::RequestStatus::kAborted;
|
||||||
}
|
}
|
||||||
checkGna2Status(status, "Gna2RequestWait");
|
|
||||||
|
|
||||||
if (per_request_diagnostics) {
|
if (per_request_diagnostics) {
|
||||||
dumpAllAllocations(debugLogIndexRequestWait, "AfterGna2RequestWait");
|
dumpAllAllocations(debugLogIndexRequestWait, "AfterGna2RequestWait");
|
||||||
debugLogIndexRequestWait++;
|
debugLogIndexRequestWait++;
|
||||||
}
|
}
|
||||||
updateGnaPerfCounters();
|
updateGnaPerfCounters();
|
||||||
|
|
||||||
|
// handle error case after updating statistics data.
|
||||||
|
checkGna2Status(status, "Gna2RequestWait");
|
||||||
|
|
||||||
return GNAPluginNS::RequestStatus::kCompleted;
|
return GNAPluginNS::RequestStatus::kCompleted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
173
src/plugins/intel_gna/src/gna_infer_request.cpp
Normal file
173
src/plugins/intel_gna/src/gna_infer_request.cpp
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "gna_infer_request.hpp"
|
||||||
|
|
||||||
|
#include "gna_plugin.hpp"
|
||||||
|
|
||||||
|
namespace GNAPluginNS {
|
||||||
|
|
||||||
|
GNAInferRequest::GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||||
|
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
|
||||||
|
const std::vector<std::shared_ptr<const ov::Node>>& outputs)
|
||||||
|
: InferenceEngine::IInferRequestInternal(inputs, outputs),
|
||||||
|
plg(plg) {
|
||||||
|
CreateInferRequest();
|
||||||
|
}
|
||||||
|
|
||||||
|
GNAInferRequest::GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||||
|
InferenceEngine::InputsDataMap networkInputs,
|
||||||
|
InferenceEngine::OutputsDataMap networkOutputs)
|
||||||
|
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs),
|
||||||
|
plg(plg) {
|
||||||
|
CreateInferRequest();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GNAInferRequest::InferImpl() {
|
||||||
|
// execute input pre-processing.
|
||||||
|
execDataPreprocessing(_inputs);
|
||||||
|
// result returned from sync infer wait method
|
||||||
|
|
||||||
|
auto infer_call = [&]() {
|
||||||
|
auto result = plg->Infer(_inputs, _outputs);
|
||||||
|
// if result is false we are dealing with QoS feature and set kRequestIndexInvalid
|
||||||
|
// if result is ok we set kRequestIndexCompleted to not execute request if it is not
|
||||||
|
// in the queue.
|
||||||
|
auto result_request_index = result ? kRequestIndexCompleted : kRequestIndexInvalid;
|
||||||
|
SetRequestIndex(result_request_index);
|
||||||
|
};
|
||||||
|
|
||||||
|
CallCleanupAndRethrowOnException(std::move(infer_call));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GNAInferRequest::GetPerformanceCounts() const {
|
||||||
|
return plg->GetPerformanceCounts();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GNAInferRequest::StartAsyncImpl() {
|
||||||
|
// execute input pre-processing.
|
||||||
|
execDataPreprocessing(_inputs);
|
||||||
|
|
||||||
|
auto queue_call = [&]() {
|
||||||
|
SetRequestIndex(plg->QueueInference(_inputs, _outputs));
|
||||||
|
};
|
||||||
|
|
||||||
|
CallCleanupAndRethrowOnException(std::move(queue_call));
|
||||||
|
|
||||||
|
// workaround to unblock callback-based flows
|
||||||
|
if (_callback) {
|
||||||
|
auto res = Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY);
|
||||||
|
std::exception_ptr exceptionPtr;
|
||||||
|
if (res != InferenceEngine::StatusCode::OK) {
|
||||||
|
try {
|
||||||
|
IE_EXCEPTION_SWITCH(res,
|
||||||
|
ExceptionType,
|
||||||
|
InferenceEngine::details::ThrowNow<ExceptionType>{} <<=
|
||||||
|
std::stringstream{}
|
||||||
|
<< IE_LOCATION
|
||||||
|
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
|
||||||
|
} catch (...) {
|
||||||
|
exceptionPtr = std::current_exception();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_callback(exceptionPtr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::StatusCode GNAInferRequest::Wait(int64_t millis_timeout) {
|
||||||
|
if (!IsRequestIndexValid()) {
|
||||||
|
return InferenceEngine::INFER_NOT_STARTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValidateAndConfigureTimeout(millis_timeout);
|
||||||
|
|
||||||
|
if (IsRequestCompleted()) {
|
||||||
|
return InferenceEngine::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto waitStatus = RequestStatus::kAborted;
|
||||||
|
auto wait_call = [&]() {
|
||||||
|
waitStatus = plg->WaitFor(_infer_request_idx, millis_timeout);
|
||||||
|
};
|
||||||
|
CallCleanupAndRethrowOnException(std::move(wait_call));
|
||||||
|
|
||||||
|
return HandleRequestWaitStatus(waitStatus);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> GNAInferRequest::QueryState() {
|
||||||
|
auto pluginStates = plg->QueryState();
|
||||||
|
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
||||||
|
return plg->QueryState();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GNAInferRequest::IsRequestIndexValid() {
|
||||||
|
return _infer_request_idx != kRequestIndexInvalid;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GNAInferRequest::IsRequestCompleted() {
|
||||||
|
return _infer_request_idx == kRequestIndexCompleted;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GNAInferRequest::SetRequestIndex(uint32_t request_index) {
|
||||||
|
return _infer_request_idx = request_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GNAInferRequest::ValidateAndConfigureTimeout(int64_t& millis_timeout) {
|
||||||
|
if (millis_timeout == InferenceEngine::InferRequest::WaitMode::RESULT_READY) {
|
||||||
|
millis_timeout = MAX_TIMEOUT;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (millis_timeout < 0) {
|
||||||
|
IE_THROW(ParameterMismatch) << "Invalid timeout value in milliseconds: " << millis_timeout << "!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::StatusCode GNAInferRequest::HandleRequestWaitStatus(const RequestStatus& request_status) {
|
||||||
|
if (request_status == RequestStatus::kPending) {
|
||||||
|
// request is still pending so Wait() is needed once again
|
||||||
|
return InferenceEngine::RESULT_NOT_READY;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (request_status == RequestStatus::kAborted) {
|
||||||
|
// need to preserve invalid state here to avoid next Wait() from clearing it
|
||||||
|
SetRequestIndex(kRequestIndexInvalid);
|
||||||
|
return InferenceEngine::INFER_NOT_STARTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (request_status == RequestStatus::kCompletedWithError) {
|
||||||
|
SetRequestIndex(kRequestIndexInvalid);
|
||||||
|
THROW_GNA_EXCEPTION << "Error when waiting for inference results!";
|
||||||
|
}
|
||||||
|
|
||||||
|
return InferenceEngine::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GNAInferRequest::CallCleanupAndRethrowOnException(std::function<void()>&& function_to_invoke) {
|
||||||
|
try {
|
||||||
|
function_to_invoke();
|
||||||
|
} catch (...) {
|
||||||
|
// need to preserve invalid state here to avoid next Wait() from clearing it
|
||||||
|
// and next rethrow issue.
|
||||||
|
SetRequestIndex(kRequestIndexInvalid);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GNAInferRequest::CreateInferRequest() {
|
||||||
|
// TODO: internal connection API - better to generalize
|
||||||
|
if (_networkOutputs.empty()) {
|
||||||
|
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
|
||||||
|
for (auto output : _networkOutputs) {
|
||||||
|
_outputs[output.first] = plg->GetOutputBlob(output.first, output.second->getTensorDesc().getPrecision());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto input : _networkInputs) {
|
||||||
|
_inputs[input.first] = plg->GetInputBlob(input.first, input.second->getTensorDesc().getPrecision());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace GNAPluginNS
|
@ -5,137 +5,59 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||||
#include "cpp/ie_infer_request.hpp"
|
#include "request_status.hpp"
|
||||||
#include "gna_plugin.hpp"
|
|
||||||
|
|
||||||
namespace GNAPluginNS {
|
namespace GNAPluginNS {
|
||||||
|
class GNAPlugin;
|
||||||
|
|
||||||
class GNAInferRequest : public InferenceEngine::IInferRequestInternal {
|
class GNAInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||||
private:
|
public:
|
||||||
void CreateInferRequest() {
|
|
||||||
// TODO: internal connection API - better to generalize
|
|
||||||
if (_networkOutputs.empty()) {
|
|
||||||
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
|
|
||||||
for (auto output : _networkOutputs) {
|
|
||||||
_outputs[output.first] =
|
|
||||||
plg->GetOutputBlob(output.first, output.second->getTensorDesc().getPrecision());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto input : _networkInputs) {
|
|
||||||
_inputs[input.first] =
|
|
||||||
plg->GetInputBlob(input.first, input.second->getTensorDesc().getPrecision());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
std::shared_ptr<GNAPlugin> plg;
|
|
||||||
uint32_t inferRequestIdx = -1;
|
|
||||||
|
|
||||||
public:
|
|
||||||
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||||
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
|
const std::vector<std::shared_ptr<const ov::Node>>& inputs,
|
||||||
const std::vector<std::shared_ptr<const ov::Node>>& outputs)
|
const std::vector<std::shared_ptr<const ov::Node>>& outputs);
|
||||||
: InferenceEngine::IInferRequestInternal(inputs, outputs), plg(plg) {
|
|
||||||
CreateInferRequest();
|
|
||||||
}
|
|
||||||
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||||
InferenceEngine::InputsDataMap networkInputs,
|
InferenceEngine::InputsDataMap network_inputs,
|
||||||
InferenceEngine::OutputsDataMap networkOutputs)
|
InferenceEngine::OutputsDataMap network_outputs);
|
||||||
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
|
|
||||||
CreateInferRequest();
|
|
||||||
}
|
|
||||||
/**
|
/**
|
||||||
* @brief Infers specified input(s) in synchronous mode
|
* @brief Infers specified input(s) in synchronous mode
|
||||||
* @note blocks all method of InferRequest while request is ongoing (running or waiting in queue)
|
* @note blocks all method of InferRequest while request is ongoing (running or waiting in queue)
|
||||||
*/
|
*/
|
||||||
void InferImpl() override {
|
void InferImpl() override;
|
||||||
// execute input pre-processing.
|
|
||||||
execDataPreprocessing(_inputs);
|
|
||||||
// result returned from sync infer wait method
|
|
||||||
auto result = plg->Infer(_inputs, _outputs);
|
|
||||||
|
|
||||||
// if result is false we are dealing with QoS feature
|
|
||||||
// if result is ok, next call to wait() will return Ok, if request not in gna_queue
|
|
||||||
if (!result) {
|
|
||||||
inferRequestIdx = -1;
|
|
||||||
} else {
|
|
||||||
inferRequestIdx = -2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
||||||
* Note: not all plugins may provide meaningful data
|
* Note: not all plugins may provide meaningful data
|
||||||
* @param perfMap - a map of layer names to profiling information for that layer.
|
* @param perfMap - a map of layer names to profiling information for that layer.
|
||||||
*/
|
*/
|
||||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override {
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||||
return plg->GetPerformanceCounts();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief methods with _ThreadUnsafe prefix are to implement in plugins
|
* @brief methods with _ThreadUnsafe prefix are to implement in plugins
|
||||||
* or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
|
* or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
|
||||||
*/
|
*/
|
||||||
void StartAsyncImpl() override {
|
void StartAsyncImpl() override;
|
||||||
// execute input pre-processing.
|
|
||||||
execDataPreprocessing(_inputs);
|
|
||||||
inferRequestIdx = plg->QueueInference(_inputs, _outputs);
|
|
||||||
// workaround to unblock callback-based flows
|
|
||||||
if (_callback) {
|
|
||||||
auto res = Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY);
|
|
||||||
std::exception_ptr exceptionPtr;
|
|
||||||
if (res != InferenceEngine::StatusCode::OK) {
|
|
||||||
try {
|
|
||||||
IE_EXCEPTION_SWITCH(res, ExceptionType,
|
|
||||||
InferenceEngine::details::ThrowNow<ExceptionType>{}
|
|
||||||
<<= std::stringstream{} << IE_LOCATION
|
|
||||||
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
|
|
||||||
} catch (...) {
|
|
||||||
exceptionPtr = std::current_exception();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_callback(exceptionPtr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
|
||||||
|
|
||||||
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override {
|
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
|
||||||
if (inferRequestIdx == -1) {
|
|
||||||
return InferenceEngine::INFER_NOT_STARTED;
|
|
||||||
} else if (millis_timeout < -1) {
|
|
||||||
IE_THROW(ParameterMismatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (millis_timeout == InferenceEngine::InferRequest::WaitMode::RESULT_READY) {
|
protected:
|
||||||
millis_timeout = MAX_TIMEOUT;
|
bool SetRequestIndex(uint32_t request_index);
|
||||||
}
|
bool IsRequestIndexValid();
|
||||||
const auto waitStatus = plg->WaitFor(inferRequestIdx, millis_timeout);
|
bool IsRequestCompleted();
|
||||||
|
|
||||||
if (waitStatus == RequestStatus::kPending) {
|
private:
|
||||||
// request is still pending so Wait() is needed once again
|
void CreateInferRequest();
|
||||||
return InferenceEngine::RESULT_NOT_READY;
|
InferenceEngine::StatusCode HandleRequestWaitStatus(const RequestStatus& request_status);
|
||||||
}
|
void ValidateAndConfigureTimeout(int64_t& millis_timeout);
|
||||||
if (waitStatus == RequestStatus::kAborted) {
|
void CallCleanupAndRethrowOnException(std::function<void()>&& function_to_invoke);
|
||||||
// need to preserve invalid state here to avoid next Wait() from clearing it
|
|
||||||
inferRequestIdx = -1;
|
|
||||||
return InferenceEngine::INFER_NOT_STARTED;
|
|
||||||
}
|
|
||||||
return InferenceEngine::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
static constexpr const uint32_t kRequestIndexInvalid = std::numeric_limits<uint32_t>::max();
|
||||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
static constexpr const uint32_t kRequestIndexCompleted = std::numeric_limits<uint32_t>::max() - 1;
|
||||||
auto pluginStates = plg->QueryState();
|
|
||||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
uint32_t _infer_request_idx = kRequestIndexInvalid;
|
||||||
return plg->QueryState();
|
std::shared_ptr<GNAPlugin> plg;
|
||||||
}
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
};
|
};
|
||||||
} // namespace GNAPluginNS
|
} // namespace GNAPluginNS
|
||||||
|
@ -1335,7 +1335,9 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
|
|||||||
++inputNum;
|
++inputNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
freeWorker->enqueueRequest();
|
if (!freeWorker->enqueueRequest()) {
|
||||||
|
THROW_GNA_EXCEPTION << "Error with enqueueing inference request";
|
||||||
|
}
|
||||||
|
|
||||||
freeWorker->setResult(result);
|
freeWorker->setResult(result);
|
||||||
|
|
||||||
@ -1351,7 +1353,13 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GNAPlugin::Wait(uint32_t request_idx) {
|
bool GNAPlugin::Wait(uint32_t request_idx) {
|
||||||
return RequestStatus::kCompleted == WaitFor(request_idx, MAX_TIMEOUT);
|
auto result = WaitFor(request_idx, MAX_TIMEOUT);
|
||||||
|
|
||||||
|
if (result == RequestStatus::kCompletedWithError) {
|
||||||
|
THROW_GNA_EXCEPTION << "Error when waiting for inference results!";
|
||||||
|
}
|
||||||
|
|
||||||
|
return result == RequestStatus::kCompleted;
|
||||||
}
|
}
|
||||||
|
|
||||||
RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||||
@ -1368,6 +1376,10 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
|||||||
|
|
||||||
const auto waitStatus = worker.wait(millisTimeout);
|
const auto waitStatus = worker.wait(millisTimeout);
|
||||||
|
|
||||||
|
if (waitStatus == RequestStatus::kCompletedWithError) {
|
||||||
|
return waitStatus;
|
||||||
|
}
|
||||||
|
|
||||||
if (waitStatus == RequestStatus::kAborted) {
|
if (waitStatus == RequestStatus::kAborted) {
|
||||||
return waitStatus;
|
return waitStatus;
|
||||||
}
|
}
|
||||||
|
@ -28,6 +28,7 @@ public:
|
|||||||
* @param requestID id of request to be used for wait
|
* @param requestID id of request to be used for wait
|
||||||
* @param timeoutMilliseconds timeout of wait in milliseconds
|
* @param timeoutMilliseconds timeout of wait in milliseconds
|
||||||
* @return Status of subrequest @see GNAPluginNS::RequestStatus
|
* @return Status of subrequest @see GNAPluginNS::RequestStatus
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
using WaitHandler = std::function<RequestStatus(uint32_t requestID, int64_t timeoutMilliseconds)>;
|
using WaitHandler = std::function<RequestStatus(uint32_t requestID, int64_t timeoutMilliseconds)>;
|
||||||
|
|
||||||
@ -42,8 +43,14 @@ public:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add subrequest to execution queue.
|
* @brief Add subrequest to execution queue.
|
||||||
|
* @return true in case subrequest was properly enqueued, otherwise return false
|
||||||
*/
|
*/
|
||||||
virtual void enqueue() = 0;
|
virtual bool enqueue() = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Finalize subrequest and set it status to RequestStatus::kNone
|
||||||
|
*/
|
||||||
|
virtual void cleanup() = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return true if subrequest is pending, otherwise return false
|
* @brief Return true if subrequest is pending, otherwise return false
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <gna2-inference-api.h>
|
#include <gna2-inference-api.h>
|
||||||
|
|
||||||
#include "log/debug.hpp"
|
#include "log/debug.hpp"
|
||||||
|
#include "log/log.hpp"
|
||||||
|
|
||||||
namespace GNAPluginNS {
|
namespace GNAPluginNS {
|
||||||
namespace request {
|
namespace request {
|
||||||
@ -24,14 +25,30 @@ RequestStatus SubrequestImpl::wait(int64_t timeoutMilliseconds) {
|
|||||||
return status_;
|
return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
status_ = waitHandler_(requestID_, timeoutMilliseconds);
|
status_ = waitHandler_(requestID_, timeoutMilliseconds);
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
ov::intel_gna::log::error() << "Exception when executiong wait: " << e.what() << std::endl;
|
||||||
|
status_ = RequestStatus::kCompletedWithError;
|
||||||
|
}
|
||||||
|
|
||||||
return status_;
|
return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SubrequestImpl::enqueue() {
|
bool SubrequestImpl::enqueue() {
|
||||||
|
try {
|
||||||
requestID_ = enqueueHandler_();
|
requestID_ = enqueueHandler_();
|
||||||
status_ = RequestStatus::kPending;
|
status_ = RequestStatus::kPending;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
ov::intel_gna::log::error() << "Exception when executiong enqueue: " << e.what() << std::endl;
|
||||||
|
status_ = RequestStatus::kCompletedWithError;
|
||||||
|
}
|
||||||
|
return status_ != RequestStatus::kCompletedWithError;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SubrequestImpl::cleanup() {
|
||||||
|
static_cast<void>(wait(0));
|
||||||
|
status_ = RequestStatus::kNone;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SubrequestImpl::isPending() const {
|
bool SubrequestImpl::isPending() const {
|
||||||
|
@ -40,8 +40,14 @@ public:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add subrequest to execution queue.
|
* @brief Add subrequest to execution queue.
|
||||||
|
* @return true in case subrequest was properly enqueued, otherwise return false
|
||||||
*/
|
*/
|
||||||
void enqueue() override;
|
bool enqueue() override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Finalize subrequest and set it status to RequestStatus::kNone
|
||||||
|
*/
|
||||||
|
void cleanup() override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return true if subrequest is pending, otherwise return false
|
* @brief Return true if subrequest is pending, otherwise return false
|
||||||
|
@ -39,15 +39,14 @@ public:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Enqueue request to requests queue for contained model.
|
* @brief Enqueue request to requests queue for contained model.
|
||||||
* @throw Exception in case worker is busy or if there was an issue with enqueue.
|
* @return true in case subrequest was properly enqueued, otherwise return false
|
||||||
*/
|
*/
|
||||||
virtual void enqueueRequest() = 0;
|
virtual bool enqueueRequest() = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Wait untril request will be not finished for give timeout.
|
* @brief Wait untril request will be not finished for give timeout.
|
||||||
* @param timeoutMilliseconds timeout in milliseconds
|
* @param timeoutMilliseconds timeout in milliseconds
|
||||||
* @return status of execution of ongoing request. @see GNAPluginNS::RequestStatus
|
* @return status of execution of ongoing request. @see GNAPluginNS::RequestStatus
|
||||||
* @throw Exception in case worker is busy or if there was an issue with enqueue.
|
|
||||||
*/
|
*/
|
||||||
virtual RequestStatus wait(int64_t timeoutMilliseconds) = 0;
|
virtual RequestStatus wait(int64_t timeoutMilliseconds) = 0;
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <gna2-inference-api.h>
|
#include <gna2-inference-api.h>
|
||||||
|
|
||||||
#include "log/debug.hpp"
|
#include "log/debug.hpp"
|
||||||
|
#include "log/log.hpp"
|
||||||
#include "model_wrapper.hpp"
|
#include "model_wrapper.hpp"
|
||||||
#include "subrequest.hpp"
|
#include "subrequest.hpp"
|
||||||
|
|
||||||
@ -39,12 +40,19 @@ Gna2Model* WorkerImpl::model() {
|
|||||||
return &fullModel_->object();
|
return &fullModel_->object();
|
||||||
}
|
}
|
||||||
|
|
||||||
void WorkerImpl::enqueueRequest() {
|
bool WorkerImpl::enqueueRequest() {
|
||||||
check_if_free();
|
if (!isFree()) {
|
||||||
|
ov::intel_gna::log::warning() << "Trying to propagate on busy request with id: " << representingIndex_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto& subrequest : modelSubrequests_) {
|
for (auto& subrequest : modelSubrequests_) {
|
||||||
subrequest->enqueue();
|
if (!subrequest->enqueue()) {
|
||||||
|
cleanup_subrequests();
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
|
RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
|
||||||
@ -56,8 +64,13 @@ RequestStatus WorkerImpl::wait(int64_t timeoutMilliseconds) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (subrequest->wait(timeoutMilliseconds) == RequestStatus::kPending) {
|
auto result = subrequest->wait(timeoutMilliseconds);
|
||||||
|
|
||||||
|
if (result == RequestStatus::kPending) {
|
||||||
pending = true;
|
pending = true;
|
||||||
|
} else if (result == RequestStatus::kCompletedWithError) {
|
||||||
|
cleanup_subrequests();
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,9 +120,11 @@ InferenceEngine::BlobMap& WorkerImpl::result() {
|
|||||||
return requestResult_;
|
return requestResult_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void WorkerImpl::check_if_free() {
|
void WorkerImpl::cleanup_subrequests() {
|
||||||
if (!isFree()) {
|
for (auto& subrequest : modelSubrequests_) {
|
||||||
THROW_GNA_EXCEPTION << "Trying to propagte on busy request with id: " << representingIndex_;
|
if (subrequest->isPending()) {
|
||||||
|
subrequest->cleanup();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ public:
|
|||||||
/**
|
/**
|
||||||
* @see Worker::enqueueRequest()
|
* @see Worker::enqueueRequest()
|
||||||
*/
|
*/
|
||||||
void enqueueRequest() override;
|
bool enqueueRequest() override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @see Worker::wait()
|
* @see Worker::wait()
|
||||||
@ -93,7 +93,7 @@ public:
|
|||||||
void setResult(InferenceEngine::BlobMap&& result) override;
|
void setResult(InferenceEngine::BlobMap&& result) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void check_if_free();
|
void cleanup_subrequests();
|
||||||
|
|
||||||
uint32_t representingIndex_{0};
|
uint32_t representingIndex_{0};
|
||||||
std::shared_ptr<ModelWrapper> fullModel_;
|
std::shared_ptr<ModelWrapper> fullModel_;
|
||||||
|
@ -13,7 +13,8 @@ enum class RequestStatus {
|
|||||||
kNone = 0, /// request was not initialized
|
kNone = 0, /// request was not initialized
|
||||||
kAborted = 1, /// request was aborted
|
kAborted = 1, /// request was aborted
|
||||||
kPending = 2, /// request was started and is onging
|
kPending = 2, /// request was started and is onging
|
||||||
kCompleted = 3 /// request was completed with success
|
kCompleted = 3, /// request was completed with success
|
||||||
|
kCompletedWithError = 4 /// request was completed with error
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace GNAPluginNS
|
} // namespace GNAPluginNS
|
||||||
|
@ -14,7 +14,7 @@ endif()
|
|||||||
|
|
||||||
# TODO: fix CVS-71010 and remove BUILD_SHARED_LIBS
|
# TODO: fix CVS-71010 and remove BUILD_SHARED_LIBS
|
||||||
if(NOT BUILD_SHARED_LIBS)
|
if(NOT BUILD_SHARED_LIBS)
|
||||||
set(exclude_path EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/(gna_api_stub|gna_wait_test|gna_export_import_test).cpp")
|
set(exclude_path EXCLUDED_SOURCE_PATHS "${CMAKE_CURRENT_SOURCE_DIR}/(gna_api_stub|gna_wait_test|gna_export_import_test|gna_infer_request_test).cpp")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
addIeTargetTest(
|
addIeTargetTest(
|
||||||
|
@ -138,6 +138,9 @@ GNA2_API enum Gna2Status Gna2RequestConfigSetAccelerationMode(
|
|||||||
GNA2_API enum Gna2Status Gna2RequestEnqueue(
|
GNA2_API enum Gna2Status Gna2RequestEnqueue(
|
||||||
uint32_t requestConfigId,
|
uint32_t requestConfigId,
|
||||||
uint32_t * requestId) {
|
uint32_t * requestId) {
|
||||||
|
if (current != nullptr) {
|
||||||
|
return current->Gna2RequestEnqueue(requestConfigId, requestId);
|
||||||
|
}
|
||||||
return Gna2StatusSuccess;
|
return Gna2StatusSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
204
src/plugins/intel_gna/tests/unit/gna_infer_request_test.cpp
Normal file
204
src/plugins/intel_gna/tests/unit/gna_infer_request_test.cpp
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "gna_infer_request.hpp"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "any_copy.hpp"
|
||||||
|
#include "common_test_utils/data_utils.hpp"
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "gna_mock_api.hpp"
|
||||||
|
#include "gna_plugin.hpp"
|
||||||
|
#include "ngraph_functions/builders.hpp"
|
||||||
|
|
||||||
|
using namespace ::testing;
|
||||||
|
using namespace InferenceEngine;
|
||||||
|
|
||||||
|
using GNAPluginNS::GNAInferRequest;
|
||||||
|
using GNAPluginNS::GNAPlugin;
|
||||||
|
using ::testing::InSequence;
|
||||||
|
|
||||||
|
class GNAInferRequestTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
IInferRequestInternal::Ptr CreateRequest() {
|
||||||
|
auto function = GetFunction();
|
||||||
|
CNNNetwork cnn_network = CNNNetwork{function};
|
||||||
|
|
||||||
|
SetExpectsForLoadNetworkAndShutDown(_data);
|
||||||
|
const ov::AnyMap gna_config = {ov::intel_gna::execution_mode(ov::intel_gna::ExecutionMode::SW_EXACT)};
|
||||||
|
auto plugin = std::make_shared<GNAPlugin>(any_copy(gna_config));
|
||||||
|
plugin->LoadNetwork(cnn_network);
|
||||||
|
|
||||||
|
return std::make_shared<GNAInferRequest>(plugin, cnn_network.getInputsInfo(), cnn_network.getOutputsInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<ngraph::Function> GetFunction() {
|
||||||
|
auto ngPrc = ngraph::element::f32;
|
||||||
|
std::vector<size_t> shape = {1, 10};
|
||||||
|
auto params = ngraph::builder::makeParams(ngPrc, {shape});
|
||||||
|
auto shape_size = ov::shape_size(shape);
|
||||||
|
auto add_const =
|
||||||
|
ngraph::builder::makeConstant<float>(ngPrc,
|
||||||
|
shape,
|
||||||
|
CommonTestUtils::generate_float_numbers(shape_size, -0.5f, 0.5f),
|
||||||
|
false);
|
||||||
|
|
||||||
|
auto add = std::make_shared<ngraph::opset9::Add>(params[0], add_const);
|
||||||
|
auto res = std::make_shared<ngraph::op::Result>(add);
|
||||||
|
auto function = std::make_shared<ngraph::Function>(res, params, "Add");
|
||||||
|
return function;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetExpectsForLoadNetworkAndShutDown(std::vector<std::vector<uint8_t>>& data) {
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2MemoryAlloc(_, _, _))
|
||||||
|
.Times(AtLeast(1))
|
||||||
|
.WillRepeatedly(Invoke([&data](uint32_t size_requested, uint32_t* size_granted, void** memory_address) {
|
||||||
|
data.push_back(std::vector<uint8_t>(size_requested));
|
||||||
|
*size_granted = size_requested;
|
||||||
|
*memory_address = data.back().data();
|
||||||
|
return Gna2StatusSuccess;
|
||||||
|
}));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2DeviceGetVersion(_, _))
|
||||||
|
.WillOnce(Invoke([](uint32_t deviceIndex, enum Gna2DeviceVersion* deviceVersion) {
|
||||||
|
*deviceVersion = Gna2DeviceVersionSoftwareEmulation;
|
||||||
|
return Gna2StatusSuccess;
|
||||||
|
}));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2GetLibraryVersion(_, _))
|
||||||
|
.Times(AtLeast(0))
|
||||||
|
.WillRepeatedly(Return(Gna2StatusSuccess));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2InstrumentationConfigCreate(_, _, _, _)).WillOnce(Return(Gna2StatusSuccess));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2ModelCreate(_, _, _))
|
||||||
|
.WillOnce(Invoke([](uint32_t deviceIndex, struct Gna2Model const* model, uint32_t* model_id) {
|
||||||
|
*model_id = 0;
|
||||||
|
return Gna2StatusSuccess;
|
||||||
|
}));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2RequestConfigCreate(_, _))
|
||||||
|
.WillOnce(Invoke([](uint32_t model_Id, uint32_t* request_config_id) {
|
||||||
|
*request_config_id = 0;
|
||||||
|
return Gna2StatusSuccess;
|
||||||
|
}));
|
||||||
|
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2InstrumentationConfigAssignToRequestConfig(_, _))
|
||||||
|
.Times(AtLeast(1))
|
||||||
|
.WillRepeatedly(Return(Gna2StatusSuccess));
|
||||||
|
|
||||||
|
InSequence seq;
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2DeviceClose(_)).WillOnce(Return(Gna2StatusSuccess));
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2MemoryFree(_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetExpectOnEnqueue(const Gna2Status& return_status = Gna2StatusSuccess) {
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2RequestEnqueue(_, _)).WillOnce(Return(return_status));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetExpectOnWait(const Gna2Status& return_status = Gna2StatusSuccess) {
|
||||||
|
EXPECT_CALL(*_mock_api, Gna2RequestWait(_, _)).WillOnce(Return(return_status));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetUp() override {
|
||||||
|
_mock_api = std::make_shared<StrictMock<GNACppApi>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
ASSERT_TRUE(Mock::VerifyAndClearExpectations(_mock_api.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<StrictMock<GNACppApi>> _mock_api;
|
||||||
|
std::vector<std::vector<uint8_t>> _data;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, start_async) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
// wait on shutdown neede if request was enqueued by not waited
|
||||||
|
SetExpectOnWait();
|
||||||
|
EXPECT_NO_THROW(request->StartAsync());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, start_async_with_enqueue_error) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
// trigger Gna2RequestEnqueue to fail
|
||||||
|
SetExpectOnEnqueue(Gna2StatusUnknownError);
|
||||||
|
// no wait needed due the fact there are no enqueud requests
|
||||||
|
EXPECT_THROW(request->StartAsync(), std::exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, start_async_with_wait) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
// wait on shutdown needed if request was enqueued by not waited
|
||||||
|
SetExpectOnWait();
|
||||||
|
EXPECT_NO_THROW(request->StartAsync());
|
||||||
|
EXPECT_EQ(OK, request->Wait(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, start_async_error_with_wait) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue(Gna2StatusUnknownError);
|
||||||
|
// wait on shutdown needed if request was enqueued by not waited
|
||||||
|
// SetExpectOnWait();
|
||||||
|
EXPECT_THROW(request->StartAsync(), std::exception);
|
||||||
|
EXPECT_EQ(INFER_NOT_STARTED, request->Wait(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, start_async_with_wait_error) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
// wait on shutdown needed if request was enqueued by not waited
|
||||||
|
SetExpectOnWait(Gna2StatusUnknownError);
|
||||||
|
EXPECT_NO_THROW(request->StartAsync());
|
||||||
|
EXPECT_THROW(request->Wait(0), std::exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, start_async_wait_check_recovery_after_wait_error) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
SetExpectOnWait(Gna2StatusUnknownError);
|
||||||
|
EXPECT_NO_THROW(request->StartAsync());
|
||||||
|
EXPECT_THROW(request->Wait(0), std::exception);
|
||||||
|
// check that no there is exception on second wiat after first failing
|
||||||
|
EXPECT_EQ(INFER_NOT_STARTED, request->Wait(0));
|
||||||
|
|
||||||
|
// start new request
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
SetExpectOnWait();
|
||||||
|
EXPECT_NO_THROW(request->StartAsync());
|
||||||
|
EXPECT_EQ(OK, request->Wait(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, infer) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
SetExpectOnWait();
|
||||||
|
EXPECT_NO_THROW(request->Infer());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, infer_enque_error) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue(Gna2StatusUnknownError);
|
||||||
|
EXPECT_THROW(request->Infer(), std::exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAInferRequestTest, infer_wait_error_check_recovery) {
|
||||||
|
auto request = CreateRequest();
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
SetExpectOnWait(Gna2StatusUnknownError);
|
||||||
|
EXPECT_THROW(request->Infer(), std::exception);
|
||||||
|
// check if next infer will execute properly after wait throwing
|
||||||
|
SetExpectOnEnqueue();
|
||||||
|
SetExpectOnWait();
|
||||||
|
EXPECT_NO_THROW(request->Infer());
|
||||||
|
}
|
@ -47,6 +47,10 @@ public:
|
|||||||
uint32_t requestId,
|
uint32_t requestId,
|
||||||
uint32_t timeoutMilliseconds));
|
uint32_t timeoutMilliseconds));
|
||||||
|
|
||||||
|
MOCK_METHOD2(Gna2RequestEnqueue, Gna2Status(
|
||||||
|
uint32_t requestConfigId,
|
||||||
|
uint32_t* requestId));
|
||||||
|
|
||||||
MOCK_METHOD2(Gna2DeviceGetVersion, Gna2Status(
|
MOCK_METHOD2(Gna2DeviceGetVersion, Gna2Status(
|
||||||
uint32_t deviceIndex,
|
uint32_t deviceIndex,
|
||||||
enum Gna2DeviceVersion * deviceVersion));
|
enum Gna2DeviceVersion * deviceVersion));
|
||||||
|
@ -10,14 +10,15 @@
|
|||||||
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
|
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
|
||||||
#include "gna_infer_request.hpp"
|
#include "gna_infer_request.hpp"
|
||||||
#include "gna_mock_api.hpp"
|
#include "gna_mock_api.hpp"
|
||||||
|
#include "gna_plugin.hpp"
|
||||||
#include "request/model_wrapper_factory.hpp"
|
#include "request/model_wrapper_factory.hpp"
|
||||||
#include "request/subrequest_impl.hpp"
|
#include "request/subrequest_impl.hpp"
|
||||||
#include "request/worker_factory.hpp"
|
#include "request/worker_factory.hpp"
|
||||||
#include "request/worker_impl.hpp"
|
#include "request/worker_impl.hpp"
|
||||||
#include "request/worker_pool.hpp"
|
#include "request/worker_pool.hpp"
|
||||||
|
|
||||||
using GNAPluginNS::GNAInferRequest;
|
using namespace GNAPluginNS;
|
||||||
using GNAPluginNS::GNAPlugin;
|
using namespace GNAPluginNS::request;
|
||||||
using ::testing::_;
|
using ::testing::_;
|
||||||
using ::testing::Return;
|
using ::testing::Return;
|
||||||
|
|
||||||
@ -27,9 +28,6 @@ class GNAPluginForGNAWaitTest : public GNAPlugin {
|
|||||||
public:
|
public:
|
||||||
// Prepare underlining object to enable GNAInferRequest::Wait() working
|
// Prepare underlining object to enable GNAInferRequest::Wait() working
|
||||||
GNAPluginForGNAWaitTest() {
|
GNAPluginForGNAWaitTest() {
|
||||||
using namespace GNAPluginNS;
|
|
||||||
using namespace request;
|
|
||||||
|
|
||||||
InferenceEngine::TensorDesc td{InferenceEngine::Precision::FP32, {1, 1}, InferenceEngine::Layout::HW};
|
InferenceEngine::TensorDesc td{InferenceEngine::Precision::FP32, {1, 1}, InferenceEngine::Layout::HW};
|
||||||
auto fakeInfo = std::make_shared<InferenceEngine::InputInfo>();
|
auto fakeInfo = std::make_shared<InferenceEngine::InputInfo>();
|
||||||
auto fakePtr = std::make_shared<InferenceEngine::Data>("fakeName", td);
|
auto fakePtr = std::make_shared<InferenceEngine::Data>("fakeName", td);
|
||||||
@ -55,20 +53,32 @@ public:
|
|||||||
|
|
||||||
auto model = ModelWrapperFactory::createWithNumberOfEmptyOperations(1);
|
auto model = ModelWrapperFactory::createWithNumberOfEmptyOperations(1);
|
||||||
subrequests.push_back(std::make_shared<SubrequestImpl>(std::move(enqueue), std::move(wait)));
|
subrequests.push_back(std::make_shared<SubrequestImpl>(std::move(enqueue), std::move(wait)));
|
||||||
auto worker = std::make_shared<WorkerImpl>(model, std::move(subrequests));
|
_worker = std::make_shared<WorkerImpl>(model, std::move(subrequests));
|
||||||
|
|
||||||
requestWorkerPool_->addModelWorker(worker);
|
requestWorkerPool_->addModelWorker(_worker);
|
||||||
worker->enqueueRequest();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EnqueTestRequest() {
|
||||||
|
_worker->enqueueRequest();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Worker> _worker;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GNAInferRequestForGNAWaitTest : public GNAInferRequest {
|
class GNAInferRequestForGNAWaitTest : public GNAInferRequest {
|
||||||
public:
|
public:
|
||||||
// Prepare underlining object to enable Wait() working
|
// Prepare underlining object to enable Wait() working
|
||||||
GNAInferRequestForGNAWaitTest(std::shared_ptr<GNAPlugin> plugin)
|
GNAInferRequestForGNAWaitTest(std::shared_ptr<GNAPluginForGNAWaitTest> plugin)
|
||||||
: GNAInferRequest{plugin, plugin->GetNetworkInputs(), plugin->GetNetworkOutputs()} {
|
: GNAInferRequest{plugin, plugin->GetNetworkInputs(), plugin->GetNetworkOutputs()},
|
||||||
inferRequestIdx = 0;
|
_plugin(plugin) {}
|
||||||
|
|
||||||
|
void EnqueTestRequest() {
|
||||||
|
_plugin->EnqueTestRequest();
|
||||||
|
SetRequestIndex(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GNAPluginForGNAWaitTest> _plugin;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
|
TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
|
||||||
@ -76,6 +86,7 @@ TEST_F(GNAWaitTest, ReturnsGna2StatusDriverQoSTimeoutExceeded) {
|
|||||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDriverQoSTimeoutExceeded));
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDriverQoSTimeoutExceeded));
|
||||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||||
|
inferRequest.EnqueTestRequest();
|
||||||
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
|
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,6 +95,38 @@ TEST_F(GNAWaitTest, ReturnsGna2StatusWarningDeviceBusy) {
|
|||||||
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusWarningDeviceBusy));
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusWarningDeviceBusy));
|
||||||
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||||
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||||
|
inferRequest.EnqueTestRequest();
|
||||||
ASSERT_EQ(InferenceEngine::RESULT_NOT_READY, inferRequest.Wait(0));
|
ASSERT_EQ(InferenceEngine::RESULT_NOT_READY, inferRequest.Wait(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange) {
|
||||||
|
GNACppApi enableMocks;
|
||||||
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
|
||||||
|
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||||
|
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||||
|
inferRequest.EnqueTestRequest();
|
||||||
|
ASSERT_THROW(inferRequest.Wait(0), std::exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange_Extra_Sync) {
|
||||||
|
GNACppApi enableMocks;
|
||||||
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
|
||||||
|
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||||
|
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||||
|
inferRequest.EnqueTestRequest();
|
||||||
|
ASSERT_THROW(inferRequest.Wait(0), std::exception);
|
||||||
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(0);
|
||||||
|
ASSERT_EQ(InferenceEngine::INFER_NOT_STARTED, inferRequest.Wait(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNAWaitTest, ReturnsGna2StatusDeviceParameterOutOfRange_Another_Use) {
|
||||||
|
GNACppApi enableMocks;
|
||||||
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusDeviceParameterOutOfRange));
|
||||||
|
auto plugin = std::make_shared<GNAPluginForGNAWaitTest>();
|
||||||
|
GNAInferRequestForGNAWaitTest inferRequest{plugin};
|
||||||
|
inferRequest.EnqueTestRequest();
|
||||||
|
ASSERT_THROW(inferRequest.Wait(0), std::exception);
|
||||||
|
inferRequest.EnqueTestRequest();
|
||||||
|
EXPECT_CALL(enableMocks, Gna2RequestWait(_, _)).Times(1).WillOnce(Return(Gna2StatusSuccess));
|
||||||
|
ASSERT_EQ(InferenceEngine::OK, inferRequest.Wait(0));
|
||||||
|
}
|
||||||
|
@ -60,16 +60,40 @@ TEST_F(GNA_Request_WorkerImplTest, enqueueRequest) {
|
|||||||
subrequests.push_back(subrequestMock2);
|
subrequests.push_back(subrequestMock2);
|
||||||
WorkerImpl worker(wrapper, subrequests);
|
WorkerImpl worker(wrapper, subrequests);
|
||||||
|
|
||||||
// check if exception will be thrown if worker will be busy - at least one subrequest is pending.
|
// check if false will be returned if worker will be busy - at least one subrequest is pending.
|
||||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||||
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
|
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
|
||||||
EXPECT_THROW(worker.enqueueRequest(), std::exception);
|
EXPECT_FALSE(worker.enqueueRequest());
|
||||||
|
|
||||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||||
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(false));
|
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||||
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1);
|
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1).WillOnce(Return(true));
|
||||||
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1);
|
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1).WillOnce(Return(true));
|
||||||
EXPECT_NO_THROW(worker.enqueueRequest());
|
EXPECT_TRUE(worker.enqueueRequest());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GNA_Request_WorkerImplTest, enqueueRequest_second_enque_failed) {
|
||||||
|
auto wrapper = ModelWrapperFactory::createTrivial();
|
||||||
|
std::vector<std::shared_ptr<Subrequest>> subrequests;
|
||||||
|
|
||||||
|
auto subrequestMock1 = std::make_shared<MockSubrequest>();
|
||||||
|
subrequests.push_back(subrequestMock1);
|
||||||
|
auto subrequestMock2 = std::make_shared<MockSubrequest>();
|
||||||
|
subrequests.push_back(subrequestMock2);
|
||||||
|
WorkerImpl worker(wrapper, subrequests);
|
||||||
|
|
||||||
|
// check if exception will be thrown if worker will be busy - at least one subrequest is pending.
|
||||||
|
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||||
|
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(1).WillOnce(Return(true));
|
||||||
|
EXPECT_FALSE(worker.enqueueRequest());
|
||||||
|
|
||||||
|
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(2).WillOnce(Return(false)).WillOnce(Return(true));
|
||||||
|
EXPECT_CALL(*subrequestMock2.get(), isPending()).Times(2).WillOnce(Return(false)).WillOnce(Return(false));
|
||||||
|
EXPECT_CALL(*subrequestMock1.get(), enqueue()).Times(1).WillOnce(Return(true));
|
||||||
|
EXPECT_CALL(*subrequestMock2.get(), enqueue()).Times(1).WillOnce(Return(false));
|
||||||
|
EXPECT_CALL(*subrequestMock1.get(), cleanup()).Times(1);
|
||||||
|
EXPECT_CALL(*subrequestMock2.get(), cleanup()).Times(0);
|
||||||
|
EXPECT_FALSE(worker.enqueueRequest());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GNA_Request_WorkerImplTest, wait) {
|
TEST_F(GNA_Request_WorkerImplTest, wait) {
|
||||||
@ -107,6 +131,13 @@ TEST_F(GNA_Request_WorkerImplTest, wait) {
|
|||||||
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(1).WillOnce(Return(false));
|
||||||
EXPECT_CALL(*subrequestMock1.get(), isAborted()).Times(1).WillOnce(Return(true));
|
EXPECT_CALL(*subrequestMock1.get(), isAborted()).Times(1).WillOnce(Return(true));
|
||||||
EXPECT_EQ(RequestStatus::kAborted, worker.wait(referenceTimeout));
|
EXPECT_EQ(RequestStatus::kAborted, worker.wait(referenceTimeout));
|
||||||
|
|
||||||
|
// subrequest enuqued and completed with error
|
||||||
|
EXPECT_CALL(*subrequestMock1.get(), isPending()).Times(2).WillOnce(Return(true)).WillOnce(Return(false));
|
||||||
|
EXPECT_CALL(*subrequestMock1.get(), wait(referenceTimeout))
|
||||||
|
.Times(1)
|
||||||
|
.WillOnce(Return(RequestStatus::kCompletedWithError));
|
||||||
|
EXPECT_EQ(RequestStatus::kCompletedWithError, worker.wait(referenceTimeout));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GNA_Request_WorkerImplTest, isFree) {
|
TEST_F(GNA_Request_WorkerImplTest, isFree) {
|
||||||
|
@ -13,7 +13,8 @@ namespace request {
|
|||||||
class MockSubrequest : public Subrequest {
|
class MockSubrequest : public Subrequest {
|
||||||
public:
|
public:
|
||||||
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
|
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
|
||||||
MOCK_METHOD(void, enqueue, (), (override));
|
MOCK_METHOD(bool, enqueue, (), (override));
|
||||||
|
MOCK_METHOD(void, cleanup, (), (override));
|
||||||
MOCK_METHOD(bool, isPending, (), (const, override));
|
MOCK_METHOD(bool, isPending, (), (const, override));
|
||||||
MOCK_METHOD(bool, isAborted, (), (const, override));
|
MOCK_METHOD(bool, isAborted, (), (const, override));
|
||||||
MOCK_METHOD(bool, isCompleted, (), (const, override));
|
MOCK_METHOD(bool, isCompleted, (), (const, override));
|
||||||
|
@ -14,7 +14,7 @@ class MockWorker : public Worker {
|
|||||||
public:
|
public:
|
||||||
MOCK_METHOD(Gna2Model*, model, (), (override));
|
MOCK_METHOD(Gna2Model*, model, (), (override));
|
||||||
MOCK_METHOD(const Gna2Model*, model, (), (const, override));
|
MOCK_METHOD(const Gna2Model*, model, (), (const, override));
|
||||||
MOCK_METHOD(void, enqueueRequest, (), (override));
|
MOCK_METHOD(bool, enqueueRequest, (), (override));
|
||||||
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
|
MOCK_METHOD(RequestStatus, wait, (int64_t), (override));
|
||||||
MOCK_METHOD(bool, isFree, (), (const, override));
|
MOCK_METHOD(bool, isFree, (), (const, override));
|
||||||
MOCK_METHOD(uint32_t, representingIndex, (), (const, override));
|
MOCK_METHOD(uint32_t, representingIndex, (), (const, override));
|
||||||
|
@ -36,7 +36,6 @@ public:
|
|||||||
if (!legacy_str.empty()) {
|
if (!legacy_str.empty()) {
|
||||||
name += "," + legacy_str;
|
name += "," + legacy_str;
|
||||||
}
|
}
|
||||||
std::cout << "name: " << name << std::endl;
|
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
|
Loading…
Reference in New Issue
Block a user