Merged internal Infer Request implementation (#5125)
This commit is contained in:
parent
ef70e5187c
commit
46987def54
@ -8,7 +8,7 @@
|
||||
`InferRequest` Class
|
||||
------------------------
|
||||
|
||||
Inference Engine Plugin API provides the helper InferenceEngine::InferRequestInternal class recommended
|
||||
Inference Engine Plugin API provides the helper InferenceEngine::IInferRequestInternal class recommended
|
||||
to use as a base class for a synchronous inference request implementation. Based of that, a declaration
|
||||
of a synchronous request class can look as follows:
|
||||
|
||||
@ -46,7 +46,7 @@ Decrements a number of created inference requests:
|
||||
|
||||
### `InferImpl()`
|
||||
|
||||
**Implementation details:** Base InferRequestInternal class implements the public InferenceEngine::InferRequestInternal::Infer method as following:
|
||||
**Implementation details:** Base IInferRequestInternal class implements the public InferenceEngine::IInferRequestInternal::Infer method as following:
|
||||
- Checks blobs set by users
|
||||
- Calls the `InferImpl` method defined in a derived class to call actual pipeline stages synchronously
|
||||
|
||||
@ -59,7 +59,7 @@ Below is the code of the the `inferPreprocess` method to demonstrate Inference E
|
||||
@snippet src/template_infer_request.cpp infer_request:infer_preprocess
|
||||
|
||||
**Details:**
|
||||
* `InferImpl` must call the InferenceEngine::InferRequestInternal::execDataPreprocessing function, which executes common Inference Engine preprocessing step (for example, applies resize or color conversion operations) if it is set by the user. The output dimensions, layout and precision matches the input information set via InferenceEngine::CNNNetwork::getInputsInfo.
|
||||
* `InferImpl` must call the InferenceEngine::IInferRequestInternal::execDataPreprocessing function, which executes common Inference Engine preprocessing step (for example, applies resize or color conversion operations) if it is set by the user. The output dimensions, layout and precision matches the input information set via InferenceEngine::CNNNetwork::getInputsInfo.
|
||||
* If `inputBlob` passed by user differs in terms of precisions from precision expected by plugin, `blobCopy` is performed which does actual precision conversion.
|
||||
|
||||
#### 2. `startPipeline`
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class AcceleratorSyncRequest : public InferRequestInternal {
|
||||
class AcceleratorSyncRequest : public IInferRequestInternal {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<AcceleratorSyncRequest>;
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
int main() {
|
||||
InferenceEngine::Core core;
|
||||
InferenceEngine::IInferRequest::CompletionCallback callback;
|
||||
InferenceEngine::IInferRequest::CompletionCallback callback = nullptr;
|
||||
int numRequests = 42;
|
||||
int i = 1;
|
||||
auto network = core.ReadNetwork("sample.xml");
|
||||
|
@ -18,7 +18,7 @@ public:
|
||||
const InferenceEngine::ITaskExecutor::Ptr& waitExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
||||
|
||||
~TemplateAsyncInferRequest() override;
|
||||
~TemplateAsyncInferRequest();
|
||||
|
||||
private:
|
||||
TemplateInferRequest::Ptr _inferRequest;
|
||||
|
@ -131,21 +131,18 @@ void TemplatePlugin::ExecutableNetwork::InitExecutor() {
|
||||
|
||||
|
||||
// ! [executable_network:create_infer_request_impl]
|
||||
InferenceEngine::InferRequestInternal::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::IInferRequestInternal::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) {
|
||||
return std::make_shared<TemplateInferRequest>(networkInputs, networkOutputs, std::static_pointer_cast<ExecutableNetwork>(shared_from_this()));
|
||||
}
|
||||
// ! [executable_network:create_infer_request_impl]
|
||||
|
||||
// ! [executable_network:create_infer_request]
|
||||
InferenceEngine::IInferRequest::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequest() {
|
||||
InferenceEngine::IInferRequestInternal::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequest() {
|
||||
InferenceEngine::IInferRequest::Ptr asyncRequest;
|
||||
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
|
||||
return std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
|
||||
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
|
||||
asyncRequest.reset(new InferenceEngine::InferRequestBase(asyncThreadSafeImpl));
|
||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
}
|
||||
// ! [executable_network:create_infer_request]
|
||||
|
||||
|
@ -36,9 +36,9 @@ public:
|
||||
// Methods from a base class ExecutableNetworkThreadSafeDefault
|
||||
|
||||
void ExportImpl(std::ostream& model) override;
|
||||
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::Parameter GetMetric(const std::string &name) const override;
|
||||
InferenceEngine::Parameter GetConfig(const std::string &name) const override;
|
||||
|
||||
|
@ -33,7 +33,7 @@ using Time = std::chrono::high_resolution_clock;
|
||||
TemplateInferRequest::TemplateInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
|
||||
const InferenceEngine::OutputsDataMap& networkOutputs,
|
||||
const std::shared_ptr<TemplatePlugin::ExecutableNetwork>& executableNetwork) :
|
||||
InferRequestInternal(networkInputs, networkOutputs),
|
||||
IInferRequestInternal(networkInputs, networkOutputs),
|
||||
_executableNetwork(executableNetwork) {
|
||||
// TODO: allocate infer request device and host buffers if needed, fill actual list of profiling tasks
|
||||
|
||||
@ -178,9 +178,9 @@ static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
|
||||
void TemplateInferRequest::inferPreprocess() {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, _profilingTask[Preprocess]);
|
||||
auto start = Time::now();
|
||||
// NOTE: After InferRequestInternal::execDataPreprocessing call
|
||||
// NOTE: After IInferRequestInternal::execDataPreprocessing call
|
||||
// input can points to other memory region than it was allocated in constructor.
|
||||
InferRequestInternal::execDataPreprocessing(_deviceInputs);
|
||||
IInferRequestInternal::execDataPreprocessing(_deviceInputs);
|
||||
for (auto&& networkInput : _deviceInputs) {
|
||||
auto index = _executableNetwork->_inputIndex[networkInput.first];
|
||||
const auto& parameter = _parameters[index];
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include <ie_common.h>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
#include <threading/ie_itask_executor.hpp>
|
||||
#include <openvino/itt.hpp>
|
||||
@ -27,14 +26,14 @@ namespace TemplatePlugin {
|
||||
class ExecutableNetwork;
|
||||
|
||||
// ! [infer_request:header]
|
||||
class TemplateInferRequest : public InferenceEngine::InferRequestInternal {
|
||||
class TemplateInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
typedef std::shared_ptr<TemplateInferRequest> Ptr;
|
||||
|
||||
TemplateInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
|
||||
const InferenceEngine::OutputsDataMap& networkOutputs,
|
||||
const std::shared_ptr<ExecutableNetwork>& executableNetwork);
|
||||
~TemplateInferRequest() override;
|
||||
~TemplateInferRequest();
|
||||
|
||||
void InferImpl() override;
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <ie_metric_helpers.hpp>
|
||||
#include <ie_plugin_config.hpp>
|
||||
#include <ie_algorithm.hpp>
|
||||
|
||||
#include <hetero/hetero_plugin_config.hpp>
|
||||
#include <threading/ie_executor_manager.hpp>
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
/**
|
||||
* @brief A header file that provides ExecutableNetwork class
|
||||
*
|
||||
*
|
||||
* @file ie_executable_network.hpp
|
||||
*/
|
||||
|
||||
|
@ -18,44 +18,13 @@
|
||||
#include "ie_iinfer_request.hpp"
|
||||
#include "details/ie_so_loader.h"
|
||||
#include "ie_blob.h"
|
||||
#include "ie_iinfer_request.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
namespace details {
|
||||
|
||||
class ICompletionCallbackWrapper {
|
||||
public:
|
||||
virtual ~ICompletionCallbackWrapper() = default;
|
||||
|
||||
virtual void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept = 0;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class CompletionCallbackWrapper : public ICompletionCallbackWrapper {
|
||||
T lambda;
|
||||
|
||||
public:
|
||||
explicit CompletionCallbackWrapper(const T& lambda): lambda(lambda) {}
|
||||
|
||||
void call(InferenceEngine::IInferRequest::Ptr /*request*/, InferenceEngine::StatusCode /*code*/) const
|
||||
noexcept override {
|
||||
lambda();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class CompletionCallbackWrapper<IInferRequest::CompletionCallback> : public ICompletionCallbackWrapper {
|
||||
IInferRequest::CompletionCallback callBack;
|
||||
|
||||
public:
|
||||
explicit CompletionCallbackWrapper(const IInferRequest::CompletionCallback& callBack): callBack(callBack) {}
|
||||
|
||||
void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept override {
|
||||
callBack(request, code);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
class SharedObjectLoader;
|
||||
}
|
||||
class IInferRequestInternal;
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest
|
||||
@ -63,42 +32,41 @@ public:
|
||||
* Wraps IInferRequest
|
||||
* It can throw exceptions safely for the application, where it is properly handled.
|
||||
*/
|
||||
class InferRequest {
|
||||
IInferRequest::Ptr actual;
|
||||
InferenceEngine::details::SharedObjectLoader::Ptr plg;
|
||||
std::shared_ptr<details::ICompletionCallbackWrapper> callback;
|
||||
class INFERENCE_ENGINE_API_CLASS(InferRequest) {
|
||||
std::shared_ptr<IInferRequestInternal> _impl;
|
||||
std::shared_ptr<details::SharedObjectLoader> _so;
|
||||
|
||||
static void callWrapper(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) {
|
||||
details::ICompletionCallbackWrapper* pWrapper = nullptr;
|
||||
ResponseDesc dsc;
|
||||
request->GetUserData(reinterpret_cast<void**>(&pWrapper), &dsc);
|
||||
pWrapper->call(request, code);
|
||||
}
|
||||
explicit InferRequest(const std::shared_ptr<IInferRequestInternal>& impl,
|
||||
const std::shared_ptr<details::SharedObjectLoader>& so);
|
||||
|
||||
friend class ExecutableNetwork;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @enum WaitMode
|
||||
* @brief Enumeration to hold wait mode for IInferRequest
|
||||
*/
|
||||
enum WaitMode : int64_t {
|
||||
/** Wait until inference result becomes available */
|
||||
RESULT_READY = -1,
|
||||
/** IInferRequest doesn't block or interrupt current thread and immediately returns inference status */
|
||||
STATUS_ONLY = 0,
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A smart pointer to the InferRequest object
|
||||
*/
|
||||
using Ptr = std::shared_ptr<InferRequest>;
|
||||
|
||||
/**
|
||||
* @brief Default constructor
|
||||
*/
|
||||
InferRequest() = default;
|
||||
|
||||
/**
|
||||
* constructs InferRequest from the initialized shared_pointer
|
||||
* @param request Initialized shared pointer to IInferRequest interface
|
||||
* @param splg Plugin to use. This is required to ensure that InferRequest can work properly even if plugin object is destroyed.
|
||||
*/
|
||||
explicit InferRequest(IInferRequest::Ptr request,
|
||||
InferenceEngine::details::SharedObjectLoader::Ptr splg = {}):
|
||||
actual(request), plg(splg) {
|
||||
// plg can be null, but not the actual
|
||||
if (actual == nullptr) IE_THROW() << "InferRequest was not initialized.";
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
*/
|
||||
~InferRequest() {
|
||||
actual = nullptr;
|
||||
}
|
||||
~InferRequest();
|
||||
|
||||
/**
|
||||
* @brief Sets input/output data to infer
|
||||
@ -108,27 +76,16 @@ public:
|
||||
* @param data Reference to input or output blob. The type of a blob must match the network input precision and
|
||||
* size.
|
||||
*/
|
||||
void SetBlob(const std::string& name, const Blob::Ptr& data) {
|
||||
CALL_STATUS_FNC(SetBlob, name.c_str(), data);
|
||||
}
|
||||
void SetBlob(const std::string& name, const Blob::Ptr& data);
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest::GetBlob
|
||||
* @brief Gets input/output data for inference
|
||||
*
|
||||
* Wraps IInferRequest::GetBlob
|
||||
* @note Memory allocation does not happen
|
||||
* @param name A name of Blob to get
|
||||
* @return A shared pointer to a Blob with a name @p name. If a blob is not found, an exception is thrown.
|
||||
*/
|
||||
Blob::Ptr GetBlob(const std::string& name) {
|
||||
Blob::Ptr data;
|
||||
CALL_STATUS_FNC(GetBlob, name.c_str(), data);
|
||||
std::string error = "Internal error: blob with name `" + name + "` is not allocated!";
|
||||
auto blobPtr = data.get();
|
||||
const bool remoteBlobPassed = blobPtr->is<RemoteBlob>();
|
||||
if (blobPtr == nullptr) IE_THROW() << error;
|
||||
if (!remoteBlobPassed && blobPtr->buffer() == nullptr) IE_THROW() << error;
|
||||
return data;
|
||||
}
|
||||
Blob::Ptr GetBlob(const std::string& name);
|
||||
|
||||
/**
|
||||
* @brief Sets blob with a pre-process information
|
||||
@ -137,51 +94,37 @@ public:
|
||||
* @param data A reference to input. The type of Blob must correspond to the network input precision and size.
|
||||
* @param info Preprocess info for blob.
|
||||
*/
|
||||
void SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo& info) {
|
||||
CALL_STATUS_FNC(SetBlob, name.c_str(), data, info);
|
||||
}
|
||||
void SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo& info);
|
||||
|
||||
/**
|
||||
* @brief Gets pre-process for input data
|
||||
* @param name Name of input blob.
|
||||
* @return pointer to pre-process info of blob with name
|
||||
*/
|
||||
const PreProcessInfo& GetPreProcess(const std::string& name) const {
|
||||
const PreProcessInfo* info = nullptr;
|
||||
CALL_STATUS_FNC(GetPreProcess, name.c_str(), &info);
|
||||
return *info;
|
||||
}
|
||||
const PreProcessInfo& GetPreProcess(const std::string& name) const;
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest::Infer
|
||||
* @brief Infers specified input(s) in synchronous mode
|
||||
*
|
||||
* @note blocks all methods of InferRequest while request is ongoing (running or waiting in queue)
|
||||
*
|
||||
* Wraps IInferRequest::Infer
|
||||
*/
|
||||
void Infer() {
|
||||
CALL_STATUS_FNC_NO_ARGS(Infer);
|
||||
}
|
||||
void Infer();
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest::Cancel
|
||||
*
|
||||
* Wraps IInferRequest::Cancel
|
||||
* @brief Cancel inference request
|
||||
* @param name Name of input blob.
|
||||
* @return pointer to pre-process info of blob with name
|
||||
*/
|
||||
void Cancel() {
|
||||
CALL_STATUS_FNC_NO_ARGS(Cancel);
|
||||
}
|
||||
void Cancel();
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest::GetPerformanceCounts
|
||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer
|
||||
*
|
||||
* Wraps IInferRequest::GetPerformanceCounts
|
||||
* @note not all plugins provide meaningful data
|
||||
* @return Map of layer names to profiling information for that layer
|
||||
*/
|
||||
std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const {
|
||||
std::map<std::string, InferenceEngineProfileInfo> perfMap;
|
||||
CALL_STATUS_FNC(GetPerformanceCounts, perfMap);
|
||||
return perfMap;
|
||||
}
|
||||
std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const;
|
||||
|
||||
/**
|
||||
* @brief Sets input data to infer
|
||||
@ -190,11 +133,7 @@ public:
|
||||
* @param inputs A reference to a map of input blobs accessed by input names.
|
||||
* The type of Blob must correspond to the network input precision and size.
|
||||
*/
|
||||
void SetInput(const BlobMap& inputs) {
|
||||
for (auto&& input : inputs) {
|
||||
CALL_STATUS_FNC(SetBlob, input.first.c_str(), input.second);
|
||||
}
|
||||
}
|
||||
void SetInput(const BlobMap& inputs);
|
||||
|
||||
/**
|
||||
* @brief Sets data that will contain result of the inference
|
||||
@ -203,34 +142,27 @@ public:
|
||||
* @param results - a reference to a map of result blobs accessed by output names.
|
||||
* The type of Blob must correspond to the network output precision and size.
|
||||
*/
|
||||
void SetOutput(const BlobMap& results) {
|
||||
for (auto&& result : results) {
|
||||
CALL_STATUS_FNC(SetBlob, result.first.c_str(), result.second);
|
||||
}
|
||||
}
|
||||
void SetOutput(const BlobMap& results);
|
||||
|
||||
/**
|
||||
* @brief Sets new batch size when dynamic batching is enabled in executable network that created this request.
|
||||
*
|
||||
* @param batch new batch size to be used by all the following inference calls for this request.
|
||||
*/
|
||||
void SetBatch(const int batch) {
|
||||
CALL_STATUS_FNC(SetBatch, batch);
|
||||
}
|
||||
void SetBatch(const int batch);
|
||||
|
||||
/**
|
||||
* @brief Start inference of specified input(s) in asynchronous mode
|
||||
*
|
||||
* @note It returns immediately. Inference starts also immediately.
|
||||
*/
|
||||
void StartAsync() {
|
||||
CALL_STATUS_FNC_NO_ARGS(StartAsync);
|
||||
}
|
||||
void StartAsync();
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest::Wait
|
||||
* @brief Waits for the result to become available. Blocks until specified millis_timeout has elapsed or the result
|
||||
* becomes available, whichever comes first.
|
||||
*
|
||||
*
|
||||
* Wraps IInferRequest::Wait
|
||||
* @param millis_timeout Maximum duration in milliseconds to block for
|
||||
* @note There are special cases when millis_timeout is equal some value of the WaitMode enum:
|
||||
* * STATUS_ONLY - immediately returns inference status (IInferRequest::RequestStatus). It does not block or
|
||||
@ -238,105 +170,69 @@ public:
|
||||
* * RESULT_READY - waits until inference result becomes available
|
||||
* @return A status code of operation
|
||||
*/
|
||||
StatusCode Wait(int64_t millis_timeout) {
|
||||
ResponseDesc resp;
|
||||
if (actual == nullptr) IE_THROW() << "InferRequest was not initialized.";
|
||||
auto res = actual->Wait(millis_timeout, &resp);
|
||||
if (res != OK && res != RESULT_NOT_READY &&
|
||||
res != INFER_NOT_STARTED && res != INFER_CANCELLED) {
|
||||
IE_EXCEPTION_SWITCH(res, ExceptionType,
|
||||
InferenceEngine::details::ThrowNow<ExceptionType>{}
|
||||
<<= std::stringstream{} << IE_LOCATION << resp.msg)
|
||||
}
|
||||
return res;
|
||||
}
|
||||
StatusCode Wait(int64_t millis_timeout = RESULT_READY);
|
||||
|
||||
private:
|
||||
void SetCompletionCallbackImpl(std::function<void()>);
|
||||
void SetCompletionCallbackImpl(std::function<void(InferRequest, StatusCode)>);
|
||||
void SetCompletionCallbackImpl(IInferRequest::CompletionCallback);
|
||||
template<typename T>
|
||||
struct SetCallback {
|
||||
void operator()(std::function<void()> f) {_this.SetCompletionCallbackImpl(std::move(f));}
|
||||
InferRequest& _this;
|
||||
};
|
||||
|
||||
public:
|
||||
/**
|
||||
* @copybrief IInferRequest::SetCompletionCallback
|
||||
* @brief Sets a callback function that will be called on success or failure of asynchronous request
|
||||
*
|
||||
* Wraps IInferRequest::SetCompletionCallback
|
||||
*
|
||||
* @param callbackToSet Lambda callback object which will be called on processing finish.
|
||||
* @param callbackToSet callback object which will be called on when inference finish.
|
||||
*/
|
||||
template <class T>
|
||||
void SetCompletionCallback(const T& callbackToSet) {
|
||||
callback.reset(new details::CompletionCallbackWrapper<T>(callbackToSet));
|
||||
CALL_STATUS_FNC(SetUserData, callback.get());
|
||||
actual->SetCompletionCallback(callWrapper);
|
||||
template<typename F>
|
||||
void SetCompletionCallback(F callbackToSet) {
|
||||
return SetCallback<F>{*this}(std::move(callbackToSet));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @copybrief IExecutableNetwork::QueryState
|
||||
* @brief Gets state control interface for given infer request.
|
||||
*
|
||||
* Wraps IExecutableNetwork::QueryState
|
||||
* State control essential for recurrent networks
|
||||
* @return A vector of Memory State objects
|
||||
*/
|
||||
std::vector<VariableState> QueryState() {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
if (actual == nullptr) IE_THROW() << "ExecutableNetwork was not initialized.";
|
||||
IVariableState::Ptr pState = nullptr;
|
||||
auto res = OK;
|
||||
std::vector<VariableState> controller;
|
||||
for (size_t idx = 0; res == OK; ++idx) {
|
||||
ResponseDesc resp;
|
||||
res = actual->QueryState(pState, idx, &resp);
|
||||
if (res != OK && res != OUT_OF_BOUNDS) {
|
||||
IE_THROW() << resp.msg;
|
||||
}
|
||||
if (res != OUT_OF_BOUNDS) {
|
||||
controller.push_back(VariableState(pState, plg));
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
return controller;
|
||||
}
|
||||
std::vector<VariableState> QueryState();
|
||||
|
||||
/**
|
||||
* @brief IInferRequest pointer to be used directly in CreateInferRequest functions
|
||||
* @return A shared pointer to underlying IInferRequest interface
|
||||
* @return A shared pointer to IInferRequest interface
|
||||
*/
|
||||
operator IInferRequest::Ptr&() {
|
||||
if (actual == nullptr) IE_THROW() << "InferRequest was not initialized.";
|
||||
return actual;
|
||||
}
|
||||
INFERENCE_ENGINE_DEPRECATED("Will be removed")
|
||||
operator std::shared_ptr<IInferRequest> ();
|
||||
|
||||
/**
|
||||
* @brief Checks if current InferRequest object is not initialized
|
||||
* @return true if current InferRequest object is not initialized, false - otherwise
|
||||
*/
|
||||
bool operator!() const noexcept {
|
||||
return !actual;
|
||||
}
|
||||
bool operator!() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Checks if current InferRequest object is initialized
|
||||
* @return true if current InferRequest object is initialized, false - otherwise
|
||||
*/
|
||||
explicit operator bool() const noexcept {
|
||||
return !!actual;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A smart pointer to the InferRequest object
|
||||
*/
|
||||
using Ptr = std::shared_ptr<InferRequest>;
|
||||
explicit operator bool() const noexcept;
|
||||
};
|
||||
|
||||
namespace details {
|
||||
|
||||
template <>
|
||||
class CompletionCallbackWrapper<std::function<void(InferRequest, StatusCode)>> : public ICompletionCallbackWrapper {
|
||||
std::function<void(InferRequest, StatusCode)> lambda;
|
||||
|
||||
public:
|
||||
explicit CompletionCallbackWrapper(const std::function<void(InferRequest, InferenceEngine::StatusCode)>& lambda)
|
||||
: lambda(lambda) {}
|
||||
|
||||
void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept override {
|
||||
lambda(InferRequest(request), code);
|
||||
template<>
|
||||
struct InferRequest::SetCallback<std::function<void(InferRequest, StatusCode)>> {
|
||||
void operator()(std::function<void(InferRequest, StatusCode)> f) {
|
||||
_this.SetCompletionCallbackImpl(std::move(f));
|
||||
}
|
||||
InferRequest& _this;
|
||||
};
|
||||
template<>
|
||||
struct InferRequest::SetCallback<IInferRequest::CompletionCallback> {
|
||||
void operator()(IInferRequest::CompletionCallback f) {
|
||||
_this.SetCompletionCallbackImpl(std::move(f));
|
||||
}
|
||||
InferRequest& _this;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
||||
|
@ -25,7 +25,8 @@ namespace InferenceEngine {
|
||||
* @brief This is an interface of asynchronous infer request
|
||||
*
|
||||
*/
|
||||
class IInferRequest : public std::enable_shared_from_this<IInferRequest> {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
class INFERENCE_ENGINE_DEPRECATED("Do not use IInferRequest API") IInferRequest : public std::enable_shared_from_this<IInferRequest> {
|
||||
public:
|
||||
/**
|
||||
* @enum WaitMode
|
||||
@ -201,5 +202,6 @@ public:
|
||||
protected:
|
||||
~IInferRequest() = default;
|
||||
};
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
} // namespace InferenceEngine
|
@ -5,7 +5,7 @@
|
||||
#include "cldnn_async_infer_request.h"
|
||||
#include <memory>
|
||||
|
||||
CLDNNPlugin::CLDNNAsyncInferRequest::CLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
|
||||
CLDNNPlugin::CLDNNAsyncInferRequest::CLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr &inferRequest,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor)
|
||||
: InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor)
|
||||
|
@ -13,13 +13,13 @@ namespace CLDNNPlugin {
|
||||
|
||||
class CLDNNAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
|
||||
public:
|
||||
CLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
|
||||
CLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr &inferRequest,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor);
|
||||
|
||||
void Infer_ThreadUnsafe() override;
|
||||
|
||||
~CLDNNAsyncInferRequest() override;
|
||||
~CLDNNAsyncInferRequest();
|
||||
};
|
||||
|
||||
} // namespace CLDNNPlugin
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include <ie_algorithm.hpp>
|
||||
|
||||
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
|
@ -26,6 +26,7 @@
|
||||
|
||||
#include "cldnn_executable_network.h"
|
||||
#include "threading/ie_cpu_streams_executor.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||
|
||||
|
||||
using namespace InferenceEngine;
|
||||
@ -62,8 +63,8 @@ CLDNNExecNetwork::CLDNNExecNetwork(InferenceEngine::CNNNetwork &network, RemoteC
|
||||
}
|
||||
}
|
||||
|
||||
InferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequestImpl(InputsDataMap networkInputs,
|
||||
OutputsDataMap networkOutputs) {
|
||||
IInferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequestImpl(InputsDataMap networkInputs,
|
||||
OutputsDataMap networkOutputs) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::CLDNNPlugin, "CLDNNExecNetwork::CreateInferRequestImpl");
|
||||
if (m_graphs.empty()) {
|
||||
IE_THROW(NetworkNotLoaded);
|
||||
@ -91,7 +92,7 @@ InferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequestImpl(InputsDataMap
|
||||
return ptr;
|
||||
}
|
||||
|
||||
IInferRequest::Ptr CLDNNExecNetwork::CreateInferRequest() {
|
||||
IInferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequest() {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::CLDNNPlugin, "CLDNNExecNetwork::CreateInferRequest");
|
||||
return CreateAsyncInferRequestFromSync<CLDNNAsyncInferRequest>();
|
||||
}
|
||||
|
@ -26,9 +26,9 @@ public:
|
||||
CLDNNExecNetwork(InferenceEngine::CNNNetwork &network, InferenceEngine::RemoteContext::Ptr context, Config config);
|
||||
|
||||
InferenceEngine::CNNNetwork GetExecGraphInfo() override;
|
||||
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
|
||||
InferenceEngine::Parameter GetMetric(const std::string &name) const override;
|
||||
InferenceEngine::Parameter GetConfig(const std::string &name) const override;
|
||||
|
@ -12,6 +12,8 @@
|
||||
#include "cldnn_remote_context.h"
|
||||
#include "cldnn_executable_network.h"
|
||||
#include "cldnn_itt.h"
|
||||
#include <ie_algorithm.hpp>
|
||||
#include <debug.h>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
@ -812,7 +814,7 @@ void CLDNNInferRequest::SetBatch(int new_batch) {
|
||||
|
||||
CLDNNInferRequest::CLDNNInferRequest(InputsDataMap networkInputs, OutputsDataMap networkOutputs,
|
||||
const CLDNNExecNetwork::Ptr& execNetwork)
|
||||
: InferRequestInternal(networkInputs, networkOutputs)
|
||||
: IInferRequestInternal(networkInputs, networkOutputs)
|
||||
, m_useProfiling(false)
|
||||
, m_useStreams(false) {
|
||||
IE_ASSERT(nullptr != execNetwork);
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include "cldnn_graph.h"
|
||||
#include <threading/ie_istreams_executor.hpp>
|
||||
|
||||
@ -22,7 +21,7 @@ struct buf_info {
|
||||
|
||||
class CLDNNExecNetwork;
|
||||
|
||||
class CLDNNInferRequest : public InferenceEngine::InferRequestInternal {
|
||||
class CLDNNInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
// make sure all blobs and cldnn::memory objects
|
||||
// are in place and valid
|
||||
|
@ -8,16 +8,15 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_executable_network_thread_safe_default.hpp>
|
||||
#include "gna_infer_request.hpp"
|
||||
#include "gna_plugin.hpp"
|
||||
#include <gna/gna_config.hpp>
|
||||
#include <threading/ie_executor_manager.hpp>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_thread_safe_async_only.hpp>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafeAsyncOnly {
|
||||
class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkInternal {
|
||||
std::shared_ptr<GNAPlugin> plg;
|
||||
|
||||
public:
|
||||
@ -53,9 +52,9 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
|
||||
: GNAExecutableNetwork(network, std::make_shared<GNAPlugin>(config)) {
|
||||
}
|
||||
|
||||
InferenceEngine::AsyncInferRequestInternal::Ptr
|
||||
CreateAsyncInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override {
|
||||
InferenceEngine::IInferRequestInternal::Ptr
|
||||
CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override {
|
||||
return std::make_shared<GNAInferRequest>(plg, networkInputs, networkOutputs);
|
||||
}
|
||||
|
||||
|
@ -8,13 +8,12 @@
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "cpp_interfaces/impl/ie_infer_async_request_internal.hpp"
|
||||
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||
#include "gna_plugin.hpp"
|
||||
|
||||
namespace GNAPluginNS {
|
||||
|
||||
class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
||||
class GNAInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
protected:
|
||||
std::shared_ptr<GNAPlugin> plg;
|
||||
uint32_t inferRequestIdx = -1;
|
||||
@ -23,7 +22,7 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
||||
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
|
||||
InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs)
|
||||
: InferenceEngine::AsyncInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
|
||||
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
|
||||
// TODO: internal connection API - better to generalize
|
||||
if (networkOutputs.empty()) {
|
||||
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
|
||||
@ -78,10 +77,19 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
||||
inferRequestIdx = plg->QueueInference(_inputs, _outputs);
|
||||
// workaround to unblock callback-based flows
|
||||
if (_callback) {
|
||||
auto infer_request = _publicInterface.lock();
|
||||
IE_ASSERT(infer_request != nullptr);
|
||||
auto res = Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
_callback(infer_request, res);
|
||||
std::exception_ptr exceptionPtr;
|
||||
if (res != InferenceEngine::StatusCode::OK) {
|
||||
try {
|
||||
IE_EXCEPTION_SWITCH(res, ExceptionType,
|
||||
InferenceEngine::details::ThrowNow<ExceptionType>{}
|
||||
<<= std::stringstream{} << IE_LOCATION
|
||||
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
|
||||
} catch (...) {
|
||||
exceptionPtr = std::current_exception();
|
||||
}
|
||||
}
|
||||
_callback(exceptionPtr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,9 +9,9 @@
|
||||
using namespace HeteroPlugin;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
HeteroAsyncInferRequest::HeteroAsyncInferRequest(const InferRequestInternal::Ptr& request,
|
||||
const ITaskExecutor::Ptr& taskExecutor,
|
||||
const ITaskExecutor::Ptr& callbackExecutor) :
|
||||
HeteroAsyncInferRequest::HeteroAsyncInferRequest(const IInferRequestInternal::Ptr& request,
|
||||
const ITaskExecutor::Ptr& taskExecutor,
|
||||
const ITaskExecutor::Ptr& callbackExecutor) :
|
||||
AsyncInferRequestThreadSafeDefault(request, taskExecutor, callbackExecutor),
|
||||
_heteroInferRequest(std::static_pointer_cast<HeteroInferRequest>(request)),
|
||||
_statusCodes{_heteroInferRequest->_inferRequests.size(), StatusCode::OK} {
|
||||
|
@ -19,10 +19,10 @@ namespace HeteroPlugin {
|
||||
class HeteroAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<HeteroAsyncInferRequest>;
|
||||
HeteroAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr& request,
|
||||
HeteroAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr& request,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
||||
~HeteroAsyncInferRequest() override;
|
||||
~HeteroAsyncInferRequest();
|
||||
void StartAsync_ThreadUnsafe() override;
|
||||
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
|
||||
|
||||
|
@ -24,9 +24,11 @@
|
||||
#include "transformations/serialize.hpp"
|
||||
#include "ie_ngraph_utils.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
#include "ie_algorithm.hpp"
|
||||
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"
|
||||
#include "hetero/hetero_plugin_config.hpp"
|
||||
#include "hetero_plugin.hpp"
|
||||
#include <ie_algorithm.hpp>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
@ -638,7 +640,7 @@ void HeteroExecutableNetwork::ExportImpl(std::ostream& heteroModel) {
|
||||
}
|
||||
}
|
||||
|
||||
InferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequestImpl(
|
||||
IInferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequestImpl(
|
||||
InputsDataMap networkInputs,
|
||||
OutputsDataMap networkOutputs) {
|
||||
HeteroInferRequest::SubRequestsList inferRequests;
|
||||
@ -655,7 +657,7 @@ InferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequestImpl(
|
||||
_blobNameMap);
|
||||
}
|
||||
|
||||
IInferRequest::Ptr HeteroExecutableNetwork::CreateInferRequest() {
|
||||
IInferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequest() {
|
||||
return CreateAsyncInferRequestFromSync<HeteroAsyncInferRequest>();
|
||||
}
|
||||
|
||||
|
@ -49,10 +49,10 @@ public:
|
||||
|
||||
~HeteroExecutableNetwork() override = default;
|
||||
|
||||
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
|
||||
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
|
||||
|
||||
InferenceEngine::Parameter GetConfig(const std::string &name) const override;
|
||||
|
||||
|
@ -20,7 +20,7 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
|
||||
InferenceEngine::OutputsDataMap networkOutputs,
|
||||
const SubRequestsList& inferRequests,
|
||||
const std::unordered_map<std::string, std::string>& subgraphInputToOutputBlobNames) :
|
||||
InferRequestInternal(networkInputs, networkOutputs),
|
||||
IInferRequestInternal(networkInputs, networkOutputs),
|
||||
_inferRequests(inferRequests) {
|
||||
if (_networkOutputs.empty() || _networkInputs.empty()) {
|
||||
IE_THROW() << "Internal error: no information about network's output/input";
|
||||
@ -65,7 +65,7 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
|
||||
}
|
||||
|
||||
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) {
|
||||
InferenceEngine::InferRequestInternal::SetBlob(name, data);
|
||||
InferenceEngine::IInferRequestInternal::SetBlob(name, data);
|
||||
assert(!_inferRequests.empty());
|
||||
for (auto &&desc : _inferRequests) {
|
||||
auto &r = desc._request;
|
||||
|
@ -15,14 +15,15 @@
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <ie_common.h>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
#include <cpp/ie_infer_request.hpp>
|
||||
#include <cpp/ie_executable_network.hpp>
|
||||
#include <openvino/itt.hpp>
|
||||
|
||||
namespace HeteroPlugin {
|
||||
|
||||
class HeteroInferRequest : public InferenceEngine::InferRequestInternal {
|
||||
class HeteroInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
typedef std::shared_ptr<HeteroInferRequest> Ptr;
|
||||
|
||||
|
@ -9,6 +9,7 @@ file (GLOB LIBRARY_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpp/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threading/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpp/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cpp_interfaces/interface/*.cpp
|
||||
)
|
||||
|
||||
# TODO: WA for OneHot pass usage in reshape
|
||||
|
@ -54,7 +54,7 @@ InferRequest ExecutableNetwork::CreateInferRequest() {
|
||||
}
|
||||
|
||||
InferRequest::Ptr ExecutableNetwork::CreateInferRequestPtr() {
|
||||
CALL_STATEMENT(return std::make_shared<InferRequest>(_impl->CreateInferRequest(), _so));
|
||||
CALL_STATEMENT(return std::make_shared<InferRequest>(InferRequest{_impl->CreateInferRequest(), _so}));
|
||||
}
|
||||
|
||||
void ExecutableNetwork::Export(const std::string& modelFileName) {
|
||||
|
209
inference-engine/src/inference_engine/cpp/ie_infer_request.cpp
Normal file
209
inference-engine/src/inference_engine/cpp/ie_infer_request.cpp
Normal file
@ -0,0 +1,209 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cpp/ie_infer_request.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
|
||||
#include "ie_remote_context.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
#define CATCH_IE_EXCEPTION(ExceptionType) catch (const InferenceEngine::ExceptionType& e) {throw e;}
|
||||
|
||||
#define CATCH_IE_EXCEPTIONS \
|
||||
CATCH_IE_EXCEPTION(GeneralError) \
|
||||
CATCH_IE_EXCEPTION(NotImplemented) \
|
||||
CATCH_IE_EXCEPTION(NetworkNotLoaded) \
|
||||
CATCH_IE_EXCEPTION(ParameterMismatch) \
|
||||
CATCH_IE_EXCEPTION(NotFound) \
|
||||
CATCH_IE_EXCEPTION(OutOfBounds) \
|
||||
CATCH_IE_EXCEPTION(Unexpected) \
|
||||
CATCH_IE_EXCEPTION(RequestBusy) \
|
||||
CATCH_IE_EXCEPTION(ResultNotReady) \
|
||||
CATCH_IE_EXCEPTION(NotAllocated) \
|
||||
CATCH_IE_EXCEPTION(InferNotStarted) \
|
||||
CATCH_IE_EXCEPTION(NetworkNotRead) \
|
||||
CATCH_IE_EXCEPTION(InferCancelled)
|
||||
|
||||
#define CALL_STATEMENT(...) \
|
||||
if (_impl == nullptr) IE_THROW() << "Inference Requst is not initialized"; \
|
||||
try { \
|
||||
__VA_ARGS__ \
|
||||
} CATCH_IE_EXCEPTIONS catch (const std::exception& ex) { \
|
||||
IE_THROW() << ex.what(); \
|
||||
} catch (...) { \
|
||||
IE_THROW(Unexpected); \
|
||||
}
|
||||
|
||||
InferRequest::InferRequest(const std::shared_ptr<IInferRequestInternal>& impl,
|
||||
const std::shared_ptr<details::SharedObjectLoader>& so) :
|
||||
_impl{impl},
|
||||
_so{so} {
|
||||
if (_impl == nullptr) IE_THROW() << "Inference Requst is not initialized";
|
||||
}
|
||||
|
||||
InferRequest::~InferRequest() {
|
||||
_impl = {};
|
||||
}
|
||||
|
||||
void InferRequest::SetBlob(const std::string& name, const Blob::Ptr& data) {
|
||||
CALL_STATEMENT(_impl->SetBlob(name, data);)
|
||||
}
|
||||
|
||||
Blob::Ptr InferRequest::GetBlob(const std::string& name) {
|
||||
Blob::Ptr blobPtr;
|
||||
CALL_STATEMENT(blobPtr = _impl->GetBlob(name);)
|
||||
std::string error = "Internal error: blob with name `" + name + "` is not allocated!";
|
||||
const bool remoteBlobPassed = blobPtr->is<RemoteBlob>();
|
||||
if (blobPtr == nullptr) IE_THROW() << error;
|
||||
if (!remoteBlobPassed && blobPtr->buffer() == nullptr) IE_THROW() << error;
|
||||
return blobPtr;
|
||||
}
|
||||
|
||||
void InferRequest::SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo& info) {
|
||||
CALL_STATEMENT(_impl->SetBlob(name, data, info);)
|
||||
}
|
||||
|
||||
const PreProcessInfo& InferRequest::GetPreProcess(const std::string& name) const {
|
||||
CALL_STATEMENT(return _impl->GetPreProcess(name);)
|
||||
}
|
||||
|
||||
void InferRequest::Infer() {
|
||||
CALL_STATEMENT(_impl->Infer();)
|
||||
}
|
||||
|
||||
void InferRequest::Cancel() {
|
||||
CALL_STATEMENT(_impl->Cancel();)
|
||||
}
|
||||
|
||||
std::map<std::string, InferenceEngineProfileInfo> InferRequest::GetPerformanceCounts() const {
|
||||
CALL_STATEMENT(return _impl->GetPerformanceCounts();)
|
||||
}
|
||||
|
||||
void InferRequest::SetInput(const BlobMap& inputs) {
|
||||
CALL_STATEMENT(
|
||||
for (auto&& input : inputs) {
|
||||
_impl->SetBlob(input.first, input.second);
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
void InferRequest::SetOutput(const BlobMap& results) {
|
||||
CALL_STATEMENT(
|
||||
for (auto&& result : results) {
|
||||
_impl->SetBlob(result.first, result.second);
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
void InferRequest::SetBatch(const int batch) {
|
||||
CALL_STATEMENT(_impl->SetBatch(batch);)
|
||||
}
|
||||
|
||||
void InferRequest::StartAsync() {
|
||||
CALL_STATEMENT(_impl->StartAsync();)
|
||||
}
|
||||
|
||||
|
||||
StatusCode InferRequest::Wait(int64_t millis_timeout) {
|
||||
CALL_STATEMENT(return _impl->Wait(millis_timeout);)
|
||||
}
|
||||
|
||||
void InferRequest::SetCompletionCallbackImpl(std::function<void()> callback) {
|
||||
CALL_STATEMENT(
|
||||
_impl->SetCallback([callback] (std::exception_ptr) {
|
||||
callback();
|
||||
});
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
#define CATCH_IE_EXCEPTION_RETURN(StatusCode, ExceptionType) catch (const ExceptionType&) {return StatusCode;}
|
||||
|
||||
#define CATCH_IE_EXCEPTIONS_RETURN \
|
||||
CATCH_IE_EXCEPTION_RETURN(GENERAL_ERROR, GeneralError) \
|
||||
CATCH_IE_EXCEPTION_RETURN(NOT_IMPLEMENTED, NotImplemented) \
|
||||
CATCH_IE_EXCEPTION_RETURN(NETWORK_NOT_LOADED, NetworkNotLoaded) \
|
||||
CATCH_IE_EXCEPTION_RETURN(PARAMETER_MISMATCH, ParameterMismatch) \
|
||||
CATCH_IE_EXCEPTION_RETURN(NOT_FOUND, NotFound) \
|
||||
CATCH_IE_EXCEPTION_RETURN(OUT_OF_BOUNDS, OutOfBounds) \
|
||||
CATCH_IE_EXCEPTION_RETURN(UNEXPECTED, Unexpected) \
|
||||
CATCH_IE_EXCEPTION_RETURN(REQUEST_BUSY, RequestBusy) \
|
||||
CATCH_IE_EXCEPTION_RETURN(RESULT_NOT_READY, ResultNotReady) \
|
||||
CATCH_IE_EXCEPTION_RETURN(NOT_ALLOCATED, NotAllocated) \
|
||||
CATCH_IE_EXCEPTION_RETURN(INFER_NOT_STARTED, InferNotStarted) \
|
||||
CATCH_IE_EXCEPTION_RETURN(NETWORK_NOT_READ, NetworkNotRead) \
|
||||
CATCH_IE_EXCEPTION_RETURN(INFER_CANCELLED, InferCancelled)
|
||||
|
||||
|
||||
void InferRequest::SetCompletionCallbackImpl(std::function<void(InferRequest, StatusCode)> callback) {
|
||||
CALL_STATEMENT(
|
||||
auto weakThis = InferRequest{std::shared_ptr<IInferRequestInternal>{_impl.get(), [](IInferRequestInternal*){}}, _so};
|
||||
_impl->SetCallback([callback, weakThis] (std::exception_ptr exceptionPtr) {
|
||||
StatusCode statusCode = StatusCode::OK;
|
||||
if (exceptionPtr != nullptr) {
|
||||
statusCode = [&] {
|
||||
try {
|
||||
std::rethrow_exception(exceptionPtr);
|
||||
} CATCH_IE_EXCEPTIONS_RETURN catch (const std::exception& ex) {
|
||||
return GENERAL_ERROR;
|
||||
} catch (...) {
|
||||
return UNEXPECTED;
|
||||
}
|
||||
} ();
|
||||
}
|
||||
callback(weakThis, statusCode);
|
||||
});
|
||||
)
|
||||
}
|
||||
|
||||
void InferRequest::SetCompletionCallbackImpl(IInferRequest::CompletionCallback callback) {
|
||||
CALL_STATEMENT(
|
||||
IInferRequest::Ptr weakThis = InferRequest{std::shared_ptr<IInferRequestInternal>{_impl.get(), [](IInferRequestInternal*){}}, _so};
|
||||
_impl->SetCallback([callback, weakThis] (std::exception_ptr exceptionPtr) {
|
||||
StatusCode statusCode = StatusCode::OK;
|
||||
if (exceptionPtr != nullptr) {
|
||||
statusCode = [&] {
|
||||
try {
|
||||
std::rethrow_exception(exceptionPtr);
|
||||
} CATCH_IE_EXCEPTIONS_RETURN catch (const std::exception& ex) {
|
||||
return GENERAL_ERROR;
|
||||
} catch (...) {
|
||||
return UNEXPECTED;
|
||||
}
|
||||
} ();
|
||||
}
|
||||
callback(weakThis, statusCode);
|
||||
});
|
||||
)
|
||||
}
|
||||
|
||||
std::vector<VariableState> InferRequest::QueryState() {
|
||||
std::vector<VariableState> controller;
|
||||
CALL_STATEMENT(
|
||||
for (auto&& state : _impl->QueryState()) {
|
||||
controller.emplace_back(std::make_shared<VariableStateBase>(state), _so);
|
||||
}
|
||||
)
|
||||
return controller;
|
||||
}
|
||||
|
||||
InferRequest::operator IInferRequest::Ptr () {
|
||||
CALL_STATEMENT(
|
||||
return std::make_shared<InferRequestBase>(_impl);
|
||||
)
|
||||
}
|
||||
|
||||
bool InferRequest::operator!() const noexcept {
|
||||
return !_impl;
|
||||
}
|
||||
|
||||
InferRequest::operator bool() const noexcept {
|
||||
return !!_impl;
|
||||
}
|
||||
} // namespace InferenceEngine
|
@ -0,0 +1,325 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <ie_blob.h>
|
||||
#include <ie_common.h>
|
||||
#include <ie_preprocess.hpp>
|
||||
#include <ie_compound_blob.h>
|
||||
#include <ie_algorithm.hpp>
|
||||
#include <debug.h>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iplugin_internal.hpp>
|
||||
#include <cpp_interfaces/plugin_itt.hpp>
|
||||
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
IInferRequestInternal::~IInferRequestInternal() {}
|
||||
|
||||
IInferRequestInternal::IInferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs) {
|
||||
// // We should copy maps since they can be overriden in SetBlob with preprocess
|
||||
copyInputOutputInfo(networkInputs, networkOutputs, _networkInputs, _networkOutputs);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::Infer() {
|
||||
checkBlobs();
|
||||
InferImpl();
|
||||
}
|
||||
|
||||
void IInferRequestInternal::InferImpl() {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::Cancel() {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
std::map<std::string, InferenceEngineProfileInfo> IInferRequestInternal::GetPerformanceCounts() const {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::SetBlob(const std::string& name, const Blob::Ptr& userBlob) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "SetBlob");
|
||||
if (name.empty()) {
|
||||
IE_THROW(NotFound) << "Failed to set blob with empty name";
|
||||
}
|
||||
if (!userBlob) IE_THROW(NotAllocated) << "Failed to set empty blob with name: \'" << name << "\'";
|
||||
const bool compoundBlobPassed = userBlob->is<CompoundBlob>();
|
||||
const bool remoteBlobPassed = userBlob->is<RemoteBlob>();
|
||||
if (!compoundBlobPassed && !remoteBlobPassed && userBlob->buffer() == nullptr)
|
||||
IE_THROW(NotAllocated) << "Input data was not allocated. Input name: \'" << name << "\'";
|
||||
if (userBlob->size() == 0) {
|
||||
IE_THROW() << "Input data is empty. Input name: \'" << name << "\'";
|
||||
}
|
||||
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
size_t dataSize = userBlob->size();
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
if (foundInput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
|
||||
IE_THROW(ParameterMismatch) << "Failed to set Blob with precision not corresponding to user input precision";
|
||||
}
|
||||
|
||||
auto& devBlob = _deviceInputs[name];
|
||||
const bool preProcRequired = preProcessingRequired(foundInput, userBlob, devBlob);
|
||||
if (compoundBlobPassed && !preProcRequired) {
|
||||
IE_THROW(NotImplemented) << "cannot set compound blob: supported only for input pre-processing";
|
||||
}
|
||||
|
||||
if (preProcRequired) {
|
||||
addInputPreProcessingFor(name, userBlob, devBlob ? devBlob : _inputs[name]);
|
||||
} else {
|
||||
size_t inputSize = foundInput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
|
||||
? InferenceEngine::details::product(foundInput->getTensorDesc().getDims())
|
||||
: 1;
|
||||
if (dataSize != inputSize) {
|
||||
IE_THROW() << "Input blob size is not equal network input size (" << dataSize << "!=" << inputSize << ").";
|
||||
}
|
||||
_inputs[name] = userBlob;
|
||||
devBlob = userBlob;
|
||||
}
|
||||
} else {
|
||||
if (compoundBlobPassed) {
|
||||
IE_THROW(NotImplemented) << "cannot set compound blob: supported only for input pre-processing";
|
||||
}
|
||||
size_t outputSize = foundOutput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
|
||||
? details::product(foundOutput->getTensorDesc().getDims()) :
|
||||
1;
|
||||
if (dataSize != outputSize) {
|
||||
IE_THROW() << "Output blob size is not equal network output size (" << dataSize << "!=" << outputSize << ").";
|
||||
}
|
||||
if (foundOutput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
|
||||
IE_THROW(ParameterMismatch) << "Failed to set Blob with precision not corresponding to user output precision";
|
||||
}
|
||||
_outputs[name] = userBlob;
|
||||
}
|
||||
}
|
||||
|
||||
Blob::Ptr IInferRequestInternal::GetBlob(const std::string& name) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "GetBlob");
|
||||
Blob::Ptr data;
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
const SizeVector oneVector = { 1 };
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
// ROI blob is returned only if it was set previously. Otherwise default blob is returned.
|
||||
auto it = _preProcData.find(name);
|
||||
if (it != _preProcData.end()) {
|
||||
data = it->second->getRoiBlob();
|
||||
} else {
|
||||
data = _inputs[name];
|
||||
checkBlob(data, name, true,
|
||||
foundInput->getTensorDesc().getLayout() != SCALAR
|
||||
? foundInput->getTensorDesc().getDims()
|
||||
: oneVector);
|
||||
|
||||
auto& devBlob = _deviceInputs[name];
|
||||
if (preProcessingRequired(foundInput, data, devBlob)) {
|
||||
// if no devBlob, performs inplace
|
||||
addInputPreProcessingFor(name, data, devBlob ? devBlob : _inputs[name]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
data = _outputs[name];
|
||||
checkBlob(data, name, false,
|
||||
foundOutput->getTensorDesc().getLayout() != SCALAR
|
||||
? foundOutput->getTensorDesc().getDims()
|
||||
: oneVector);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
void IInferRequestInternal::SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) {
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
copyPreProcess(info, foundInput->getPreProcess());
|
||||
} else {
|
||||
IE_THROW() << "Pre-process can't be set to output blob";
|
||||
}
|
||||
|
||||
SetBlob(name, data);
|
||||
}
|
||||
|
||||
const PreProcessInfo& IInferRequestInternal::GetPreProcess(const std::string& name) const {
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
return foundInput->getPreProcess();
|
||||
} else {
|
||||
IE_THROW() << "Output blob can't have pre-processing";
|
||||
}
|
||||
}
|
||||
|
||||
void IInferRequestInternal::SetBatch(int batch) {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<IVariableStateInternal>> IInferRequestInternal::QueryState() {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::StartAsync() {
|
||||
checkBlobs();
|
||||
StartAsyncImpl();
|
||||
}
|
||||
|
||||
void IInferRequestInternal::StartAsyncImpl() {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
StatusCode IInferRequestInternal::Wait(int64_t millis_timeout) {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::SetCallback(Callback callback) {
|
||||
_callback = std::move(callback);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::execDataPreprocessing(InferenceEngine::BlobMap& preprocessedBlobs, bool serial) {
|
||||
for (auto& input : preprocessedBlobs) {
|
||||
// If there is a pre-process entry for an input then it must be pre-processed
|
||||
// using preconfigured resize algorithm.
|
||||
auto it = _preProcData.find(input.first);
|
||||
if (it != _preProcData.end()) {
|
||||
it->second->execute(input.second, _networkInputs[input.first]->getPreProcess(), serial, m_curBatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IInferRequestInternal::findInputAndOutputBlobByName(const std::string& name, InputInfo::Ptr& foundInput, DataPtr& foundOutput) const {
|
||||
foundInput = nullptr;
|
||||
foundOutput = nullptr;
|
||||
if (_networkOutputs.empty()) {
|
||||
IE_THROW() << "Internal error: network outputs is not set";
|
||||
}
|
||||
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
|
||||
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
|
||||
[&](const std::pair<std::string, DataPtr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
if (foundOutputPair == std::end(_networkOutputs) && (foundInputPair == std::end(_networkInputs))) {
|
||||
IE_THROW(NotFound) << "Failed to find input or output with name: \'" << name << "\'";
|
||||
}
|
||||
if (foundInputPair != std::end(_networkInputs)) {
|
||||
foundInput = foundInputPair->second;
|
||||
return true;
|
||||
} else {
|
||||
foundOutput = foundOutputPair->second;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void IInferRequestInternal::checkBlob(const Blob::Ptr& blob, const std::string& name, bool isInput, const SizeVector& refDims) const {
|
||||
std::string bType = isInput ? "Input" : "Output";
|
||||
std::string sType = isInput ? "input" : "output";
|
||||
std::string strNotAllocated(bType + " data was not allocated.");
|
||||
std::string strNotMatched("The " + sType + " blob size is not equal to the network " + sType + " size");
|
||||
|
||||
if (!blob) {
|
||||
IE_THROW(NotAllocated) << strNotAllocated;
|
||||
}
|
||||
size_t refSize;
|
||||
if (refDims.empty()) {
|
||||
SizeVector dims;
|
||||
if (isInput) {
|
||||
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
|
||||
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
if (foundInputPair == std::end(_networkInputs)) {
|
||||
IE_THROW(NotFound) << "Failed to find input with name: \'" << name << "\'";
|
||||
}
|
||||
dims = foundInputPair->second->getTensorDesc().getDims();
|
||||
refSize = foundInputPair->second->getTensorDesc().getLayout() != SCALAR
|
||||
? details::product(dims)
|
||||
: 1;
|
||||
} else {
|
||||
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
|
||||
[&](const std::pair<std::string, DataPtr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
if (foundOutputPair == std::end(_networkOutputs)) {
|
||||
IE_THROW(NotFound) << "Failed to find output with name: \'" << name << "\'";
|
||||
}
|
||||
dims = foundOutputPair->second->getTensorDesc().getDims();
|
||||
refSize = foundOutputPair->second->getTensorDesc().getLayout() != SCALAR
|
||||
? details::product(dims)
|
||||
: 1;
|
||||
}
|
||||
} else {
|
||||
refSize = details::product(refDims);
|
||||
}
|
||||
|
||||
if (refSize != blob->size()) {
|
||||
IE_THROW() << strNotMatched + ": got " << blob->size() << " expecting " << refSize;
|
||||
}
|
||||
const bool remoteBlobPassed = blob->is<RemoteBlob>();
|
||||
if (!remoteBlobPassed && blob->buffer() == nullptr) IE_THROW() << strNotAllocated;
|
||||
}
|
||||
|
||||
void IInferRequestInternal::checkBlobs() {
|
||||
for (auto const& input : _inputs) {
|
||||
checkBlob(input.second, input.first, true);
|
||||
}
|
||||
for (auto const& output : _outputs) {
|
||||
checkBlob(output.second, output.first, false);
|
||||
}
|
||||
}
|
||||
|
||||
void IInferRequestInternal::setPointerToExecutableNetworkInternal(const std::shared_ptr<IExecutableNetworkInternal>& exeNetwork) {
|
||||
_exeNetwork = exeNetwork;
|
||||
}
|
||||
|
||||
bool IInferRequestInternal::preProcessingRequired(const InputInfo::Ptr& info, const Blob::Ptr& userBlob, const Blob::Ptr& deviceBlob) {
|
||||
// pre-processing is required if:
|
||||
// 1. resize algorithm is specified (resize required)
|
||||
// 2. color format specified:
|
||||
// 2.a. color format is not equal to network's expected (color conversion required)
|
||||
// 2.b. network's layout != blob's layout (reorder required)
|
||||
// 3. precision conversion is required
|
||||
|
||||
const auto& preProcessInfo = info->getPreProcess();
|
||||
const auto inputColorFormat = preProcessInfo.getColorFormat();
|
||||
// FIXME: support other network's input formats once the API is ready. Assuming input is in
|
||||
// the BGR format by default
|
||||
const auto networkColorFormat = ColorFormat::BGR;
|
||||
const bool colorFormatSpecified = inputColorFormat != ColorFormat::RAW;
|
||||
|
||||
auto blob_layout = [](const Blob::Ptr& b) { return b->getTensorDesc().getLayout(); };
|
||||
auto blob_prec = [](const Blob::Ptr& b) { return b->getTensorDesc().getPrecision();};
|
||||
|
||||
auto dst_layout = deviceBlob ? blob_layout(deviceBlob) : info->getLayout();
|
||||
auto dst_prec = deviceBlob ? blob_prec(deviceBlob) : info->getPrecision();
|
||||
|
||||
//FIXME: remove the first part to allow any needed conversion?
|
||||
const bool need_layout_conv = (colorFormatSpecified || deviceBlob) &&
|
||||
(blob_layout(userBlob) != dst_layout);
|
||||
|
||||
return preProcessInfo.getResizeAlgorithm() != ResizeAlgorithm::NO_RESIZE ||
|
||||
(colorFormatSpecified && inputColorFormat != networkColorFormat) ||
|
||||
need_layout_conv ||
|
||||
(blob_prec(userBlob) != dst_prec);
|
||||
}
|
||||
|
||||
void IInferRequestInternal::addInputPreProcessingFor(const std::string& name, Blob::Ptr const& from, const Blob::Ptr& to) {
|
||||
auto ppDataIt = _preProcData.find(name);
|
||||
if (ppDataIt == _preProcData.end()) {
|
||||
ppDataIt = (_preProcData.emplace(name, CreatePreprocDataHelper())).first;
|
||||
}
|
||||
|
||||
auto& preproc_ptr = ppDataIt->second;
|
||||
preproc_ptr->isApplicable(from, to);
|
||||
// Stores the given blob as ROI blob. It will be used to fill in network input
|
||||
// during pre-processing
|
||||
preproc_ptr->setRoiBlob(from);
|
||||
}
|
||||
} // namespace InferenceEngine
|
@ -57,42 +57,38 @@ namespace details {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
StatusCode InferenceEngineException::getStatus() const {
|
||||
return ExceptionToStatus(dynamic_cast<const Exception&>(*this));
|
||||
}
|
||||
} // namespace details
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exception) {
|
||||
if (dynamic_cast<const GeneralError*>(&exception) != nullptr) {
|
||||
if (dynamic_cast<const GeneralError*>(this) != nullptr) {
|
||||
return GENERAL_ERROR;
|
||||
} else if (dynamic_cast<const NotImplemented*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const NotImplemented*>(this) != nullptr) {
|
||||
return NOT_IMPLEMENTED;
|
||||
} else if (dynamic_cast<const NetworkNotLoaded*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const NetworkNotLoaded*>(this) != nullptr) {
|
||||
return NETWORK_NOT_LOADED;
|
||||
} else if (dynamic_cast<const ParameterMismatch*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const ParameterMismatch*>(this) != nullptr) {
|
||||
return PARAMETER_MISMATCH;
|
||||
} else if (dynamic_cast<const NotFound*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const NotFound*>(this) != nullptr) {
|
||||
return NOT_FOUND;
|
||||
} else if (dynamic_cast<const OutOfBounds*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const OutOfBounds*>(this) != nullptr) {
|
||||
return OUT_OF_BOUNDS;
|
||||
} else if (dynamic_cast<const Unexpected*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const Unexpected*>(this) != nullptr) {
|
||||
return UNEXPECTED;
|
||||
} else if (dynamic_cast<const RequestBusy*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const RequestBusy*>(this) != nullptr) {
|
||||
return REQUEST_BUSY;
|
||||
} else if (dynamic_cast<const ResultNotReady*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const ResultNotReady*>(this) != nullptr) {
|
||||
return RESULT_NOT_READY;
|
||||
} else if (dynamic_cast<const NotAllocated*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const NotAllocated*>(this) != nullptr) {
|
||||
return NOT_ALLOCATED;
|
||||
} else if (dynamic_cast<const InferNotStarted*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const InferNotStarted*>(this) != nullptr) {
|
||||
return INFER_NOT_STARTED;
|
||||
} else if (dynamic_cast<const NetworkNotRead*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const NetworkNotRead*>(this) != nullptr) {
|
||||
return NETWORK_NOT_READ;
|
||||
} else if (dynamic_cast<const InferCancelled*>(&exception) != nullptr) {
|
||||
} else if (dynamic_cast<const InferCancelled*>(this) != nullptr) {
|
||||
return INFER_CANCELLED;
|
||||
} else {
|
||||
assert(!"Unreachable"); return OK;
|
||||
}
|
||||
}
|
||||
} // namespace details
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
//
|
||||
// ie_parameter.hpp
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include "mkldnn_async_infer_request.h"
|
||||
#include <memory>
|
||||
|
||||
MKLDNNPlugin::MKLDNNAsyncInferRequest::MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr& inferRequest,
|
||||
MKLDNNPlugin::MKLDNNAsyncInferRequest::MKLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr& inferRequest,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor)
|
||||
: InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor) {
|
||||
|
@ -13,10 +13,10 @@ namespace MKLDNNPlugin {
|
||||
|
||||
class MKLDNNAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
|
||||
public:
|
||||
MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
|
||||
MKLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr &inferRequest,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor);
|
||||
~MKLDNNAsyncInferRequest() override;
|
||||
~MKLDNNAsyncInferRequest();
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -29,7 +29,7 @@ using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
using namespace InferenceEngine::details;
|
||||
|
||||
InferenceEngine::InferRequestInternal::Ptr
|
||||
InferenceEngine::IInferRequestInternal::Ptr
|
||||
MKLDNNExecNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) {
|
||||
return std::make_shared<MKLDNNInferRequest>(networkInputs, networkOutputs, std::static_pointer_cast<MKLDNNExecNetwork>(shared_from_this()));
|
||||
@ -323,7 +323,7 @@ void MKLDNNExecNetwork::setProperty(const std::map<std::string, std::string> &pr
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine::IInferRequest::Ptr MKLDNNExecNetwork::CreateInferRequest() {
|
||||
InferenceEngine::IInferRequestInternal::Ptr MKLDNNExecNetwork::CreateInferRequest() {
|
||||
return CreateAsyncInferRequestFromSync<MKLDNNAsyncInferRequest>();
|
||||
}
|
||||
|
||||
|
@ -23,11 +23,11 @@ class MKLDNNExecNetwork: public InferenceEngine::ExecutableNetworkThreadSafeDefa
|
||||
public:
|
||||
typedef std::shared_ptr<MKLDNNExecNetwork> Ptr;
|
||||
|
||||
InferenceEngine::InferRequestInternal::Ptr
|
||||
InferenceEngine::IInferRequestInternal::Ptr
|
||||
CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
|
||||
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
|
||||
|
||||
MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network, const Config &cfg,
|
||||
const MKLDNNExtensionManager::Ptr &extMgr, NumaNodesWeights &weightsSharing);
|
||||
|
@ -19,11 +19,13 @@
|
||||
#include "nodes/mkldnn_memory_node.hpp"
|
||||
#include "nodes/common/cpu_memcpy.h"
|
||||
#include "mkldnn_async_infer_request.h"
|
||||
#include <debug.h>
|
||||
|
||||
|
||||
MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs,
|
||||
MKLDNNExecNetwork::Ptr execNetwork_)
|
||||
: InferRequestInternal(networkInputs, networkOutputs)
|
||||
: IInferRequestInternal(networkInputs, networkOutputs)
|
||||
, execNetwork(execNetwork_) {
|
||||
auto id = (execNetwork->_numRequests)++;
|
||||
profilingTask = openvino::itt::handle("MKLDNN_INFER_" + execNetwork->_name + "_" + std::to_string(id));
|
||||
|
@ -8,21 +8,21 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
class MKLDNNExecNetwork;
|
||||
class MKLDNNAsyncInferRequest;
|
||||
|
||||
class MKLDNNInferRequest : public InferenceEngine::InferRequestInternal {
|
||||
class MKLDNNInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
typedef std::shared_ptr<MKLDNNInferRequest> Ptr;
|
||||
explicit MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs,
|
||||
std::shared_ptr<MKLDNNExecNetwork> execNetwork);
|
||||
|
||||
~MKLDNNInferRequest() override;
|
||||
~MKLDNNInferRequest();
|
||||
|
||||
void InferImpl() override;
|
||||
|
||||
@ -34,7 +34,7 @@ public:
|
||||
|
||||
void SetBatch(int batch = -1) override;
|
||||
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
|
||||
|
||||
/**
|
||||
* @brief Sets the pointer to asynchronous inference request that holds this request
|
||||
@ -59,7 +59,7 @@ private:
|
||||
MKLDNNGraph* graph = nullptr;
|
||||
std::map<std::string, void*> externalPtr;
|
||||
openvino::itt::handle_t profilingTask;
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> memoryStates;
|
||||
MKLDNNAsyncInferRequest* _asyncRequest = nullptr;
|
||||
};
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -27,7 +27,7 @@ public:
|
||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
||||
void Infer_ThreadUnsafe() override;
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||
~MultiDeviceAsyncInferRequest() override;
|
||||
~MultiDeviceAsyncInferRequest();
|
||||
|
||||
protected:
|
||||
MultiDeviceExecutableNetwork::Ptr _multiDeviceExecutableNetwork;
|
||||
|
@ -174,7 +174,7 @@ RemoteContext::Ptr MultiDeviceExecutableNetwork::GetContext() const {
|
||||
<< " Current list of devices allowed via the DEVICE_PRIORITIES config: " << devices_names;
|
||||
}
|
||||
|
||||
InferenceEngine::InferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::IInferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) {
|
||||
auto num = _numRequestsCreated++;
|
||||
size_t sum = 0;
|
||||
@ -192,17 +192,13 @@ InferenceEngine::InferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateI
|
||||
return std::make_shared<MultiDeviceInferRequest>(networkInputs, networkOutputs, request_to_share_blobs_with);
|
||||
}
|
||||
|
||||
IInferRequest::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
IInferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
|
||||
auto syncRequestImpl = CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||
auto asyncTreadSafeImpl = std::make_shared<MultiDeviceAsyncInferRequest>(std::static_pointer_cast<MultiDeviceInferRequest>(syncRequestImpl),
|
||||
_needPerfCounters,
|
||||
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
|
||||
_callbackExecutor);
|
||||
asyncRequest.reset(new InferRequestBase(asyncTreadSafeImpl));
|
||||
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
return std::make_shared<MultiDeviceAsyncInferRequest>(std::static_pointer_cast<MultiDeviceInferRequest>(syncRequestImpl),
|
||||
_needPerfCounters,
|
||||
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
|
||||
_callbackExecutor);
|
||||
}
|
||||
|
||||
void MultiDeviceExecutableNetwork::SetConfig(const std::map<std::string, InferenceEngine::Parameter> &config) {
|
||||
|
@ -114,9 +114,9 @@ public:
|
||||
InferenceEngine::Parameter GetConfig(const std::string &name) const override;
|
||||
InferenceEngine::Parameter GetMetric(const std::string &name) const override;
|
||||
void run(InferenceEngine::Task inferTask) override;
|
||||
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
|
||||
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs) override;
|
||||
InferenceEngine::RemoteContext::Ptr GetContext() const override;
|
||||
~MultiDeviceExecutableNetwork() override;
|
||||
|
||||
|
@ -5,6 +5,10 @@
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "multi_device_infer_request.hpp"
|
||||
#include <ie_input_info.hpp>
|
||||
#include <ie_icnn_network.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
#include <blob_factory.hpp>
|
||||
|
||||
namespace MultiDevicePlugin {
|
||||
using namespace InferenceEngine;
|
||||
@ -12,7 +16,7 @@ namespace MultiDevicePlugin {
|
||||
MultiDeviceInferRequest::MultiDeviceInferRequest(const InputsDataMap& networkInputs,
|
||||
const OutputsDataMap& networkOutputs,
|
||||
InferRequest request_to_share_blobs_with)
|
||||
: InferRequestInternal(networkInputs, networkOutputs) {
|
||||
: IInferRequestInternal(networkInputs, networkOutputs) {
|
||||
if (request_to_share_blobs_with) {
|
||||
// borrow device-friendly blobs from the request
|
||||
for (const auto &it : _networkInputs)
|
||||
|
@ -14,12 +14,13 @@
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
#include <cpp/ie_executable_network.hpp>
|
||||
#include <cpp/ie_infer_request.hpp>
|
||||
|
||||
namespace MultiDevicePlugin {
|
||||
|
||||
class MultiDeviceInferRequest : public InferenceEngine::InferRequestInternal {
|
||||
class MultiDeviceInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<MultiDeviceInferRequest>;
|
||||
explicit MultiDeviceInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <multi-device/multi_device_config.hpp>
|
||||
#include <threading/ie_executor_manager.hpp>
|
||||
#include "multi_device_plugin.hpp"
|
||||
#include <ie_algorithm.hpp>
|
||||
|
||||
// ------------------------------MultiDeviceInferencePlugin----------------------------
|
||||
namespace MultiDevicePlugin {
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "cpp_interfaces/exception2status.hpp"
|
||||
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
@ -53,7 +54,7 @@ public:
|
||||
}
|
||||
|
||||
StatusCode CreateInferRequest(IInferRequest::Ptr& req, ResponseDesc* resp) noexcept override {
|
||||
TO_STATUS(req = _impl->CreateInferRequest());
|
||||
TO_STATUS(req = std::make_shared<InferRequestBase>(_impl->CreateInferRequest()));
|
||||
}
|
||||
|
||||
StatusCode Export(const std::string& modelFileName, ResponseDesc* resp) noexcept override {
|
||||
|
@ -11,25 +11,92 @@
|
||||
#include "cpp_interfaces/exception2status.hpp"
|
||||
#include "cpp_interfaces/plugin_itt.hpp"
|
||||
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
#include "ie_iinfer_request.hpp"
|
||||
#include "ie_preprocess.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
#define CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(StatusCode, ExceptionType) catch (const ExceptionType& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(StatusCode) << ex.what(); \
|
||||
}
|
||||
|
||||
#define CATCH_IE_EXCEPTIONS_TO_STATUS_NO_RESP \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(GENERAL_ERROR, GeneralError) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NOT_IMPLEMENTED, NotImplemented) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NETWORK_NOT_LOADED, NetworkNotLoaded) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(PARAMETER_MISMATCH, ParameterMismatch) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NOT_FOUND, NotFound) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(OUT_OF_BOUNDS, OutOfBounds) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(UNEXPECTED, Unexpected) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(REQUEST_BUSY, RequestBusy) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(RESULT_NOT_READY, ResultNotReady) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NOT_ALLOCATED, NotAllocated) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(INFER_NOT_STARTED, InferNotStarted) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NETWORK_NOT_READ, NetworkNotRead) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(INFER_CANCELLED, InferCancelled)
|
||||
|
||||
/**
|
||||
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
|
||||
* @ingroup ie_dev_api_async_infer_request_api
|
||||
* @def TO_STATUS_NO_RESP(x)
|
||||
* @brief Converts C++ exceptioned function call into a status code. Does not work with a ResponseDesc object
|
||||
* @ingroup ie_dev_api_error_debug
|
||||
*/
|
||||
#define TO_STATUS_NO_RESP(x) \
|
||||
try { \
|
||||
x; \
|
||||
return OK; \
|
||||
} CATCH_IE_EXCEPTIONS_TO_STATUS_NO_RESP catch (const std::exception& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR) << ex.what(); \
|
||||
} catch (...) { \
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
|
||||
}
|
||||
|
||||
#define CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(StatusCode, ExceptionType) catch (const ExceptionType& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(StatusCode, resp) << ex.what(); \
|
||||
}
|
||||
|
||||
#define CATCH_IE_EXCEPTIONS_CALL_RETURN_STATUS \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(GENERAL_ERROR, GeneralError) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NOT_IMPLEMENTED, NotImplemented) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NETWORK_NOT_LOADED, NetworkNotLoaded) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(PARAMETER_MISMATCH, ParameterMismatch) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NOT_FOUND, NotFound) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(OUT_OF_BOUNDS, OutOfBounds) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(UNEXPECTED, Unexpected) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(REQUEST_BUSY, RequestBusy) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(RESULT_NOT_READY, ResultNotReady) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NOT_ALLOCATED, NotAllocated) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(INFER_NOT_STARTED, InferNotStarted) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NETWORK_NOT_READ, NetworkNotRead) \
|
||||
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(INFER_CANCELLED, InferCancelled)
|
||||
|
||||
/**
|
||||
* @def NO_EXCEPT_CALL_RETURN_STATUS(x)
|
||||
* @brief Returns a status code of a called function, handles exeptions and converts to a status code.
|
||||
* @ingroup ie_dev_api_error_debug
|
||||
*/
|
||||
#define NO_EXCEPT_CALL_RETURN_STATUS(x) \
|
||||
try { \
|
||||
return x; \
|
||||
} CATCH_IE_EXCEPTIONS_CALL_RETURN_STATUS catch (const std::exception& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); \
|
||||
} catch (...) { \
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Inference request `noexcept` wrapper which accepts IInferRequestInternal derived instance which can throw exceptions
|
||||
* @ingroup ie_dev_api_infer_request_api
|
||||
*/
|
||||
class InferRequestBase : public IInferRequest {
|
||||
std::shared_ptr<IAsyncInferRequestInternal> _impl;
|
||||
std::shared_ptr<IInferRequestInternal> _impl;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor with actual underlying implementation.
|
||||
* @param impl Underlying implementation of type IAsyncInferRequestInternal
|
||||
* @param impl Underlying implementation of type IInferRequestInternal
|
||||
*/
|
||||
explicit InferRequestBase(std::shared_ptr<IAsyncInferRequestInternal> impl): _impl(impl) {}
|
||||
explicit InferRequestBase(std::shared_ptr<IInferRequestInternal> impl): _impl(impl) {}
|
||||
|
||||
StatusCode Infer(ResponseDesc* resp) noexcept override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
|
||||
@ -73,15 +140,30 @@ public:
|
||||
}
|
||||
|
||||
StatusCode SetCompletionCallback(CompletionCallback callback) noexcept override {
|
||||
TO_STATUS_NO_RESP(_impl->SetCompletionCallback(callback));
|
||||
TO_STATUS_NO_RESP(_impl->SetCallback([callback, this] (std::exception_ptr exceptionPtr) {
|
||||
StatusCode statusCode = [&] ()-> StatusCode {
|
||||
if (exceptionPtr) {
|
||||
TO_STATUS_NO_RESP(std::rethrow_exception(exceptionPtr));
|
||||
} else {
|
||||
return OK;
|
||||
}
|
||||
} ();
|
||||
callback(std::shared_ptr<InferRequestBase>{this, [](InferRequestBase*){}}, statusCode);
|
||||
}));
|
||||
}
|
||||
|
||||
StatusCode GetUserData(void** data, ResponseDesc* resp) noexcept override {
|
||||
TO_STATUS(_impl->GetUserData(data));
|
||||
StatusCode GetUserData(void** data, ResponseDesc*) noexcept override {
|
||||
if (data != nullptr) {
|
||||
*data = _data;
|
||||
return OK;
|
||||
} else {
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
StatusCode SetUserData(void* data, ResponseDesc* resp) noexcept override {
|
||||
TO_STATUS(_impl->SetUserData(data));
|
||||
StatusCode SetUserData(void* data, ResponseDesc*) noexcept override {
|
||||
_data = data;
|
||||
return OK;
|
||||
}
|
||||
|
||||
StatusCode SetBatch(int batch_size, ResponseDesc* resp) noexcept override {
|
||||
@ -104,6 +186,8 @@ public:
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
void* _data = nullptr;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -13,8 +13,24 @@
|
||||
#include "description_buffer.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
#define CATCH_IE_EXCEPTION_TO_STATUS(StatusCode, ExceptionType) catch (const ExceptionType& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(StatusCode, resp) << ex.what(); \
|
||||
}
|
||||
|
||||
INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exception);
|
||||
#define CATCH_IE_EXCEPTIONS_TO_STATUS \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(GENERAL_ERROR, GeneralError) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(NOT_IMPLEMENTED, NotImplemented) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(NETWORK_NOT_LOADED, NetworkNotLoaded) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(PARAMETER_MISMATCH, ParameterMismatch) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(NOT_FOUND, NotFound) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(OUT_OF_BOUNDS, OutOfBounds) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(UNEXPECTED, Unexpected) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(REQUEST_BUSY, RequestBusy) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(RESULT_NOT_READY, ResultNotReady) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(NOT_ALLOCATED, NotAllocated) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(INFER_NOT_STARTED, InferNotStarted) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(NETWORK_NOT_READ, NetworkNotRead) \
|
||||
CATCH_IE_EXCEPTION_TO_STATUS(INFER_CANCELLED, InferCancelled)
|
||||
|
||||
/**
|
||||
* @def TO_STATUS(x)
|
||||
@ -25,42 +41,7 @@ INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exceptio
|
||||
try { \
|
||||
x; \
|
||||
return OK; \
|
||||
} catch (const ::InferenceEngine::Exception& iex) { \
|
||||
return InferenceEngine::DescriptionBuffer(InferenceEngine::ExceptionToStatus(iex), resp) << iex.what(); \
|
||||
} catch (const std::exception& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); \
|
||||
} catch (...) { \
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
|
||||
}
|
||||
|
||||
/**
|
||||
* @def TO_STATUS_NO_RESP(x)
|
||||
* @brief Converts C++ exceptioned function call into a status code. Does not work with a ResponseDesc object
|
||||
* @ingroup ie_dev_api_error_debug
|
||||
*/
|
||||
#define TO_STATUS_NO_RESP(x) \
|
||||
try { \
|
||||
x; \
|
||||
return OK; \
|
||||
} catch (const ::InferenceEngine::Exception& iex) { \
|
||||
return InferenceEngine::DescriptionBuffer(InferenceEngine::ExceptionToStatus(iex)) << iex.what(); \
|
||||
} catch (const std::exception& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR) << ex.what(); \
|
||||
} catch (...) { \
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
|
||||
}
|
||||
|
||||
/**
|
||||
* @def NO_EXCEPT_CALL_RETURN_STATUS(x)
|
||||
* @brief Returns a status code of a called function, handles exeptions and converts to a status code.
|
||||
* @ingroup ie_dev_api_error_debug
|
||||
*/
|
||||
#define NO_EXCEPT_CALL_RETURN_STATUS(x) \
|
||||
try { \
|
||||
return x; \
|
||||
} catch (const ::InferenceEngine::Exception& iex) { \
|
||||
return InferenceEngine::DescriptionBuffer(InferenceEngine::ExceptionToStatus(iex), resp) << iex.what(); \
|
||||
} catch (const std::exception& ex) { \
|
||||
} CATCH_IE_EXCEPTIONS_TO_STATUS catch (const std::exception& ex) { \
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); \
|
||||
} catch (...) { \
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
|
||||
@ -82,5 +63,4 @@ INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exceptio
|
||||
CATCH_IE_EXCEPTION(InferNotStarted) \
|
||||
CATCH_IE_EXCEPTION(NetworkNotRead) \
|
||||
CATCH_IE_EXCEPTION(InferCancelled)
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -11,8 +11,6 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
#include "cpp_interfaces/impl/ie_infer_async_request_internal.hpp"
|
||||
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
|
||||
@ -118,7 +116,29 @@ public:
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Creates an inference request public implementation.
|
||||
* @return The request public implementation
|
||||
*/
|
||||
IInferRequestInternal::Ptr CreateInferRequest() override {
|
||||
auto asyncRequestImpl = this->CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||
return asyncRequestImpl;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Creates an asynchronous inference request internal implementation.
|
||||
* @note The method is called by ExecutableNetworkInternal::CreateInferRequest as
|
||||
* plugin-specific implementation.
|
||||
* @param[in] networkInputs The network inputs
|
||||
* @param[in] networkOutputs The network outputs
|
||||
* @return A shared pointer to asynchnous inference request object.
|
||||
*/
|
||||
virtual IInferRequestInternal::Ptr CreateInferRequestImpl(InputsDataMap networkInputs,
|
||||
OutputsDataMap networkOutputs) = 0;
|
||||
|
||||
/**
|
||||
* @brief Exports an internal hardware-dependent model to a stream.
|
||||
* @note The function is called from ExecutableNetworkInternal::Export(std::ostream&),
|
||||
@ -130,7 +150,7 @@ protected:
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
InferenceEngine::InputsDataMap _networkInputs; //!< Holds infromation about network inputs info
|
||||
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
|
||||
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
|
||||
|
||||
/**
|
||||
|
@ -1,59 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_iinfer_request.hpp>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
|
||||
#include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
|
||||
#include "cpp_interfaces/impl/ie_infer_async_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @brief This class describes an executable network thread safe asynchronous only implementation.
|
||||
* @ingroup ie_dev_api_exec_network_api
|
||||
*/
|
||||
class ExecutableNetworkThreadSafeAsyncOnly : public ExecutableNetworkInternal {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to a ExecutableNetworkThreadSafeAsyncOnly object
|
||||
*/
|
||||
typedef std::shared_ptr<ExecutableNetworkThreadSafeAsyncOnly> Ptr;
|
||||
|
||||
/**
|
||||
* @brief Creates an asynchronous inference request public implementation.
|
||||
* @return The asynchronous request public implementation
|
||||
*/
|
||||
IInferRequest::Ptr CreateInferRequest() override {
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||
|
||||
asyncRequest.reset(new InferRequestBase(asyncRequestImpl));
|
||||
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Creates an asynchronous inference request internal implementation.
|
||||
* @note The method is called by ExecutableNetworkThreadSafeAsyncOnly::CreateInferRequest as
|
||||
* plugin-specific implementation.
|
||||
* @param[in] networkInputs The network inputs
|
||||
* @param[in] networkOutputs The network outputs
|
||||
* @return A shared pointer to asynchnous inference request object.
|
||||
*/
|
||||
virtual AsyncInferRequestInternal::Ptr CreateAsyncInferRequestImpl(InputsDataMap networkInputs,
|
||||
OutputsDataMap networkOutputs) = 0;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
@ -12,7 +12,6 @@
|
||||
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
|
||||
#include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
|
||||
#include "cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp"
|
||||
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
|
||||
#include "threading/ie_cpu_streams_executor.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
@ -49,7 +48,7 @@ public:
|
||||
* need for it to be implemented by plugin
|
||||
* @return shared_ptr for the created asynchronous inference request
|
||||
*/
|
||||
IInferRequest::Ptr CreateInferRequest() override {
|
||||
IInferRequestInternal::Ptr CreateInferRequest() override {
|
||||
return CreateAsyncInferRequestFromSync();
|
||||
}
|
||||
|
||||
@ -60,28 +59,12 @@ protected:
|
||||
* @return A shared pointer to an asynchronous inference request
|
||||
*/
|
||||
template <typename AsyncInferRequestType = AsyncInferRequestThreadSafeDefault>
|
||||
IInferRequest::Ptr CreateAsyncInferRequestFromSync() {
|
||||
IInferRequestInternal::Ptr CreateAsyncInferRequestFromSync() {
|
||||
auto syncRequestImpl = this->CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||
|
||||
auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
|
||||
syncRequestImpl, _taskExecutor, _callbackExecutor);
|
||||
IInferRequest::Ptr asyncRequest = std::make_shared<InferRequestBase>(asyncThreadSafeImpl);
|
||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
return asyncRequest;
|
||||
return std::make_shared<AsyncInferRequestType>(syncRequestImpl, _taskExecutor, _callbackExecutor);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates a synchronous inference request object used to infer the network
|
||||
* @note Used by ExecutableNetworkThreadSafeDefault::CreateInferRequest as a plugin-specific implementation
|
||||
* @param networkInputs An input info map needed to create input blobs
|
||||
* @param networkOutputs An output data map needed to create output blobs
|
||||
* @return Synchronous inference request object
|
||||
*/
|
||||
virtual InferRequestInternal::Ptr CreateInferRequestImpl(InputsDataMap networkInputs,
|
||||
OutputsDataMap networkOutputs) = 0;
|
||||
|
||||
ITaskExecutor::Ptr _taskExecutor = nullptr; //!< Holds a task executor
|
||||
ITaskExecutor::Ptr _callbackExecutor = nullptr; //!< Holds a callback executor
|
||||
};
|
||||
|
@ -1,82 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable : 4250)
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief minimum API to be implemented by plugin, which is used in InferRequestBase forwarding mechanism
|
||||
* @ingroup ie_dev_api_async_infer_request_api
|
||||
*/
|
||||
class AsyncInferRequestInternal : public IAsyncInferRequestInternal, public InferRequestInternal {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to a AsyncInferRequestInternal implementation
|
||||
*/
|
||||
typedef std::shared_ptr<AsyncInferRequestInternal> Ptr;
|
||||
|
||||
/**
|
||||
* @brief Constructs a new instance.
|
||||
* @param[in] networkInputs The network inputs info
|
||||
* @param[in] networkOutputs The network outputs data
|
||||
*/
|
||||
AsyncInferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs)
|
||||
: InferRequestInternal(networkInputs, networkOutputs), _callback(nullptr), _userData(nullptr) {}
|
||||
|
||||
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
|
||||
_callback = callback;
|
||||
}
|
||||
|
||||
void GetUserData(void** data) override {
|
||||
if (data == nullptr) IE_THROW(NotAllocated);
|
||||
*data = _userData;
|
||||
}
|
||||
|
||||
void SetUserData(void* data) override {
|
||||
_userData = data;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set weak pointer to the corresponding public interface: IInferRequest. This allow to pass it to
|
||||
* IInferRequest::CompletionCallback
|
||||
* @param ptr A weak pointer to InferRequestBase
|
||||
*/
|
||||
void SetPointerToPublicInterface(IInferRequest::Ptr ptr) {
|
||||
_publicInterface = ptr;
|
||||
}
|
||||
|
||||
void StartAsync() override {
|
||||
checkBlobs();
|
||||
StartAsyncImpl();
|
||||
};
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief The minimal asynchronous inference function to be implemented by plugins.
|
||||
* It starts inference of specified input(s) in asynchronous mode
|
||||
* @note
|
||||
* * The methos is used in AsyncInferRequestInternal::StartAsync which performs common steps first and
|
||||
* calls plugin dependent implementation of this method after.
|
||||
* * It returns immediately. Inference starts also immediately.
|
||||
*/
|
||||
virtual void StartAsyncImpl() = 0;
|
||||
|
||||
IInferRequest::WeakPtr _publicInterface; //!< A weak pointer to a IInferRequest interface for callback calling
|
||||
InferenceEngine::IInferRequest::CompletionCallback _callback; //!< A callback
|
||||
void* _userData; //!< A callback user data
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
@ -8,10 +8,10 @@
|
||||
#include <threading/ie_itask_executor.hpp>
|
||||
#include <threading/ie_istreams_executor.hpp>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
#include <cpp_interfaces/exception2status.hpp>
|
||||
#include <ie_system_conf.h>
|
||||
#include <ie_iinfer_request.hpp>
|
||||
|
||||
#include <exception>
|
||||
#include <future>
|
||||
@ -40,12 +40,12 @@ namespace InferenceEngine {
|
||||
*
|
||||
* @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
|
||||
*/
|
||||
class AsyncInferRequestThreadSafeDefault : public IAsyncInferRequestInternal {
|
||||
class AsyncInferRequestThreadSafeDefault : public IInferRequestInternal {
|
||||
enum InferState {Idle, Busy, Canceled, Stop};
|
||||
using Futures = std::vector<std::shared_future<void>>;
|
||||
using Promise = std::shared_ptr<std::promise<void>>;
|
||||
enum Stage_e : std::uint8_t { executor, task };
|
||||
InferRequestInternal::Ptr _syncRequest;
|
||||
IInferRequestInternal::Ptr _syncRequest;
|
||||
|
||||
friend struct DisableCallbackGuard;
|
||||
struct DisableCallbackGuard {
|
||||
@ -59,7 +59,7 @@ class AsyncInferRequestThreadSafeDefault : public IAsyncInferRequestInternal {
|
||||
_this->_callback = _callback;
|
||||
}
|
||||
AsyncInferRequestThreadSafeDefault* _this = nullptr;
|
||||
IInferRequest::CompletionCallback _callback = nullptr;
|
||||
Callback _callback;
|
||||
};
|
||||
|
||||
struct ImmediateStreamsExecutor : public InferenceEngine::ITaskExecutor {
|
||||
@ -132,15 +132,15 @@ public:
|
||||
using Ptr = std::shared_ptr<AsyncInferRequestThreadSafeDefault>;
|
||||
|
||||
/**
|
||||
* @brief Wraps a InferRequestInternal::Ptr implementation and constructs a
|
||||
* AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run InferRequestInternal::Infer
|
||||
* @brief Wraps a IInferRequestInternal::Ptr implementation and constructs a
|
||||
* AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run IInferRequestInternal::Infer
|
||||
* asynchronously.
|
||||
*
|
||||
* @param[in] request The synchronous request
|
||||
* @param[in] taskExecutor The task executor
|
||||
* @param[in] callbackExecutor The callback executor
|
||||
*/
|
||||
AsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr& request,
|
||||
AsyncInferRequestThreadSafeDefault(const IInferRequestInternal::Ptr& request,
|
||||
const ITaskExecutor::Ptr& taskExecutor,
|
||||
const ITaskExecutor::Ptr& callbackExecutor) :
|
||||
_syncRequest {request},
|
||||
@ -245,32 +245,13 @@ public:
|
||||
_syncRequest->SetBatch(batch);
|
||||
};
|
||||
|
||||
void GetUserData(void** data) override {
|
||||
void SetCallback(Callback callback) override {
|
||||
CheckState();
|
||||
if (data == nullptr) IE_THROW(NotAllocated);
|
||||
*data = _userData;
|
||||
_callback = std::move(callback);
|
||||
}
|
||||
|
||||
void SetUserData(void* data) override {
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override {
|
||||
CheckState();
|
||||
_userData = data;
|
||||
}
|
||||
|
||||
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
|
||||
CheckState();
|
||||
_callback = callback;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets the pointer to public interface.
|
||||
* @note Needed to correctly handle ownership between objects
|
||||
* @param[in] ptr A shared pointer to a public IInferRequest interface.
|
||||
*/
|
||||
void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr) {
|
||||
_publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
|
||||
}
|
||||
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
||||
return _syncRequest->QueryState();
|
||||
}
|
||||
|
||||
@ -319,13 +300,13 @@ protected:
|
||||
* pipeline tasks
|
||||
*/
|
||||
void StopAndWait() {
|
||||
_callback = nullptr;
|
||||
Futures futures;
|
||||
InferState state = InferState::Idle;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{_mutex};
|
||||
state = _state;
|
||||
if (state != InferState::Stop) {
|
||||
_callback = {};
|
||||
_state = InferState::Stop;
|
||||
futures = std::move(_futures);
|
||||
}
|
||||
@ -385,51 +366,44 @@ private:
|
||||
Task MakeNextStageTask(const Pipeline::iterator itStage, const Pipeline::iterator itEndStage,
|
||||
const ITaskExecutor::Ptr callbackExecutor) {
|
||||
return std::bind([this, itStage, itEndStage](ITaskExecutor::Ptr& callbackExecutor) mutable {
|
||||
StatusCode requestStatus = StatusCode::OK;
|
||||
std::exception_ptr localCurrentException = nullptr;
|
||||
std::exception_ptr currentException = nullptr;
|
||||
auto& thisStage = *itStage;
|
||||
auto itNextStage = itStage + 1;
|
||||
|
||||
try {
|
||||
auto& stageTask = std::get<Stage_e::task>(thisStage);
|
||||
IE_ASSERT(nullptr != stageTask);
|
||||
stageTask();
|
||||
if (itEndStage != itNextStage) {
|
||||
if (itEndStage != itNextStage) {
|
||||
auto& nextStage = *itNextStage;
|
||||
auto& nextStageExecutor = std::get<Stage_e::executor>(nextStage);
|
||||
IE_ASSERT(nullptr != nextStageExecutor);
|
||||
nextStageExecutor->run(MakeNextStageTask(itNextStage, itEndStage, std::move(callbackExecutor)));
|
||||
}
|
||||
} catch (InferenceEngine::Exception& ie_ex) {
|
||||
requestStatus = ExceptionToStatus(ie_ex);
|
||||
localCurrentException = std::current_exception();
|
||||
} catch (...) {
|
||||
requestStatus = StatusCode::GENERAL_ERROR;
|
||||
localCurrentException = std::current_exception();
|
||||
currentException = std::current_exception();
|
||||
}
|
||||
|
||||
if ((itEndStage == itNextStage) || (nullptr != localCurrentException)) {
|
||||
auto lastStageTask = [this, requestStatus, localCurrentException]() mutable {
|
||||
if ((itEndStage == itNextStage) || (nullptr != currentException)) {
|
||||
auto lastStageTask = [this, currentException]() mutable {
|
||||
auto promise = std::move(_promise);
|
||||
IInferRequest::CompletionCallback callback = nullptr;
|
||||
Callback callback;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{_mutex};
|
||||
_state = InferState::Idle;
|
||||
callback = _callback;
|
||||
}
|
||||
if (nullptr != callback) {
|
||||
InferenceEngine::CurrentException() = localCurrentException;
|
||||
if (callback) {
|
||||
try {
|
||||
callback(_publicInterface, requestStatus);
|
||||
auto local_callback = std::move(callback);
|
||||
local_callback(currentException);
|
||||
} catch (...) {
|
||||
localCurrentException = std::current_exception();
|
||||
currentException = std::current_exception();
|
||||
}
|
||||
InferenceEngine::CurrentException() = nullptr;
|
||||
}
|
||||
if (nullptr == localCurrentException) {
|
||||
if (nullptr == currentException) {
|
||||
promise.set_value();
|
||||
} else {
|
||||
promise.set_exception(localCurrentException);
|
||||
promise.set_exception(currentException);
|
||||
}
|
||||
};
|
||||
|
||||
@ -442,9 +416,6 @@ private:
|
||||
}, std::move(callbackExecutor));
|
||||
}
|
||||
|
||||
void* _userData = nullptr;
|
||||
IInferRequest::CompletionCallback _callback = nullptr;
|
||||
IInferRequest::Ptr _publicInterface;
|
||||
std::promise<void> _promise;
|
||||
mutable std::mutex _mutex;
|
||||
Futures _futures;
|
||||
|
@ -1,422 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_icnn_network.hpp>
|
||||
#include <ie_input_info.hpp>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "cpp_interfaces/exception2status.hpp"
|
||||
#include "cpp_interfaces/plugin_itt.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
|
||||
#include "debug.h"
|
||||
#include "ie_compound_blob.h"
|
||||
#include "ie_memcpy.h"
|
||||
#include "ie_preprocess_data.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
class IExecutableNetworkInternal;
|
||||
|
||||
/**
|
||||
* @brief An optimal implementation of IInferRequestInternal interface to avoid duplication in all plugins
|
||||
* This base class is recommended to be used as a base class for plugin synchronous inference request implementation.
|
||||
* @ingroup ie_dev_api_infer_request_api
|
||||
*/
|
||||
class InferRequestInternal : virtual public IInferRequestInternal {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to a InferRequestInternal implementation.
|
||||
*/
|
||||
typedef std::shared_ptr<InferRequestInternal> Ptr;
|
||||
|
||||
/**
|
||||
* @brief Constructs a new instance.
|
||||
* @param[in] networkInputs The network inputs info
|
||||
* @param[in] networkOutputs The network outputs data
|
||||
*/
|
||||
InferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs): m_curBatch(-1) {
|
||||
// // We should copy maps since they can be overriden in SetBlob with preprocess
|
||||
copyInputOutputInfo(networkInputs, networkOutputs, _networkInputs, _networkOutputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief The minimal infer function to be implemented by plugins. It infers specified input(s) in synchronous mode
|
||||
* @note
|
||||
* * This method is used in InferRequestInternal::Infer, which calls the common code first and after uses this
|
||||
* plugin dependent implementation.
|
||||
* * Blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
|
||||
*/
|
||||
virtual void InferImpl() = 0;
|
||||
|
||||
/**
|
||||
* @brief Default common implementation for all plugins with checking input and output blobs before inference
|
||||
*/
|
||||
void Infer() override {
|
||||
checkBlobs();
|
||||
InferImpl();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Default common implementation for all plugins
|
||||
*/
|
||||
void Cancel() override {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Given optional implementation of setting blob to avoid need for it to be implemented by plugin
|
||||
* @param name - a name of input or output blob.
|
||||
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input
|
||||
* precision and size.
|
||||
*/
|
||||
void SetBlob(const std::string& name, const Blob::Ptr& userBlob) override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "SetBlob");
|
||||
if (name.empty()) {
|
||||
IE_THROW(NotFound) << "Failed to set blob with empty name";
|
||||
}
|
||||
if (!userBlob) IE_THROW(NotAllocated) << "Failed to set empty blob with name: \'" << name << "\'";
|
||||
const bool compoundBlobPassed = userBlob->is<CompoundBlob>();
|
||||
const bool remoteBlobPassed = userBlob->is<RemoteBlob>();
|
||||
if (!compoundBlobPassed && !remoteBlobPassed && userBlob->buffer() == nullptr)
|
||||
IE_THROW(NotAllocated) << "Input data was not allocated. Input name: \'" << name << "\'";
|
||||
if (userBlob->size() == 0) {
|
||||
IE_THROW() << "Input data is empty. Input name: \'" << name << "\'";
|
||||
}
|
||||
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
size_t dataSize = userBlob->size();
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
if (foundInput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
|
||||
IE_THROW(ParameterMismatch)
|
||||
<< "Failed to set Blob with precision not corresponding to user input precision";
|
||||
}
|
||||
|
||||
auto& devBlob = _deviceInputs[name];
|
||||
const bool preProcRequired = preProcessingRequired(foundInput, userBlob, devBlob);
|
||||
if (compoundBlobPassed && !preProcRequired) {
|
||||
IE_THROW(NotImplemented)
|
||||
<< "cannot set compound blob: supported only for input pre-processing";
|
||||
}
|
||||
|
||||
if (preProcRequired) {
|
||||
addInputPreProcessingFor(name, userBlob, devBlob ? devBlob : _inputs[name]);
|
||||
} else {
|
||||
size_t inputSize = foundInput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
|
||||
? InferenceEngine::details::product(foundInput->getTensorDesc().getDims())
|
||||
: 1;
|
||||
if (dataSize != inputSize) {
|
||||
IE_THROW() << "Input blob size is not equal network input size (" << dataSize
|
||||
<< "!=" << inputSize << ").";
|
||||
}
|
||||
_inputs[name] = userBlob;
|
||||
devBlob = userBlob;
|
||||
}
|
||||
} else {
|
||||
if (compoundBlobPassed) {
|
||||
IE_THROW(NotImplemented)
|
||||
<< "cannot set compound blob: supported only for input pre-processing";
|
||||
}
|
||||
size_t outputSize = foundOutput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
|
||||
? details::product(foundOutput->getTensorDesc().getDims()) :
|
||||
1;
|
||||
if (dataSize != outputSize) {
|
||||
IE_THROW() << "Output blob size is not equal network output size (" << dataSize
|
||||
<< "!=" << outputSize << ").";
|
||||
}
|
||||
if (foundOutput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
|
||||
IE_THROW(ParameterMismatch)
|
||||
<< "Failed to set Blob with precision not corresponding to user output precision";
|
||||
}
|
||||
_outputs[name] = userBlob;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Given optional implementation of getting blob to avoid need for it to be implemented by plugin
|
||||
* @param name - a name of input or output blob.
|
||||
* @return Returns input or output blob. The type of Blob must correspond to the network input
|
||||
* precision and size.
|
||||
* @note if ROI blob was previously set it is returned (without dimensions checks) instead of default blob.
|
||||
*/
|
||||
Blob::Ptr GetBlob(const std::string& name) override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "GetBlob");
|
||||
Blob::Ptr data;
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
const SizeVector oneVector = { 1 };
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
// ROI blob is returned only if it was set previously. Otherwise default blob is returned.
|
||||
auto it = _preProcData.find(name);
|
||||
if (it != _preProcData.end()) {
|
||||
data = it->second->getRoiBlob();
|
||||
} else {
|
||||
data = _inputs[name];
|
||||
checkBlob(data, name, true,
|
||||
foundInput->getTensorDesc().getLayout() != SCALAR
|
||||
? foundInput->getTensorDesc().getDims()
|
||||
: oneVector);
|
||||
|
||||
auto& devBlob = _deviceInputs[name];
|
||||
if (preProcessingRequired(foundInput, data, devBlob)) {
|
||||
// if no devBlob, performs inplace
|
||||
addInputPreProcessingFor(name, data, devBlob ? devBlob : _inputs[name]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
data = _outputs[name];
|
||||
checkBlob(data, name, false,
|
||||
foundOutput->getTensorDesc().getLayout() != SCALAR
|
||||
? foundOutput->getTensorDesc().getDims()
|
||||
: oneVector);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets pre-process for input data
|
||||
* @param name Name of input blob.
|
||||
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input precision and size.
|
||||
* @param info Preprocess info for blob.
|
||||
*/
|
||||
void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) override {
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
copyPreProcess(info, foundInput->getPreProcess());
|
||||
} else {
|
||||
IE_THROW() << "Pre-process can't be set to output blob";
|
||||
}
|
||||
|
||||
SetBlob(name, data);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gets pre-process for input data
|
||||
* @param name Name of input blob.
|
||||
* @return Returns constant reference to PreProcessInfo structure
|
||||
*/
|
||||
const PreProcessInfo& GetPreProcess(const std::string& name) const override {
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
return foundInput->getPreProcess();
|
||||
} else {
|
||||
IE_THROW() << "Output blob can't have pre-processing";
|
||||
}
|
||||
}
|
||||
|
||||
void SetBatch(int batch) override {
|
||||
(void)batch;
|
||||
IE_THROW() << "Dynamic batch is not supported";
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Sets the pointer to executable network internal.
|
||||
* @note Needed to correctly handle ownership between objects.
|
||||
* @param[in] exeNetwork The executable network
|
||||
*/
|
||||
void setPointerToExecutableNetworkInternal(std::shared_ptr<IExecutableNetworkInternal> exeNetwork) {
|
||||
_exeNetwork = exeNetwork;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks that both inputs and outputs blob are valid. Throws an exception if they are not.
|
||||
*/
|
||||
virtual void checkBlobs() {
|
||||
for (auto const& input : _inputs) {
|
||||
checkBlob(input.second, input.first, true);
|
||||
}
|
||||
for (auto const& output : _outputs) {
|
||||
checkBlob(output.second, output.first, false);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<IVariableStateInternal::Ptr> QueryState() override {
|
||||
// meaning base plugin reports as no state available - plugin owners need to create proper override of this
|
||||
IE_THROW() << "Plugin doesn't override QueryState";
|
||||
return {};
|
||||
}
|
||||
|
||||
protected:
|
||||
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
|
||||
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
|
||||
InferenceEngine::BlobMap _inputs; //!< A map of user passed blobs for network inputs
|
||||
InferenceEngine::BlobMap _deviceInputs; //!< A map of actual network inputs, in plugin specific format
|
||||
InferenceEngine::BlobMap _outputs; //!< A map of user passed blobs for network outputs
|
||||
std::map<std::string, PreProcessDataPtr> _preProcData; //!< A map of pre-process data per input
|
||||
int m_curBatch; //!< Current batch value used in dynamic batching
|
||||
|
||||
/**
|
||||
* @brief A shared pointer to ExecutableNetworkInternal interface
|
||||
* @note Needed to correctly handle ownership between objects.
|
||||
*/
|
||||
std::shared_ptr<IExecutableNetworkInternal> _exeNetwork;
|
||||
/**
|
||||
* @brief Checks and executes input data pre-processing if needed.
|
||||
* @param inputs Inputs blobs to perform preprocessing on
|
||||
* @param serial Whether to use multiple threads to execute the step
|
||||
*/
|
||||
void execDataPreprocessing(InferenceEngine::BlobMap& preprocessedBlobs, bool serial = false) {
|
||||
for (auto& input : preprocessedBlobs) {
|
||||
// If there is a pre-process entry for an input then it must be pre-processed
|
||||
// using preconfigured resize algorithm.
|
||||
auto it = _preProcData.find(input.first);
|
||||
if (it != _preProcData.end()) {
|
||||
_preProcData[input.first]->execute(input.second, _networkInputs[input.first]->getPreProcess(), serial,
|
||||
m_curBatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Helper function to find input or output blob by name
|
||||
* @param name A name of input or output blob.
|
||||
* @param foundInput A pointer to input information if found.
|
||||
* @param foundOutput A pointer to output DataPtr if found.
|
||||
* @return `True` - if loaded network has input with provided name,
|
||||
* `false` - if loaded network has output with provided name
|
||||
* @throws [parameter_mismatch] exception if input and output has the same name
|
||||
* @throws [not_found] exception if there is no input and output layers with given name
|
||||
*/
|
||||
bool findInputAndOutputBlobByName(const std::string& name, InputInfo::Ptr& foundInput, DataPtr& foundOutput) const {
|
||||
foundInput = nullptr;
|
||||
foundOutput = nullptr;
|
||||
if (_networkOutputs.empty()) {
|
||||
IE_THROW() << "Internal error: network outputs is not set";
|
||||
}
|
||||
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
|
||||
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
|
||||
[&](const std::pair<std::string, DataPtr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
if (foundOutputPair == std::end(_networkOutputs) && (foundInputPair == std::end(_networkInputs))) {
|
||||
IE_THROW(NotFound) << "Failed to find input or output with name: \'" << name << "\'";
|
||||
}
|
||||
if (foundInputPair != std::end(_networkInputs)) {
|
||||
foundInput = foundInputPair->second;
|
||||
return true;
|
||||
} else {
|
||||
foundOutput = foundOutputPair->second;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check that @p blob is valid. Throws an exception if it's not.
|
||||
*
|
||||
* @param[in] blob The blob to check
|
||||
* @param[in] name The name of input or output depending of if the @p blob is input or output
|
||||
* @param[in] isInput Indicates if @p is input
|
||||
* @param[in] refDims The reference dims, empty if not specified
|
||||
*/
|
||||
void checkBlob(const Blob::Ptr& blob, const std::string& name, bool isInput, const SizeVector& refDims = {}) const {
|
||||
std::string bType = isInput ? "Input" : "Output";
|
||||
std::string sType = isInput ? "input" : "output";
|
||||
std::string strNotAllocated(bType + " data was not allocated.");
|
||||
std::string strNotMatched("The " + sType + " blob size is not equal to the network " + sType + " size");
|
||||
|
||||
if (!blob) {
|
||||
IE_THROW() << strNotAllocated;
|
||||
}
|
||||
size_t refSize;
|
||||
if (refDims.empty()) {
|
||||
SizeVector dims;
|
||||
if (isInput) {
|
||||
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
|
||||
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
if (foundInputPair == std::end(_networkInputs)) {
|
||||
IE_THROW(NotFound) << "Failed to find input with name: \'" << name << "\'";
|
||||
}
|
||||
dims = foundInputPair->second->getTensorDesc().getDims();
|
||||
refSize = foundInputPair->second->getTensorDesc().getLayout() != SCALAR
|
||||
? details::product(dims)
|
||||
: 1;
|
||||
} else {
|
||||
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
|
||||
[&](const std::pair<std::string, DataPtr>& pair) {
|
||||
return pair.first == name;
|
||||
});
|
||||
if (foundOutputPair == std::end(_networkOutputs)) {
|
||||
IE_THROW(NotFound) << "Failed to find output with name: \'" << name << "\'";
|
||||
}
|
||||
dims = foundOutputPair->second->getTensorDesc().getDims();
|
||||
refSize = foundOutputPair->second->getTensorDesc().getLayout() != SCALAR
|
||||
? details::product(dims)
|
||||
: 1;
|
||||
}
|
||||
} else {
|
||||
refSize = details::product(refDims);
|
||||
}
|
||||
|
||||
if (refSize != blob->size()) {
|
||||
IE_THROW() << strNotMatched + ": got " << blob->size() << " expecting " << refSize;
|
||||
}
|
||||
const bool remoteBlobPassed = blob->is<RemoteBlob>();
|
||||
if (!remoteBlobPassed && blob->buffer() == nullptr) IE_THROW() << strNotAllocated;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks whether pre-processing step is required for a given input
|
||||
* @param info InputInfo corresponding to input blob
|
||||
* @param userBlob Input Blob object corresponding to input info
|
||||
* @param deviceBlob Blob object in plugin's desired format
|
||||
* @return `True` if pre-processing is required, `false` otherwise
|
||||
*/
|
||||
bool preProcessingRequired(const InputInfo::Ptr& info, const Blob::Ptr& userBlob, const Blob::Ptr& deviceBlob = nullptr) {
|
||||
// pre-processing is required if:
|
||||
// 1. resize algorithm is specified (resize required)
|
||||
// 2. color format specified:
|
||||
// 2.a. color format is not equal to network's expected (color conversion required)
|
||||
// 2.b. network's layout != blob's layout (reorder required)
|
||||
// 3. precision conversion is required
|
||||
|
||||
const auto& preProcessInfo = info->getPreProcess();
|
||||
const auto inputColorFormat = preProcessInfo.getColorFormat();
|
||||
// FIXME: support other network's input formats once the API is ready. Assuming input is in
|
||||
// the BGR format by default
|
||||
const auto networkColorFormat = ColorFormat::BGR;
|
||||
const bool colorFormatSpecified = inputColorFormat != ColorFormat::RAW;
|
||||
|
||||
auto blob_layout = [](const Blob::Ptr& b) { return b->getTensorDesc().getLayout(); };
|
||||
auto blob_prec = [](const Blob::Ptr& b) { return b->getTensorDesc().getPrecision();};
|
||||
|
||||
auto dst_layout = deviceBlob ? blob_layout(deviceBlob) : info->getLayout();
|
||||
auto dst_prec = deviceBlob ? blob_prec(deviceBlob) : info->getPrecision();
|
||||
|
||||
//FIXME: remove the first part to allow any needed conversion?
|
||||
const bool need_layout_conv = (colorFormatSpecified || deviceBlob) &&
|
||||
(blob_layout(userBlob) != dst_layout);
|
||||
|
||||
return preProcessInfo.getResizeAlgorithm() != ResizeAlgorithm::NO_RESIZE ||
|
||||
(colorFormatSpecified && inputColorFormat != networkColorFormat) ||
|
||||
need_layout_conv ||
|
||||
(blob_prec(userBlob) != dst_prec);
|
||||
}
|
||||
|
||||
void addInputPreProcessingFor(const std::string& name, Blob::Ptr const& from, const Blob::Ptr& to) {
|
||||
auto ppDataIt = _preProcData.find(name);
|
||||
if (ppDataIt == _preProcData.end()) {
|
||||
ppDataIt = (_preProcData.emplace(name, CreatePreprocDataHelper())).first;
|
||||
}
|
||||
|
||||
auto& preproc_ptr = ppDataIt->second;
|
||||
preproc_ptr->isApplicable(from, to);
|
||||
// Stores the given blob as ROI blob. It will be used to fill in network input
|
||||
// during pre-processingstd::map<std::string, InferenceEngineProfileInfo>
|
||||
preproc_ptr->setRoiBlob(from);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_iexecutable_network.hpp>
|
||||
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
|
||||
#include <ie_iinfer_request.hpp>
|
||||
#include <ie_parameter.hpp>
|
||||
@ -11,9 +12,10 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
class IInferRequestInternal;
|
||||
/**
|
||||
* @interface IExecutableNetworkInternal
|
||||
* @brief An internal API of executable network to be implemented by plugin,
|
||||
@ -52,7 +54,7 @@ public:
|
||||
* Note: the returned request will have allocated input and output blobs (that can be changed later)
|
||||
* @return shared_ptr for the created request
|
||||
*/
|
||||
virtual IInferRequest::Ptr CreateInferRequest() = 0;
|
||||
virtual std::shared_ptr<IInferRequestInternal> CreateInferRequest() = 0;
|
||||
|
||||
/**
|
||||
* @deprecated Use IExecutableNetworkInternal::Export(std::ostream& networkModel)
|
||||
|
@ -1,66 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ie_iinfer_request.hpp>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "ie_iinfer_request_internal.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @interface IAsyncInferRequestInternal
|
||||
* @ingroup ie_dev_api_async_infer_request_api
|
||||
* @brief An internal API of asynchronous inference request to be implemented by plugin,
|
||||
* which is used in InferRequestBase forwarding mechanism
|
||||
*/
|
||||
class IAsyncInferRequestInternal : virtual public IInferRequestInternal {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to IAsyncInferRequestInternal interface
|
||||
*/
|
||||
typedef std::shared_ptr<IAsyncInferRequestInternal> Ptr;
|
||||
|
||||
/**
|
||||
* @brief Start inference of specified input(s) in asynchronous mode
|
||||
* @note The method returns immediately. Inference starts also immediately.
|
||||
*/
|
||||
virtual void StartAsync() = 0;
|
||||
|
||||
/**
|
||||
* @brief Waits for the result to become available. Blocks until specified millis_timeout has elapsed or the result
|
||||
* becomes available, whichever comes first.
|
||||
* @param millis_timeout - maximum duration in milliseconds to block for
|
||||
* @note There are special cases when millis_timeout is equal some value of WaitMode enum:
|
||||
* * STATUS_ONLY - immediately returns request status (IInferRequest::RequestStatus). It doesn't block or interrupt
|
||||
* current thread.
|
||||
* * RESULT_READY - waits until inference result becomes available
|
||||
* @return A status code
|
||||
*/
|
||||
virtual StatusCode Wait(int64_t millis_timeout) = 0;
|
||||
|
||||
/**
|
||||
* @brief Get arbitrary data for the request
|
||||
* @param data A pointer to a pointer to arbitrary data
|
||||
*/
|
||||
virtual void GetUserData(void** data) = 0;
|
||||
|
||||
/**
|
||||
* @brief Set arbitrary data for the request
|
||||
* @param data A pointer to a pointer to arbitrary data
|
||||
*/
|
||||
virtual void SetUserData(void* data) = 0;
|
||||
|
||||
/**
|
||||
* @brief Set callback function which will be called on success or failure of asynchronous request
|
||||
* @param callback - function to be called with the following description:
|
||||
*/
|
||||
virtual void SetCompletionCallback(IInferRequest::CompletionCallback callback) = 0;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
@ -4,52 +4,67 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
|
||||
#include <ie_blob.h>
|
||||
#include <ie_common.h>
|
||||
#include <ie_preprocess.hpp>
|
||||
#include <ie_preprocess_data.hpp>
|
||||
#include <ie_input_info.hpp>
|
||||
#include <ie_icnn_network.hpp>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
class IExecutableNetworkInternal;
|
||||
class IVariableStateInternal;
|
||||
/**
|
||||
* @interface IInferRequestInternal
|
||||
* @brief An internal API of synchronous inference request to be implemented by plugin,
|
||||
* which is used in InferRequestBase forwarding mechanism
|
||||
* @ingroup ie_dev_api_infer_request_api
|
||||
*/
|
||||
class IInferRequestInternal {
|
||||
class INFERENCE_ENGINE_API_CLASS(IInferRequestInternal) : public std::enable_shared_from_this<IInferRequestInternal> {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to a IInferRequestInternal interface
|
||||
*/
|
||||
typedef std::shared_ptr<IInferRequestInternal> Ptr;
|
||||
using Ptr = std::shared_ptr<IInferRequestInternal>;
|
||||
|
||||
IInferRequestInternal() = default;
|
||||
|
||||
/**
|
||||
* @brief Destroys the object.
|
||||
* @brief Constructs a new instance.
|
||||
* @param[in] networkInputs The network inputs info
|
||||
* @param[in] networkOutputs The network outputs data
|
||||
*/
|
||||
virtual ~IInferRequestInternal() = default;
|
||||
IInferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs);
|
||||
|
||||
/**
|
||||
* @brief Infers specified input(s) in synchronous mode
|
||||
* @note blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
|
||||
*/
|
||||
virtual void Infer() = 0;
|
||||
virtual void Infer();
|
||||
|
||||
/**
|
||||
* @brief The minimal infer function to be implemented by plugins. It infers specified input(s) in synchronous mode
|
||||
* @note
|
||||
* * This method is used in IInferRequestInternal::Infer, which calls the common code first and after uses this
|
||||
* plugin dependent implementation.
|
||||
* * Blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
|
||||
*/
|
||||
virtual void InferImpl();
|
||||
|
||||
/**
|
||||
* @brief Cancel current inference request execution
|
||||
*/
|
||||
virtual void Cancel() = 0;
|
||||
virtual void Cancel();
|
||||
|
||||
/**
|
||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
||||
* Note: not all plugins may provide meaningful data
|
||||
* @return Returns a map of layer names to profiling information for that layer.
|
||||
* @return - a map of layer names to profiling information for that layer.
|
||||
*/
|
||||
virtual std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const = 0;
|
||||
virtual std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const;
|
||||
|
||||
/**
|
||||
* @brief Set input/output data to infer
|
||||
@ -58,16 +73,16 @@ public:
|
||||
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input
|
||||
* precision and size.
|
||||
*/
|
||||
virtual void SetBlob(const std::string& name, const Blob::Ptr& data) = 0;
|
||||
virtual void SetBlob(const std::string& name, const Blob::Ptr& data);
|
||||
|
||||
/**
|
||||
* @brief Get input/output data to infer
|
||||
* @note Memory allocation doesn't happen
|
||||
* @param name - a name of input or output blob.
|
||||
* @return Returns input or output blob. The type of Blob must correspond to the network input
|
||||
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input
|
||||
* precision and size.
|
||||
*/
|
||||
virtual Blob::Ptr GetBlob(const std::string& name) = 0;
|
||||
virtual Blob::Ptr GetBlob(const std::string& name);
|
||||
|
||||
/**
|
||||
* @brief Sets pre-process for input data
|
||||
@ -75,26 +90,138 @@ public:
|
||||
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input precision and size.
|
||||
* @param info Preprocess info for blob.
|
||||
*/
|
||||
virtual void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) = 0;
|
||||
virtual void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info);
|
||||
|
||||
/**
|
||||
* @brief Gets pre-process for input data
|
||||
* @param name Name of input blob.
|
||||
* @return Returns constant reference to PreProcessInfo structure
|
||||
* @param info pointer to a pointer to PreProcessInfo structure
|
||||
*/
|
||||
virtual const PreProcessInfo& GetPreProcess(const std::string& name) const = 0;
|
||||
virtual const PreProcessInfo& GetPreProcess(const std::string& name) const;
|
||||
|
||||
/**
|
||||
* @brief Sets new batch size when dynamic batching is enabled in executable network that created this request.
|
||||
* @param batch - new batch size to be used by all the following inference calls for this request.
|
||||
*/
|
||||
virtual void SetBatch(int batch) = 0;
|
||||
virtual void SetBatch(int batch);
|
||||
|
||||
/**
|
||||
* @brief Queries memory states.
|
||||
* @return Returns memory states
|
||||
*/
|
||||
virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
|
||||
virtual std::vector<std::shared_ptr<IVariableStateInternal>> QueryState();
|
||||
|
||||
/**
|
||||
* @brief Start inference of specified input(s) in asynchronous mode
|
||||
* @note The method returns immediately. Inference starts also immediately.
|
||||
*/
|
||||
virtual void StartAsync();
|
||||
|
||||
/**
|
||||
* @brief The minimal asynchronous inference function to be implemented by plugins.
|
||||
* It starts inference of specified input(s) in asynchronous mode
|
||||
* @note
|
||||
* * The methos is used in AsyncInferRequestInternal::StartAsync which performs common steps first and
|
||||
* calls plugin dependent implementation of this method after.
|
||||
* * It returns immediately. Inference starts also immediately.
|
||||
*/
|
||||
virtual void StartAsyncImpl();
|
||||
|
||||
/**
|
||||
* @brief Waits for the result to become available. Blocks until specified millis_timeout has elapsed or the result
|
||||
* becomes available, whichever comes first.
|
||||
* @param millis_timeout - maximum duration in milliseconds to block for
|
||||
* @note There are special cases when millis_timeout is equal some value of WaitMode enum:
|
||||
* * STATUS_ONLY - immediately returns request status (IInferRequest::RequestStatus). It doesn't block or interrupt
|
||||
* current thread.
|
||||
* * RESULT_READY - waits until inference result becomes available
|
||||
* @return A status code
|
||||
*/
|
||||
virtual StatusCode Wait(int64_t millis_timeout);
|
||||
|
||||
/**
|
||||
* @brief Alias for callback type
|
||||
*/
|
||||
using Callback = std::function<void(std::exception_ptr)>;
|
||||
|
||||
/**
|
||||
* @brief Set callback function which will be called on success or failure of asynchronous request
|
||||
* @param callback - function to be called with the following description:
|
||||
*/
|
||||
virtual void SetCallback(Callback callback);
|
||||
|
||||
/**
|
||||
* @brief Check that @p blob is valid. Throws an exception if it's not.
|
||||
*
|
||||
* @param[in] blob The blob to check
|
||||
* @param[in] name The name of input or output depending of if the @p blob is input or output
|
||||
* @param[in] isInput Indicates if @p is input
|
||||
* @param[in] refDims The reference dims, empty if not specified
|
||||
*/
|
||||
void checkBlob(const Blob::Ptr& blob, const std::string& name, bool isInput, const SizeVector& refDims = {}) const;
|
||||
|
||||
/**
|
||||
* @brief Check that all of the blobs is valid. Throws an exception if it's not.
|
||||
*/
|
||||
virtual void checkBlobs();
|
||||
|
||||
/**
|
||||
* @brief Sets the pointer to executable network internal.
|
||||
* @note Needed to correctly handle ownership between objects.
|
||||
* @param[in] exeNetwork The executable network
|
||||
*/
|
||||
void setPointerToExecutableNetworkInternal(const std::shared_ptr<IExecutableNetworkInternal>& exeNetwork);
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Checks and executes input data pre-processing if needed.
|
||||
* @param inputs Inputs blobs to perform preprocessing on
|
||||
* @param serial Whether to use multiple threads to execute the step
|
||||
*/
|
||||
void execDataPreprocessing(InferenceEngine::BlobMap& preprocessedBlobs, bool serial = false);
|
||||
|
||||
/**
|
||||
* @brief Helper function to find input or output blob by name
|
||||
* @param name A name of input or output blob.
|
||||
* @param foundInput A pointer to input information if found.
|
||||
* @param foundOutput A pointer to output DataPtr if found.
|
||||
* @return `True` - if loaded network has input with provided name,
|
||||
* `false` - if loaded network has output with provided name
|
||||
* @throws [parameter_mismatch] exception if input and output has the same name
|
||||
* @throws [not_found] exception if there is no input and output layers with given name
|
||||
*/
|
||||
bool findInputAndOutputBlobByName(const std::string& name, InputInfo::Ptr& foundInput, DataPtr& foundOutput) const;
|
||||
|
||||
/**
|
||||
* @brief Checks whether pre-processing step is required for a given input
|
||||
* @param info InputInfo corresponding to input blob
|
||||
* @param userBlob Input Blob object corresponding to input info
|
||||
* @param deviceBlob Blob object in plugin's desired format
|
||||
* @return `True` if pre-processing is required, `false` otherwise
|
||||
*/
|
||||
bool preProcessingRequired(const InputInfo::Ptr& info, const Blob::Ptr& userBlob, const Blob::Ptr& deviceBlob = nullptr);
|
||||
|
||||
void addInputPreProcessingFor(const std::string& name, Blob::Ptr const& from, const Blob::Ptr& to);
|
||||
|
||||
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
|
||||
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
|
||||
InferenceEngine::BlobMap _inputs; //!< A map of user passed blobs for network inputs
|
||||
InferenceEngine::BlobMap _deviceInputs; //!< A map of actual network inputs, in plugin specific format
|
||||
InferenceEngine::BlobMap _outputs; //!< A map of user passed blobs for network outputs
|
||||
std::map<std::string, PreProcessDataPtr> _preProcData; //!< A map of pre-process data per input
|
||||
int m_curBatch = -1; //!< Current batch value used in dynamic batching
|
||||
|
||||
/**
|
||||
* @brief A shared pointer to ExecutableNetworkInternal interface
|
||||
* @note Needed to correctly handle ownership between objects.
|
||||
*/
|
||||
std::shared_ptr<IExecutableNetworkInternal> _exeNetwork;
|
||||
Callback _callback; //!< A callback
|
||||
|
||||
/**
|
||||
* @brief Destroys the object.
|
||||
*/
|
||||
~IInferRequestInternal();
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -17,7 +17,7 @@ public:
|
||||
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr &taskExecutorGetResult);
|
||||
|
||||
~MyriadAsyncInferRequest() override;
|
||||
~MyriadAsyncInferRequest();
|
||||
private:
|
||||
MyriadInferRequest::Ptr _request;
|
||||
InferenceEngine::ITaskExecutor::Ptr _taskExecutorGetResult;
|
||||
|
@ -61,7 +61,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
ie::InferRequestInternal::Ptr CreateInferRequestImpl(ie::InputsDataMap networkInputs,
|
||||
ie::IInferRequestInternal::Ptr CreateInferRequestImpl(ie::InputsDataMap networkInputs,
|
||||
ie::OutputsDataMap networkOutputs) override {
|
||||
if (_device == nullptr || !_device->isBooted()) {
|
||||
IE_THROW() << "Can not create infer request: there is no available devices with platform "
|
||||
@ -73,7 +73,7 @@ public:
|
||||
_graphMetaData.stagesMeta, _config, _log, _executor);
|
||||
}
|
||||
|
||||
ie::IInferRequest::Ptr CreateInferRequest() override {
|
||||
ie::IInferRequestInternal::Ptr CreateInferRequest() override {
|
||||
ie::IInferRequest::Ptr asyncRequest;
|
||||
if (_device == nullptr || !_device->isBooted()) {
|
||||
IE_THROW() << "Can not create infer request: there is no available devices with platform "
|
||||
@ -86,11 +86,8 @@ public:
|
||||
_executor);
|
||||
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||
auto taskExecutorGetResult = getNextTaskExecutor();
|
||||
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
|
||||
return std::make_shared<MyriadAsyncInferRequest>(
|
||||
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
|
||||
asyncRequest.reset(new ie::InferRequestBase(asyncThreadSafeImpl));
|
||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
}
|
||||
|
||||
void Export(std::ostream& model) override {
|
||||
|
@ -33,7 +33,7 @@ MyriadInferRequest::MyriadInferRequest(GraphDesc &graphDesc,
|
||||
const MyriadConfig& myriadConfig,
|
||||
const Logger::Ptr &log,
|
||||
const MyriadExecutorPtr &executor) :
|
||||
InferRequestInternal(networkInputs, networkOutputs), _executor(executor),
|
||||
IInferRequestInternal(networkInputs, networkOutputs), _executor(executor),
|
||||
_log(log), _stagesMetaData(blobMetaData), _config(myriadConfig),
|
||||
_inputInfo(compilerInputsInfo), _outputInfo(compilerOutputsInfo),
|
||||
_graphDesc(graphDesc) {
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include <memory>
|
||||
|
||||
#include <ie_common.h>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
|
||||
#include <vpu/utils/logger.hpp>
|
||||
@ -22,7 +21,7 @@
|
||||
namespace vpu {
|
||||
namespace MyriadPlugin {
|
||||
|
||||
class MyriadInferRequest : public InferenceEngine::InferRequestInternal {
|
||||
class MyriadInferRequest : public InferenceEngine::IInferRequestInternal {
|
||||
MyriadExecutorPtr _executor;
|
||||
Logger::Ptr _log;
|
||||
std::vector<StageMetaInfo> _stagesMetaData;
|
||||
|
@ -13,11 +13,6 @@ using namespace InferenceEngine;
|
||||
using namespace InferenceEngine::details;
|
||||
|
||||
|
||||
TEST(InferRequestCPPTests, throwsOnInitWithNull) {
|
||||
IInferRequest::Ptr nlptr = nullptr;
|
||||
ASSERT_THROW(InferRequest req(nlptr), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
TEST(InferRequestCPPTests, throwsOnUninitializedSetBlob) {
|
||||
InferRequest req;
|
||||
ASSERT_THROW(req.SetBlob({}, {}), InferenceEngine::Exception);
|
||||
@ -81,7 +76,7 @@ TEST(InferRequestCPPTests, throwsOnUninitializedSetCompletionCallback) {
|
||||
|
||||
TEST(InferRequestCPPTests, throwsOnUninitializedCast) {
|
||||
InferRequest req;
|
||||
ASSERT_THROW((void)static_cast<IInferRequest::Ptr &>(req), InferenceEngine::Exception);
|
||||
ASSERT_THROW((void)static_cast<IInferRequest::Ptr>(req), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
TEST(InferRequestCPPTests, throwsOnUninitializedQueryState) {
|
||||
|
@ -98,11 +98,12 @@ class MockExecutableNetwork : public ExecutableNetworkInternal {
|
||||
public:
|
||||
MockExecutableNetwork() {}
|
||||
MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel));
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr());
|
||||
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap());
|
||||
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap());
|
||||
MOCK_CONST_METHOD1(GetConfig, Parameter(const std::string& name));
|
||||
MOCK_CONST_METHOD1(GetMetric, Parameter(const std::string& name));
|
||||
MOCK_METHOD2(CreateInferRequestImpl, IInferRequestInternal::Ptr(InputsDataMap, OutputsDataMap));
|
||||
};
|
||||
|
||||
//------------------------------------------------------
|
||||
@ -251,9 +252,8 @@ public:
|
||||
EXPECT_CALL(*mock, GetOutputsInfo()).Times(AnyNumber()).WillRepeatedly(Return(ConstOutputsDataMap{}));
|
||||
EXPECT_CALL(*mock, GetConfig(PluginConfigParams::KEY_PERF_COUNT)).Times(AnyNumber()).WillRepeatedly(Return(Parameter{PluginConfigParams::NO}));
|
||||
EXPECT_CALL(*mock, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS))).Times(AnyNumber()).WillRepeatedly(Return(Parameter{1u}));
|
||||
auto ptr = std::make_shared<MockIInferRequest>();
|
||||
EXPECT_CALL(*ptr, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*ptr, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
auto ptr = std::make_shared<MockIInferRequestInternal>();
|
||||
EXPECT_CALL(*ptr, SetCallback(_)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mock, CreateInferRequest()).Times(AnyNumber()).WillRepeatedly(Return(ptr));
|
||||
return mock;
|
||||
}
|
||||
@ -345,10 +345,8 @@ private:
|
||||
}));
|
||||
EXPECT_CALL(*net, CreateInferRequest()).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&]() {
|
||||
std::vector<std::string> res;
|
||||
auto inferReq = std::make_shared<MockIInferRequest>();
|
||||
EXPECT_CALL(*inferReq, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*inferReq, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
auto inferReq = std::make_shared<MockIInferRequestInternal>();
|
||||
EXPECT_CALL(*inferReq, SetCallback(_)).Times(AnyNumber());
|
||||
return inferReq;
|
||||
}));
|
||||
}
|
||||
|
@ -122,20 +122,12 @@ TEST_P(CallbackTests, returnGeneralErrorIfCallbackThrowException) {
|
||||
// Load CNNNetwork to target plugins
|
||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||
// Create InferRequest
|
||||
InferenceEngine::IInferRequest::Ptr req = static_cast<InferenceEngine::IInferRequest::Ptr &>(execNet.CreateInferRequest());
|
||||
req->SetCompletionCallback(
|
||||
[](InferenceEngine::IInferRequest::Ptr, InferenceEngine::StatusCode status) {
|
||||
IE_THROW() << "returnGeneralErrorIfCallbackThrowException";
|
||||
});
|
||||
auto req = execNet.CreateInferRequest();
|
||||
req.SetCompletionCallback([] {
|
||||
IE_THROW(GeneralError);
|
||||
});
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
req->StartAsync(&resp);
|
||||
InferenceEngine::StatusCode waitStatus = InferenceEngine::StatusCode::INFER_NOT_STARTED;
|
||||
while (InferenceEngine::StatusCode::RESULT_NOT_READY == waitStatus ||
|
||||
InferenceEngine::StatusCode::INFER_NOT_STARTED == waitStatus) {
|
||||
waitStatus = req->Wait(InferenceEngine::IInferRequest::WaitMode::STATUS_ONLY, &resp);
|
||||
}
|
||||
ASSERT_EQ(InferenceEngine::StatusCode::GENERAL_ERROR, waitStatus);
|
||||
ASSERT_NE(std::string(resp.msg).find("returnGeneralErrorIfCallbackThrowException"), std::string::npos);
|
||||
ASSERT_NO_THROW(req.StartAsync());
|
||||
ASSERT_THROW(req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY), InferenceEngine::GeneralError);
|
||||
}
|
||||
} // namespace BehaviorTestsDefinitions
|
@ -131,11 +131,9 @@ TEST_P(InferRequestOutputTests, canStartAsyncInferWithGetInOut) {
|
||||
InferenceEngine::InferRequest req;
|
||||
ASSERT_NO_THROW(req = execNet.CreateInferRequest());
|
||||
InferenceEngine::Blob::Ptr inputBlob = req.GetBlob(cnnNet.getInputsInfo().begin()->first);
|
||||
InferenceEngine::StatusCode sts;
|
||||
ASSERT_NO_THROW(req.Infer());
|
||||
ASSERT_NO_THROW(req.StartAsync());
|
||||
sts = req.Wait(500);
|
||||
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts);
|
||||
ASSERT_NO_THROW(req.Wait());
|
||||
InferenceEngine::Blob::Ptr outputBlob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first);
|
||||
}
|
||||
} // namespace BehaviorTestsDefinitions
|
@ -12,14 +12,10 @@
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/mock_task_executor.hpp"
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_network_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_default.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_inference_plugin_internal.hpp"
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_ivariable_state_internal.hpp"
|
||||
|
@ -12,13 +12,11 @@
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class MockAsyncInferRequestDefault : public AsyncInferRequestThreadSafeDefault {
|
||||
public:
|
||||
MockAsyncInferRequestDefault(InferRequestInternal::Ptr request,
|
||||
MockAsyncInferRequestDefault(IInferRequestInternal::Ptr request,
|
||||
const ITaskExecutor::Ptr &taskExecutor,
|
||||
const ITaskExecutor::Ptr &callbackExecutor)
|
||||
: AsyncInferRequestThreadSafeDefault(request, taskExecutor, callbackExecutor) {}
|
||||
|
@ -1,37 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ie_iinfer_request.hpp>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_async_request_internal.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class MockAsyncInferRequestInternal : public AsyncInferRequestInternal {
|
||||
public:
|
||||
using AsyncInferRequestInternal::SetBlob;
|
||||
MockAsyncInferRequestInternal(InputsDataMap networkInputs, OutputsDataMap networkOutputs)
|
||||
: AsyncInferRequestInternal(networkInputs, networkOutputs) {}
|
||||
|
||||
MOCK_METHOD0(StartAsyncImpl, void());
|
||||
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
|
||||
MOCK_METHOD1(GetUserData, void(void **));
|
||||
MOCK_METHOD1(SetUserData, void(void *));
|
||||
MOCK_METHOD0(InferImpl, void());
|
||||
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
|
||||
MOCK_METHOD1(setNetworkInputs, void(InputsDataMap));
|
||||
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
|
||||
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
|
||||
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
MOCK_METHOD0(Cancel_ThreadUnsafe, void());
|
||||
};
|
@ -13,12 +13,10 @@
|
||||
#include "ie_iexecutable_network.hpp"
|
||||
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
@ -26,7 +24,7 @@ class MockExecutableNetworkInternal : public ExecutableNetworkInternal {
|
||||
public:
|
||||
MOCK_METHOD1(setNetworkInputs, void(InputsDataMap));
|
||||
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
|
||||
MOCK_METHOD1(Export, void(const std::string &));
|
||||
MOCK_METHOD0(GetExecGraphInfo, CNNNetwork(void));
|
||||
void WrapOstreamExport(std::ostream& networkModel) {
|
||||
@ -36,4 +34,5 @@ public:
|
||||
void ExportImpl(std::ostream& networkModel) override {
|
||||
networkModel << exportString << std::endl;
|
||||
}
|
||||
MOCK_METHOD2(CreateInferRequestImpl, IInferRequestInternal::Ptr(InputsDataMap, OutputsDataMap));
|
||||
};
|
||||
|
@ -1,23 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "ie_iexecutable_network.hpp"
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_thread_safe_async_only.hpp>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class MockExecutableNetworkThreadSafeAsyncOnly : public ExecutableNetworkThreadSafeAsyncOnly {
|
||||
public:
|
||||
MOCK_METHOD2(CreateAsyncInferRequestImpl,
|
||||
AsyncInferRequestInternal::Ptr(InputsDataMap networkInputs, OutputsDataMap networkOutputs));
|
||||
MOCK_METHOD1(Export, void(const std::string &));
|
||||
void Export(std::ostream&) override {}
|
||||
};
|
@ -15,7 +15,7 @@ using namespace InferenceEngine;
|
||||
class MockExecutableNetworkThreadSafe : public ExecutableNetworkThreadSafeDefault {
|
||||
public:
|
||||
MOCK_METHOD2(CreateInferRequestImpl,
|
||||
std::shared_ptr<InferRequestInternal>(InputsDataMap networkInputs, OutputsDataMap networkOutputs));
|
||||
std::shared_ptr<IInferRequestInternal>(InputsDataMap networkInputs, OutputsDataMap networkOutputs));
|
||||
MOCK_METHOD1(Export, void(const std::string &));
|
||||
void Export(std::ostream &) override {}
|
||||
};
|
||||
|
@ -1,27 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "ie_iexecutable_network.hpp"
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class MockInferRequestInternal : public InferRequestInternal {
|
||||
public:
|
||||
MockInferRequestInternal(InputsDataMap networkInputs, OutputsDataMap networkOutputs)
|
||||
: InferRequestInternal(networkInputs, networkOutputs) {}
|
||||
using InferRequestInternal::SetBlob;
|
||||
using InferRequestInternal::GetBlob;
|
||||
MOCK_METHOD0(InferImpl, void());
|
||||
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
|
||||
MOCK_METHOD0(checkBlobs, void());
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
};
|
@ -1,32 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
|
||||
|
||||
class MockIAsyncInferRequestInternal : public InferenceEngine::IAsyncInferRequestInternal {
|
||||
public:
|
||||
MOCK_METHOD0(StartAsync, void());
|
||||
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
|
||||
MOCK_METHOD1(GetUserData, void(void **));
|
||||
MOCK_METHOD1(SetUserData, void(void *));
|
||||
MOCK_METHOD0(Infer, void());
|
||||
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngine::InferenceEngineProfileInfo>());
|
||||
MOCK_METHOD2(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &));
|
||||
MOCK_METHOD1(GetBlob, InferenceEngine::Blob::Ptr(const std::string&));
|
||||
MOCK_METHOD3(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &, const InferenceEngine::PreProcessInfo&));
|
||||
MOCK_CONST_METHOD1(GetPreProcess, const InferenceEngine::PreProcessInfo&(const std::string&));
|
||||
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
};
|
@ -14,10 +14,8 @@
|
||||
#include "ie_icnn_network.hpp"
|
||||
#include "ie_iexecutable_network.hpp"
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
@ -25,7 +23,7 @@ class MockIExecutableNetworkInternal : public IExecutableNetworkInternal {
|
||||
public:
|
||||
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap(void));
|
||||
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap(void));
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
|
||||
MOCK_METHOD1(Export, void(const std::string &));
|
||||
void Export(std::ostream &) override {};
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
|
||||
|
@ -10,16 +10,24 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_variable_state_internal.hpp>
|
||||
|
||||
class MockIInferRequestInternal : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
using InferenceEngine::IInferRequestInternal::IInferRequestInternal;
|
||||
MOCK_METHOD0(StartAsync, void());
|
||||
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
|
||||
MOCK_METHOD0(Infer, void());
|
||||
MOCK_CONST_METHOD1(GetPerformanceCounts, void(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &));
|
||||
MOCK_METHOD2(SetBlob, void(const char *name, const InferenceEngine::Blob::Ptr &));
|
||||
MOCK_METHOD2(GetBlob, void(const char *name, InferenceEngine::Blob::Ptr &));
|
||||
MOCK_METHOD3(SetBlob, void(const char*, const InferenceEngine::Blob::Ptr&, const InferenceEngine::PreProcessInfo&));
|
||||
MOCK_METHOD2(GetPreProcess, void(const char*, const InferenceEngine::PreProcessInfo**));
|
||||
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngine::InferenceEngineProfileInfo>());
|
||||
MOCK_METHOD2(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &));
|
||||
MOCK_METHOD1(GetBlob, InferenceEngine::Blob::Ptr(const std::string&));
|
||||
MOCK_METHOD3(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &, const InferenceEngine::PreProcessInfo&));
|
||||
MOCK_CONST_METHOD1(GetPreProcess, const InferenceEngine::PreProcessInfo&(const std::string&));
|
||||
MOCK_METHOD1(SetCallback, void(std::function<void(std::exception_ptr)>));
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<InferenceEngine::IVariableStateInternal::Ptr>());
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
MOCK_METHOD0(StartAsyncImpl, void());
|
||||
MOCK_METHOD0(InferImpl, void());
|
||||
MOCK_METHOD0(checkBlobs, void());
|
||||
};
|
||||
|
@ -14,7 +14,7 @@ class MockIInferencePlugin : public InferenceEngine::IInferencePlugin {
|
||||
public:
|
||||
MOCK_METHOD1(AddExtension, void(InferenceEngine::IExtensionPtr));
|
||||
MOCK_METHOD2(LoadNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
|
||||
const CNNNetwork&, const std::map<std::string, std::string>&));
|
||||
const InferenceEngine::CNNNetwork&, const std::map<std::string, std::string>&));
|
||||
MOCK_METHOD2(ImportNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
|
||||
const std::string&, const std::map<std::string, std::string>&));
|
||||
MOCK_METHOD1(SetConfig, void(const std::map<std::string, std::string> &));
|
||||
@ -38,4 +38,7 @@ public:
|
||||
MOCK_METHOD3(ImportNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
|
||||
std::istream&, const InferenceEngine::RemoteContext::Ptr&,
|
||||
const std::map<std::string, std::string>&));
|
||||
MOCK_QUALIFIED_METHOD2(QueryNetwork, const,
|
||||
InferenceEngine::QueryNetworkResult(const InferenceEngine::CNNNetwork&,
|
||||
const std::map<std::string, std::string>&));
|
||||
};
|
||||
|
@ -7,9 +7,8 @@
|
||||
|
||||
#include <cpp_interfaces/base/ie_executable_network_base.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_default.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
@ -18,85 +17,11 @@ using namespace InferenceEngine::details;
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
class ExecutableNetworkThreadSafeAsyncOnlyTests : public ::testing::Test {
|
||||
protected:
|
||||
shared_ptr<MockExecutableNetworkThreadSafeAsyncOnly> mockExeNetwork;
|
||||
shared_ptr<MockAsyncInferRequestInternal> mockAsyncInferRequestInternal;
|
||||
shared_ptr<IExecutableNetwork> exeNetwork;
|
||||
ResponseDesc dsc;
|
||||
StatusCode sts;
|
||||
|
||||
virtual void TearDown() {
|
||||
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockAsyncInferRequestInternal.get()));
|
||||
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetwork.get()));
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
|
||||
exeNetwork = std::make_shared<ExecutableNetworkBase>(mockExeNetwork);
|
||||
InputsDataMap networkInputs;
|
||||
OutputsDataMap networkOutputs;
|
||||
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsThreadSafeImplAndSetNetworkIO) {
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||
Return(mockAsyncInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
|
||||
ASSERT_NE(threadSafeReq, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, returnErrorIfInferThrowsException) {
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||
Return(mockAsyncInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
|
||||
EXPECT_NO_THROW(sts = req->Infer(&dsc));
|
||||
ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, returnErrorIfStartAsyncThrowsException) {
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||
Return(mockAsyncInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).WillOnce(Throw(std::runtime_error("")));
|
||||
EXPECT_NO_THROW(sts = req->StartAsync(&dsc));
|
||||
ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, canForwardStartAsyncAndInfer) {
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||
Return(mockAsyncInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).Times(1);
|
||||
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).Times(1);
|
||||
|
||||
EXPECT_NO_THROW(req->StartAsync(&dsc)) << dsc.msg;
|
||||
EXPECT_NO_THROW(req->Infer(&dsc)) << dsc.msg;
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, canForwardInferAndStartAsync) {
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||
Return(mockAsyncInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).Times(1);
|
||||
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).Times(1);
|
||||
EXPECT_NO_THROW(req->Infer(&dsc)) << dsc.msg;
|
||||
EXPECT_NO_THROW(req->StartAsync(&dsc)) << dsc.msg;
|
||||
}
|
||||
|
||||
class ExecutableNetworkThreadSafeTests : public ::testing::Test {
|
||||
protected:
|
||||
shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetwork;
|
||||
shared_ptr<IExecutableNetwork> exeNetwork;
|
||||
shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
|
||||
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
|
||||
ResponseDesc dsc;
|
||||
StatusCode sts;
|
||||
|
||||
@ -110,7 +35,7 @@ protected:
|
||||
exeNetwork = std::make_shared<ExecutableNetworkBase>(mockExeNetwork);
|
||||
InputsDataMap networkInputs;
|
||||
OutputsDataMap networkOutputs;
|
||||
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
|
||||
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(networkInputs, networkOutputs);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -7,13 +7,16 @@
|
||||
#include <gmock/gmock-generated-actions.h>
|
||||
|
||||
#include <cpp/ie_infer_request.hpp>
|
||||
#include <cpp/ie_executable_network.hpp>
|
||||
#include <ie_plugin_cpp.hpp>
|
||||
#include <cpp_interfaces/exception2status.hpp>
|
||||
#include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinference_plugin.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iinfer_request.hpp"
|
||||
#include "unit_test_utils/mocks/mock_not_empty_icnn_network.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
@ -25,7 +28,7 @@ constexpr const char* MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME;
|
||||
|
||||
class InferRequestBaseTests : public ::testing::Test {
|
||||
protected:
|
||||
std::shared_ptr<MockIAsyncInferRequestInternal> mock_impl;
|
||||
std::shared_ptr<MockIInferRequestInternal> mock_impl;
|
||||
shared_ptr<IInferRequest> request;
|
||||
ResponseDesc dsc;
|
||||
|
||||
@ -33,7 +36,7 @@ protected:
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
mock_impl.reset(new MockIAsyncInferRequestInternal());
|
||||
mock_impl.reset(new MockIInferRequestInternal());
|
||||
request = std::make_shared<InferRequestBase>(mock_impl);
|
||||
}
|
||||
};
|
||||
@ -55,42 +58,6 @@ TEST_F(InferRequestBaseTests, canCatchUnknownErrorInStartAsync) {
|
||||
ASSERT_EQ(UNEXPECTED, request->StartAsync(nullptr));
|
||||
}
|
||||
|
||||
// GetUserData
|
||||
TEST_F(InferRequestBaseTests, canForwardGetUserData) {
|
||||
void **data = nullptr;
|
||||
EXPECT_CALL(*mock_impl.get(), GetUserData(data)).Times(1);
|
||||
ASSERT_EQ(OK, request->GetUserData(data, &dsc));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestBaseTests, canReportErrorInGetUserData) {
|
||||
EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
|
||||
ASSERT_NE(request->GetUserData(nullptr, &dsc), OK);
|
||||
ASSERT_STREQ(dsc.msg, "compare");
|
||||
}
|
||||
|
||||
TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetUserData) {
|
||||
EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(5));
|
||||
ASSERT_EQ(UNEXPECTED, request->GetUserData(nullptr, nullptr));
|
||||
}
|
||||
|
||||
// SetUserData
|
||||
TEST_F(InferRequestBaseTests, canForwardSetUserData) {
|
||||
void *data = nullptr;
|
||||
EXPECT_CALL(*mock_impl.get(), SetUserData(data)).Times(1);
|
||||
ASSERT_EQ(OK, request->SetUserData(data, &dsc));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestBaseTests, canReportErrorInSetUserData) {
|
||||
EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
|
||||
ASSERT_NE(request->SetUserData(nullptr, &dsc), OK);
|
||||
ASSERT_STREQ(dsc.msg, "compare");
|
||||
}
|
||||
|
||||
TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetUserData) {
|
||||
EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(5));
|
||||
ASSERT_EQ(UNEXPECTED, request->SetUserData(nullptr, nullptr));
|
||||
}
|
||||
|
||||
// Wait
|
||||
TEST_F(InferRequestBaseTests, canForwardWait) {
|
||||
int64_t ms = 0;
|
||||
@ -192,23 +159,27 @@ TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetBlob) {
|
||||
// SetCompletionCallback
|
||||
TEST_F(InferRequestBaseTests, canForwardSetCompletionCallback) {
|
||||
InferenceEngine::IInferRequest::CompletionCallback callback = nullptr;
|
||||
EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(callback)).Times(1);
|
||||
EXPECT_CALL(*mock_impl.get(), SetCallback(_)).Times(1);
|
||||
ASSERT_NO_THROW(request->SetCompletionCallback(callback));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestBaseTests, canReportErrorInSetCompletionCallback) {
|
||||
EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(_)).WillOnce(Throw(std::runtime_error("compare")));
|
||||
ASSERT_NO_THROW(request->SetCompletionCallback(nullptr));
|
||||
EXPECT_CALL(*mock_impl.get(), SetCallback(_)).WillOnce(Throw(std::runtime_error("compare")));
|
||||
ASSERT_NE(request->SetCompletionCallback(nullptr), OK);
|
||||
}
|
||||
|
||||
|
||||
class InferRequestTests : public ::testing::Test {
|
||||
protected:
|
||||
std::shared_ptr<MockIInferRequest> mock_request;
|
||||
InferRequest::Ptr requestWrapper;
|
||||
std::shared_ptr<MockIInferRequestInternal> mock_request;
|
||||
InferRequest request;
|
||||
ResponseDesc dsc;
|
||||
|
||||
shared_ptr<MockAsyncInferRequestInternal> mockInferRequestInternal;
|
||||
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
|
||||
InferenceEngine::ExecutableNetwork exeNetwork;
|
||||
MockIInferencePlugin* mockIPlugin;
|
||||
InferencePlugin plugin;
|
||||
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
|
||||
MockNotEmptyICNNNetwork mockNotEmptyNet;
|
||||
std::string _incorrectName;
|
||||
std::string _inputName;
|
||||
@ -216,15 +187,21 @@ protected:
|
||||
std::string _inputDataNotAllocatedError;
|
||||
std::string _inputDataIsEmptyError;
|
||||
|
||||
virtual void TearDown() {
|
||||
void TearDown() override {
|
||||
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
|
||||
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_request.get()));
|
||||
EXPECT_TRUE(Mock::VerifyAndClearExpectations(requestWrapper.get()));
|
||||
request = {};
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
mock_request = make_shared<MockIInferRequest>();
|
||||
requestWrapper = make_shared<InferRequest>(mock_request);
|
||||
void SetUp() override {
|
||||
mock_request = make_shared<MockIInferRequestInternal>();
|
||||
mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
|
||||
ON_CALL(*mockIExeNet, CreateInferRequest()).WillByDefault(Return(mock_request));
|
||||
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
|
||||
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
|
||||
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
|
||||
exeNetwork = plugin.LoadNetwork({}, {});
|
||||
request = exeNetwork.CreateInferRequest();
|
||||
_incorrectName = "incorrect_name";
|
||||
_inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
|
||||
_failedToFindInOutError =
|
||||
@ -235,15 +212,20 @@ protected:
|
||||
+ _inputName + "\'";
|
||||
}
|
||||
|
||||
InferRequest::Ptr getInferRequestWithMockImplInside() {
|
||||
InferRequest getInferRequestWithMockImplInside() {
|
||||
IInferRequest::Ptr inferRequest;
|
||||
InputsDataMap inputsInfo;
|
||||
mockNotEmptyNet.getInputsInfo(inputsInfo);
|
||||
OutputsDataMap outputsInfo;
|
||||
mockNotEmptyNet.getOutputsInfo(outputsInfo);
|
||||
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
inferRequest = std::make_shared<InferRequestBase>(mockInferRequestInternal);
|
||||
return make_shared<InferRequest>(inferRequest);
|
||||
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
auto mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
|
||||
ON_CALL(*mockIExeNet, CreateInferRequest()).WillByDefault(Return(mockInferRequestInternal));
|
||||
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
|
||||
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
|
||||
auto plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
|
||||
auto exeNetwork = plugin.LoadNetwork({}, {});
|
||||
return exeNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
std::string getExceptionMessage(std::function<void()> function) {
|
||||
@ -274,60 +256,52 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
// constructor tests
|
||||
TEST_F(InferRequestTests, constructorsTests) {
|
||||
// construction from the non-null should not throw
|
||||
ASSERT_NO_THROW(InferRequest req(mock_request));
|
||||
IInferRequest::Ptr tmp;
|
||||
// InferRequest's "actual" is nullptr, let's check it throws on construction
|
||||
ASSERT_THROW(InferRequest req(tmp), Exception);
|
||||
}
|
||||
|
||||
// StartAsync
|
||||
TEST_F(InferRequestTests, canForwardStartAsync) {
|
||||
EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(OK));
|
||||
ASSERT_NO_THROW(requestWrapper->StartAsync());
|
||||
EXPECT_CALL(*mock_request.get(), StartAsync());
|
||||
ASSERT_NO_THROW(request.StartAsync());
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfStartAsyncReturnNotOK) {
|
||||
EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(GENERAL_ERROR));
|
||||
ASSERT_THROW(requestWrapper->StartAsync(), Exception);
|
||||
EXPECT_CALL(*mock_request.get(), StartAsync()).WillOnce(Throw(GeneralError{""}));
|
||||
ASSERT_THROW(request.StartAsync(), Exception);
|
||||
}
|
||||
|
||||
// Wait
|
||||
TEST_F(InferRequestTests, canForwardWait) {
|
||||
int64_t ms = 0;
|
||||
EXPECT_CALL(*mock_request.get(), Wait(ms, _)).WillOnce(Return(OK));
|
||||
ASSERT_TRUE(OK == requestWrapper->Wait(ms));
|
||||
EXPECT_CALL(*mock_request.get(), Wait(_)).WillOnce(Return(OK));
|
||||
ASSERT_TRUE(OK == request.Wait(ms));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, canForwardStatusFromWait) {
|
||||
EXPECT_CALL(*mock_request.get(), Wait(_, _)).WillOnce(Return(RESULT_NOT_READY));
|
||||
ASSERT_EQ(requestWrapper->Wait(0), RESULT_NOT_READY);
|
||||
EXPECT_CALL(*mock_request.get(), Wait(_)).WillOnce(Return(RESULT_NOT_READY));
|
||||
ASSERT_EQ(request.Wait(0), RESULT_NOT_READY);
|
||||
}
|
||||
|
||||
// Infer
|
||||
TEST_F(InferRequestTests, canForwardInfer) {
|
||||
EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(OK));
|
||||
ASSERT_NO_THROW(requestWrapper->Infer());
|
||||
EXPECT_CALL(*mock_request.get(), Infer());
|
||||
ASSERT_NO_THROW(request.Infer());
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfInferReturnNotOK) {
|
||||
EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(GENERAL_ERROR));
|
||||
ASSERT_THROW(requestWrapper->Infer(), Exception);
|
||||
EXPECT_CALL(*mock_request.get(), Infer()).WillOnce(Throw(GeneralError{""}));
|
||||
ASSERT_THROW(request.Infer(), Exception);
|
||||
}
|
||||
|
||||
// GetPerformanceCounts
|
||||
TEST_F(InferRequestTests, canForwardGetPerformanceCounts) {
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
|
||||
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(OK));
|
||||
ASSERT_NO_THROW(info = requestWrapper->GetPerformanceCounts());
|
||||
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts()).WillOnce(Return(info));
|
||||
ASSERT_NO_THROW(info = request.GetPerformanceCounts());
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfGetPerformanceCountsReturnNotOK) {
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
|
||||
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(GENERAL_ERROR));
|
||||
ASSERT_THROW(info = requestWrapper->GetPerformanceCounts(), Exception);
|
||||
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts()).WillOnce(Throw(GeneralError{""}));
|
||||
ASSERT_THROW(info = request.GetPerformanceCounts(), Exception);
|
||||
}
|
||||
|
||||
MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
|
||||
@ -343,15 +317,15 @@ TEST_F(InferRequestTests, getInputCallsSetBlob) {
|
||||
BlobMap blobMap{{blobName1, inblob},
|
||||
{blobName2, inblob}};
|
||||
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
|
||||
ASSERT_NO_THROW(requestWrapper->SetInput(blobMap));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(blobName1, inblob));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(blobName2, inblob));
|
||||
ASSERT_NO_THROW(request.SetInput(blobMap));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfSetInputReturnNotOK) {
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
|
||||
BlobMap blobMap{{{}, {}}};
|
||||
ASSERT_THROW(requestWrapper->SetInput(blobMap), Exception);
|
||||
ASSERT_THROW(request.SetInput(blobMap), Exception);
|
||||
}
|
||||
|
||||
// SetOutput
|
||||
@ -362,9 +336,9 @@ TEST_F(InferRequestTests, getOutputCallsSetBlob) {
|
||||
BlobMap blobMap{{blobName1, inblob},
|
||||
{blobName2, inblob}};
|
||||
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
|
||||
ASSERT_NO_THROW(requestWrapper->SetOutput(blobMap));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(blobName1, inblob));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(blobName2, inblob));
|
||||
ASSERT_NO_THROW(request.SetOutput(blobMap));
|
||||
}
|
||||
|
||||
// GetBlob
|
||||
@ -373,16 +347,16 @@ TEST_F(InferRequestTests, canForwardGetBlob) {
|
||||
blob->allocate();
|
||||
std::string name = "blob1";
|
||||
|
||||
EXPECT_CALL(*mock_request.get(), GetBlob(StrEq(name.c_str()), _, _)).WillOnce(DoAll(SetArgReferee<1>(blob), Return(OK)));
|
||||
ASSERT_NO_THROW(requestWrapper->GetBlob(name));
|
||||
EXPECT_CALL(*mock_request.get(), GetBlob(_)).WillOnce(Return(blob));
|
||||
ASSERT_NO_THROW(request.GetBlob(name));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfGetBlobReturnNotOK) {
|
||||
Blob::Ptr blob;
|
||||
std::string name = "blob1";
|
||||
|
||||
EXPECT_CALL(*mock_request.get(), GetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
|
||||
ASSERT_THROW(blob = requestWrapper->GetBlob(name), Exception);
|
||||
EXPECT_CALL(*mock_request.get(), GetBlob(_)).WillOnce(Throw(GeneralError{""}));
|
||||
ASSERT_THROW(blob = request.GetBlob(name), Exception);
|
||||
}
|
||||
|
||||
// SetBlob
|
||||
@ -390,79 +364,49 @@ TEST_F(InferRequestTests, canForwardSetBlob) {
|
||||
Blob::Ptr blob;
|
||||
std::string name = "blob1";
|
||||
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(name.c_str()), blob, _)).WillOnce(Return(OK));
|
||||
ASSERT_NO_THROW(requestWrapper->SetBlob(name, blob));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(name, blob));
|
||||
ASSERT_NO_THROW(request.SetBlob(name, blob));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfSetBlobReturnNotOK) {
|
||||
Blob::Ptr blob;
|
||||
std::string name = "blob1";
|
||||
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
|
||||
ASSERT_THROW(requestWrapper->SetBlob(name, blob), Exception);
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
|
||||
ASSERT_THROW(request.SetBlob(name, blob), Exception);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, throwsIfSetOutputReturnNotOK) {
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
|
||||
BlobMap blobMap{{{}, {}}};
|
||||
ASSERT_THROW(requestWrapper->SetOutput(blobMap), Exception);
|
||||
}
|
||||
|
||||
// SetCompletionCallback API
|
||||
void callme(InferenceEngine::IInferRequest::Ptr p, InferenceEngine::StatusCode) {
|
||||
void *data = nullptr;
|
||||
p->GetUserData(&data, nullptr);
|
||||
ASSERT_NE(nullptr, data);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, canForwardCompletionCallback) {
|
||||
void *data = nullptr;
|
||||
EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
|
||||
DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
|
||||
EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
|
||||
DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
|
||||
*pData = data;
|
||||
}), Return(OK)));
|
||||
EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
|
||||
ASSERT_NO_THROW(requestWrapper->SetCompletionCallback(&callme));
|
||||
ASSERT_THROW(request.SetOutput(blobMap), Exception);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, canForwardAnyCallback) {
|
||||
void *data = nullptr;
|
||||
EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
|
||||
DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
|
||||
EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
|
||||
DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
|
||||
*pData = data;
|
||||
}), Return(OK)));
|
||||
EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
|
||||
|
||||
ASSERT_NO_THROW(requestWrapper->SetCompletionCallback([&]() {
|
||||
// data used to store callback pointer
|
||||
ASSERT_NE(data, nullptr);
|
||||
}));
|
||||
EXPECT_CALL(*mock_request.get(), SetCallback(_));
|
||||
ASSERT_NO_THROW(request.SetCompletionCallback([] {}));
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, failToSetInputWithInCorrectName) {
|
||||
auto InferRequest = getInferRequestWithMockImplInside();
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(NotFound{""}));
|
||||
auto blobMap = getBlobMapWithIncorrectName();
|
||||
ASSERT_THROW(InferRequest->SetInput(blobMap), NotFound);
|
||||
ASSERT_THROW(request.SetInput(blobMap), NotFound);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, failToSetOutputWithInCorrectName) {
|
||||
auto InferRequest = getInferRequestWithMockImplInside();
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(NotFound{""}));
|
||||
auto blobMap = getBlobMapWithIncorrectName();
|
||||
ASSERT_THROW(InferRequest->SetOutput(blobMap), NotFound);
|
||||
ASSERT_THROW(request.SetOutput(blobMap), NotFound);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, failToSetInputWithNotAllocatedInput) {
|
||||
auto InferRequest = getInferRequestWithMockImplInside();
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(NotAllocated{""}));
|
||||
auto blobMap = getBlobMapWithNotAllocatedInput();
|
||||
ASSERT_THROW(InferRequest->SetInput(blobMap), NotAllocated);
|
||||
ASSERT_THROW(request.SetInput(blobMap), NotAllocated);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestTests, failToSetInputWithEmptyDimensions) {
|
||||
auto InferRequest = getInferRequestWithMockImplInside();
|
||||
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
|
||||
auto blobMap = getBlobMapWithEmptyDimensions();
|
||||
ASSERT_THROW(InferRequest->SetInput(blobMap), GeneralError);
|
||||
ASSERT_THROW(request.SetInput(blobMap), GeneralError);
|
||||
}
|
||||
|
@ -13,7 +13,7 @@
|
||||
#include <threading/ie_cpu_streams_executor.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/mock_task_executor.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
@ -52,7 +52,7 @@ protected:
|
||||
shared_ptr<AsyncInferRequestThreadSafeDefault> testRequest;
|
||||
ResponseDesc dsc;
|
||||
|
||||
shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
|
||||
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
|
||||
MockTaskExecutor::Ptr mockTaskExecutor;
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ protected:
|
||||
InputsDataMap inputsInfo;
|
||||
OutputsDataMap outputsInfo;
|
||||
mockTaskExecutor = make_shared<MockTaskExecutor>();
|
||||
mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, mockTaskExecutor, mockTaskExecutor);
|
||||
}
|
||||
};
|
||||
@ -90,26 +90,6 @@ TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfStartAsyncFails)
|
||||
taskExecutor->executeAll();
|
||||
}
|
||||
|
||||
// GetUserData
|
||||
TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetUserData) {
|
||||
auto taskExecutor = std::make_shared<DeferedExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(1).WillOnce(Return());
|
||||
ASSERT_NO_THROW(testRequest->StartAsync());
|
||||
ASSERT_THROW(testRequest->GetUserData(nullptr), RequestBusy);
|
||||
taskExecutor->executeAll();
|
||||
}
|
||||
|
||||
// SetUserData
|
||||
TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetUserData) {
|
||||
auto taskExecutor = std::make_shared<DeferedExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(1).WillOnce(Return());
|
||||
ASSERT_NO_THROW(testRequest->StartAsync());
|
||||
ASSERT_THROW(testRequest->SetUserData(nullptr), RequestBusy);
|
||||
taskExecutor->executeAll();
|
||||
}
|
||||
|
||||
// Wait
|
||||
TEST_F(InferRequestThreadSafeDefaultTests, returnInferNotStartedOnWait) {
|
||||
int64_t ms = 0;
|
||||
@ -175,58 +155,42 @@ TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetCompletionCallb
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(1).WillOnce(Return());
|
||||
ASSERT_NO_THROW(testRequest->StartAsync());
|
||||
ASSERT_THROW(testRequest->SetCompletionCallback({}), RequestBusy);
|
||||
ASSERT_THROW(testRequest->SetCallback({}), RequestBusy);
|
||||
taskExecutor->executeAll();
|
||||
}
|
||||
|
||||
TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
|
||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||
auto taskExecutor = std::make_shared<DeferedExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
asyncRequest.reset(new InferRequestBase(testRequest));
|
||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
|
||||
ASSERT_EQ((int) StatusCode::OK, status);
|
||||
std::exception_ptr exceptionPtr;
|
||||
testRequest->SetCallback([&](std::exception_ptr exceptionPtr_) {
|
||||
exceptionPtr = exceptionPtr_;
|
||||
});
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).Times(1);
|
||||
|
||||
testRequest->StartAsync();
|
||||
taskExecutor->executeAll();
|
||||
testRequest->Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
ASSERT_EQ(nullptr, exceptionPtr);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed) {
|
||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||
auto taskExecutor = std::make_shared<DeferedExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
asyncRequest.reset(new InferRequestBase(testRequest));
|
||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
bool wasCalled = false;
|
||||
InferRequest cppRequest(asyncRequest);
|
||||
std::function<void(InferRequest, StatusCode)> callback =
|
||||
[&](InferRequest request, StatusCode status) {
|
||||
wasCalled = true;
|
||||
ASSERT_EQ(StatusCode::GENERAL_ERROR, status);
|
||||
};
|
||||
cppRequest.SetCompletionCallback(callback);
|
||||
std::exception_ptr exceptionPtr;
|
||||
testRequest->SetCallback([&](std::exception_ptr exceptionPtr_) {
|
||||
exceptionPtr = exceptionPtr_;
|
||||
});
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
|
||||
|
||||
testRequest->StartAsync();
|
||||
taskExecutor->executeAll();
|
||||
EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
|
||||
ASSERT_TRUE(wasCalled);
|
||||
ASSERT_NE(nullptr, exceptionPtr);
|
||||
}
|
||||
|
||||
TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailedAndNoCallback) {
|
||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
asyncRequest.reset(new InferRequestBase(testRequest));
|
||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
|
||||
|
||||
testRequest->StartAsync();
|
||||
EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
|
||||
}
|
||||
|
@ -11,7 +11,6 @@
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_ivariable_state_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinference_plugin.hpp"
|
||||
#include "ie_plugin_cpp.hpp"
|
||||
|
||||
@ -20,43 +19,32 @@ using namespace std;
|
||||
using namespace InferenceEngine;
|
||||
using namespace InferenceEngine::details;
|
||||
|
||||
template <class T>
|
||||
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
|
||||
typename InferRequestBase::Ptr req(new InferRequestBase(impl));
|
||||
return InferenceEngine::InferRequest(req);
|
||||
}
|
||||
|
||||
|
||||
class VariableStateTests : public ::testing::Test {
|
||||
protected:
|
||||
shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
|
||||
shared_ptr<MockIAsyncInferRequestInternal> mockInferRequestInternal;
|
||||
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
|
||||
shared_ptr<MockIVariableStateInternal> mockVariableStateInternal;
|
||||
|
||||
struct TestPluginInternal : public MockIInferencePlugin {
|
||||
TestPluginInternal(const std::shared_ptr<MockIExecutableNetworkInternal>& mockIExeNet_) : mockIExeNet{mockIExeNet_} {}
|
||||
std::shared_ptr<IExecutableNetworkInternal> LoadNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) override {
|
||||
return mockIExeNet;
|
||||
}
|
||||
QueryNetworkResult QueryNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) const override {return {};}
|
||||
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
|
||||
};
|
||||
struct TestPlugin : public InferenceEngine::InferencePlugin {
|
||||
TestPlugin(std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet) :
|
||||
InferenceEngine::InferencePlugin(InferenceEngine::details::SOPointer<TestPluginInternal>{
|
||||
new TestPluginInternal{mockIExeNet}}) {}
|
||||
};
|
||||
MockIInferencePlugin* mockIPlugin;
|
||||
InferencePlugin plugin;
|
||||
ExecutableNetwork net;
|
||||
InferRequest req;
|
||||
|
||||
virtual void SetUp() {
|
||||
mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
|
||||
mockInferRequestInternal = make_shared<MockIAsyncInferRequestInternal>();
|
||||
mockInferRequestInternal = make_shared<MockIInferRequestInternal>();
|
||||
mockVariableStateInternal = make_shared<MockIVariableStateInternal>();
|
||||
ON_CALL(*mockExeNetworkInternal, CreateInferRequest()).WillByDefault(Return(mockInferRequestInternal));
|
||||
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
|
||||
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockExeNetworkInternal));
|
||||
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
|
||||
net = plugin.LoadNetwork({}, {});
|
||||
req = net.CreateInferRequest();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn(1);
|
||||
toReturn[0] = mockVariableStateInternal;
|
||||
|
||||
@ -69,7 +57,6 @@ TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToA
|
||||
|
||||
TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
|
||||
@ -81,7 +68,6 @@ TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppTo
|
||||
|
||||
TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
@ -95,7 +81,6 @@ TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAP
|
||||
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesReset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -109,7 +94,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesReset) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -123,7 +107,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -137,7 +120,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -152,7 +134,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -168,7 +149,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -184,7 +164,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
Blob::Ptr saver;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
@ -204,7 +183,6 @@ TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
|
||||
|
||||
TEST_F(VariableStateTests, VariableStateCanPropagateGetLastState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
@ -269,18 +247,16 @@ TEST_F(VariableStateTests, VariableStateInternalCanSaveStateByReference) {
|
||||
|
||||
// Tests for InferRequest::QueryState
|
||||
TEST_F(VariableStateTests, InferRequestCanConvertOneVariableStateFromCppToAPI) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn(1);
|
||||
toReturn[0] = mockVariableStateInternal;
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = req.QueryState();
|
||||
ASSERT_EQ(state.size(), 1);
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillOnce(Return(toReturn));
|
||||
@ -290,23 +266,21 @@ TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI)
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InferRequestCanConvert2VariableStatesFromCPPtoAPI) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = req.QueryState();
|
||||
ASSERT_EQ(state.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
|
||||
|
||||
auto state = req.QueryState();
|
||||
@ -314,11 +288,10 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
|
||||
|
||||
auto state = req.QueryState();
|
||||
@ -326,11 +299,10 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetName) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto state = req.QueryState();
|
||||
@ -339,7 +311,6 @@ auto req = make_infer_request(mockInferRequestInternal);
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -356,7 +327,6 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -374,7 +344,6 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
@ -391,7 +360,6 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
Blob::Ptr saver;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
@ -409,7 +377,6 @@ TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateGetLastState) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
|
@ -27,7 +27,7 @@ protected:
|
||||
shared_ptr<MockInferencePluginInternal> mock_plugin_impl;
|
||||
shared_ptr<MockExecutableNetworkInternal> mockExeNetworkInternal;
|
||||
shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetworkTS;
|
||||
shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
|
||||
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
|
||||
std::shared_ptr<MockNotEmptyICNNNetwork> mockNotEmptyNet = std::make_shared<MockNotEmptyICNNNetwork>();
|
||||
std::string pluginId;
|
||||
|
||||
@ -50,13 +50,13 @@ protected:
|
||||
mockExeNetworkInternal->SetPointerToPlugin(mock_plugin_impl);
|
||||
}
|
||||
|
||||
void getInferRequestWithMockImplInside(IInferRequest::Ptr &request) {
|
||||
void getInferRequestWithMockImplInside(IInferRequestInternal::Ptr &request) {
|
||||
IExecutableNetworkInternal::Ptr exeNetwork;
|
||||
InputsDataMap inputsInfo;
|
||||
mockNotEmptyNet->getInputsInfo(inputsInfo);
|
||||
OutputsDataMap outputsInfo;
|
||||
mockNotEmptyNet->getOutputsInfo(outputsInfo);
|
||||
mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
mockExeNetworkTS = make_shared<MockExecutableNetworkThreadSafe>();
|
||||
EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(_, _)).WillOnce(Return(mockExeNetworkTS));
|
||||
EXPECT_CALL(*mockExeNetworkTS.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
|
||||
@ -74,14 +74,15 @@ TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithInCorrectName) {
|
||||
inBlob->allocate();
|
||||
string inputName = "not_input";
|
||||
std::string refError = "[ NOT_FOUND ] Failed to find input or output with name: \'" + inputName + "\'";
|
||||
IInferRequest::Ptr inferRequest;
|
||||
IInferRequestInternal::Ptr inferRequest;
|
||||
getInferRequestWithMockImplInside(inferRequest);
|
||||
|
||||
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
|
||||
ASSERT_EQ(StatusCode::NOT_FOUND, sts);
|
||||
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << dsc.msg;
|
||||
try {
|
||||
inferRequest->SetBlob(inputName, inBlob);
|
||||
} catch(InferenceEngine::NotFound& ex) {
|
||||
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithEmptyName) {
|
||||
@ -89,56 +90,60 @@ TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithEmptyName) {
|
||||
inBlob->allocate();
|
||||
string inputName = "not_input";
|
||||
std::string refError = "[ NOT_FOUND ] Failed to set blob with empty name";
|
||||
IInferRequest::Ptr inferRequest;
|
||||
IInferRequestInternal::Ptr inferRequest;
|
||||
getInferRequestWithMockImplInside(inferRequest);
|
||||
|
||||
ASSERT_NO_THROW(sts = inferRequest->SetBlob("", inBlob, &dsc));
|
||||
ASSERT_EQ(StatusCode::NOT_FOUND, sts);
|
||||
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << dsc.msg;
|
||||
try {
|
||||
inferRequest->SetBlob(inputName, inBlob);
|
||||
} catch(InferenceEngine::NotFound& ex) {
|
||||
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(InferenceEnginePluginInternalTest, failToSetNullPtr) {
|
||||
string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
|
||||
std::string refError = "[ NOT_ALLOCATED ] Failed to set empty blob with name: \'" + inputName + "\'";
|
||||
IInferRequest::Ptr inferRequest;
|
||||
IInferRequestInternal::Ptr inferRequest;
|
||||
getInferRequestWithMockImplInside(inferRequest);
|
||||
Blob::Ptr inBlob = nullptr;
|
||||
|
||||
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
|
||||
ASSERT_EQ(StatusCode::NOT_ALLOCATED, sts);
|
||||
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << dsc.msg;
|
||||
try {
|
||||
inferRequest->SetBlob(inputName, inBlob);
|
||||
} catch(InferenceEngine::NotAllocated& ex) {
|
||||
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(InferenceEnginePluginInternalTest, failToSetEmptyBlob) {
|
||||
Blob::Ptr inBlob;
|
||||
string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
|
||||
std::string refError = "[ NOT_ALLOCATED ] Failed to set empty blob with name: \'" + inputName + "\'";
|
||||
IInferRequest::Ptr inferRequest;
|
||||
IInferRequestInternal::Ptr inferRequest;
|
||||
getInferRequestWithMockImplInside(inferRequest);
|
||||
|
||||
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
|
||||
ASSERT_EQ(StatusCode::NOT_ALLOCATED, sts);
|
||||
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << dsc.msg;
|
||||
try {
|
||||
inferRequest->SetBlob(inputName, inBlob);
|
||||
} catch(InferenceEngine::NotAllocated& ex) {
|
||||
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(InferenceEnginePluginInternalTest, failToSetNotAllocatedBlob) {
|
||||
string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
|
||||
std::string refError = "[ NOT_ALLOCATED ] Input data was not allocated. Input name: \'" + inputName + "\'";
|
||||
IInferRequest::Ptr inferRequest;
|
||||
IInferRequestInternal::Ptr inferRequest;
|
||||
getInferRequestWithMockImplInside(inferRequest);
|
||||
Blob::Ptr blob = make_shared_blob<float>({ Precision::FP32, {}, NCHW });
|
||||
|
||||
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), blob, &dsc));
|
||||
ASSERT_EQ(StatusCode::NOT_ALLOCATED, sts);
|
||||
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << dsc.msg;
|
||||
try {
|
||||
inferRequest->SetBlob(inputName, blob);
|
||||
} catch(InferenceEngine::NotAllocated& ex) {
|
||||
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
|
||||
<< "\tExpected: " << refError
|
||||
<< "\n\tActual: " << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(InferenceEnginePluginInternalTest, executableNetworkInternalExportsMagicAndName) {
|
||||
|
@ -58,15 +58,6 @@ TEST(ExceptionTests, ExceptionCanBeCaughtAsStandard) {
|
||||
ASSERT_THROW(IE_THROW(), std::exception);
|
||||
}
|
||||
|
||||
TEST(ExceptionTests, CanThrowStatusCode) {
|
||||
try {
|
||||
IE_THROW(InferNotStarted);
|
||||
}
|
||||
catch (const InferenceEngine::InferNotStarted& iex) {
|
||||
ASSERT_EQ(InferenceEngine::ExceptionToStatus(iex), InferenceEngine::StatusCode::INFER_NOT_STARTED);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef NDEBUG // disabled for debug as macros calls assert()
|
||||
TEST(ExceptionTests, ExceptionWithAssertThrowsNothingIfTrue) {
|
||||
ASSERT_NO_THROW(IE_ASSERT(true) << "shouldn't assert if true");
|
||||
|
@ -39,31 +39,22 @@ class ExecutableNetworkTests : public ::testing::Test {
|
||||
protected:
|
||||
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
|
||||
InferenceEngine::ExecutableNetwork exeNetwork;
|
||||
MockIInferencePlugin* mockIPlugin;
|
||||
InferencePlugin plugin;
|
||||
|
||||
struct TestPluginInternal : public MockIInferencePlugin {
|
||||
TestPluginInternal(const std::shared_ptr<MockIExecutableNetworkInternal>& mockIExeNet_) : mockIExeNet{mockIExeNet_} {}
|
||||
std::shared_ptr<IExecutableNetworkInternal> LoadNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) override {
|
||||
return mockIExeNet;
|
||||
}
|
||||
QueryNetworkResult QueryNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) const override {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
|
||||
};
|
||||
struct TestPlugin : public InferenceEngine::InferencePlugin {
|
||||
TestPlugin(std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet) :
|
||||
InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<TestPluginInternal>{
|
||||
new TestPluginInternal{mockIExeNet}}} {}
|
||||
};
|
||||
|
||||
virtual void TearDown() {
|
||||
mockIExeNet.reset();
|
||||
exeNetwork = {};
|
||||
plugin = {};
|
||||
}
|
||||
|
||||
virtual void SetUp() {
|
||||
mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
|
||||
exeNetwork = TestPlugin{mockIExeNet}.LoadNetwork({}, {});
|
||||
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
|
||||
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
|
||||
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
|
||||
exeNetwork = plugin.LoadNetwork({}, {});
|
||||
}
|
||||
};
|
||||
|
||||
@ -126,7 +117,7 @@ IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
|
||||
protected:
|
||||
std::shared_ptr<MockIInferRequest> mockIInferReq_p;
|
||||
std::shared_ptr<MockIInferRequestInternal> mockIInferReq_p;
|
||||
|
||||
virtual void TearDown() {
|
||||
ExecutableNetworkTests::TearDown();
|
||||
@ -135,7 +126,7 @@ protected:
|
||||
|
||||
virtual void SetUp() {
|
||||
ExecutableNetworkTests::SetUp();
|
||||
mockIInferReq_p = std::make_shared<MockIInferRequest>();
|
||||
mockIInferReq_p = std::make_shared<MockIInferRequestInternal>();
|
||||
}
|
||||
};
|
||||
|
||||
@ -143,7 +134,6 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CanCreateInferRequest) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
|
||||
InferRequest actualInferReq;
|
||||
ASSERT_NO_THROW(actualInferReq = exeNetwork.CreateInferRequest());
|
||||
ASSERT_EQ(mockIInferReq_p, static_cast<IInferRequest::Ptr &>(actualInferReq));
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotOK) {
|
||||
@ -153,16 +143,14 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotO
|
||||
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfSetRequestToNullptr) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest())
|
||||
.WillOnce(Return(std::shared_ptr<MockIInferRequest>{}));
|
||||
.WillOnce(Return(std::shared_ptr<MockIInferRequestInternal>{}));
|
||||
ASSERT_THROW(exeNetwork.CreateInferRequest(), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
// CreateInferRequestPtr
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, CanCreateInferRequestPtr) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
|
||||
InferRequest::Ptr actualInferReq;
|
||||
ASSERT_NO_THROW(actualInferReq = exeNetwork.CreateInferRequestPtr());
|
||||
ASSERT_EQ(mockIInferReq_p, static_cast<IInferRequest::Ptr &>(*actualInferReq.get()));
|
||||
ASSERT_NO_THROW(exeNetwork.CreateInferRequest());
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestPtrThrowsIfReturnNotOK) {
|
||||
@ -171,7 +159,7 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestPtrThrowsIfReturnN
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestPtrThrowsIfSetRequestToNullptr) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(std::shared_ptr<MockIInferRequest>{}));
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(std::shared_ptr<MockIInferRequestInternal>{}));
|
||||
ASSERT_THROW(exeNetwork.CreateInferRequestPtr(), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
@ -194,9 +182,10 @@ protected:
|
||||
|
||||
// CreateInferRequest
|
||||
TEST_F(ExecutableNetworkBaseTests, canForwardCreateInferRequest) {
|
||||
auto inferReqInternal = std::make_shared<MockIInferRequestInternal>();
|
||||
EXPECT_CALL(*mock_impl.get(), CreateInferRequest()).Times(1).WillRepeatedly(Return(inferReqInternal));
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mock_impl.get(), CreateInferRequest()).Times(1).WillRepeatedly(Return(req));
|
||||
ASSERT_EQ(OK, exeNetwork->CreateInferRequest(req, &dsc));
|
||||
ASSERT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkBaseTests, canReportErrorInCreateInferRequest) {
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "matchers/fill_with_data.hpp"
|
||||
#include "matchers/weights_matcher.hpp"
|
||||
#include <gmock/gmock-generated-actions.h>
|
||||
#include <debug.h>
|
||||
|
||||
#include <gmock/gmock-more-actions.h>
|
||||
#include "gmock/gmock.h"
|
||||
|
@ -1202,7 +1202,7 @@ TEST_F(MKLDNNGraphStructureTests, TestOutputAfterInplacePlusConcat) {
|
||||
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
|
||||
execNetwork->setNetworkInputs(_networkInputs);
|
||||
execNetwork->setNetworkOutputs(_networkOutputs);
|
||||
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
|
||||
InferenceEngine::TensorDesc desc(InferenceEngine::Precision::FP32, {1, 3, 2, 2}, InferenceEngine::NCHW);
|
||||
InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>(desc);
|
||||
@ -1211,8 +1211,7 @@ TEST_F(MKLDNNGraphStructureTests, TestOutputAfterInplacePlusConcat) {
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
|
||||
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data", src, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data", src));
|
||||
|
||||
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
|
||||
|
||||
@ -1222,11 +1221,8 @@ TEST_F(MKLDNNGraphStructureTests, TestOutputAfterInplacePlusConcat) {
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
|
||||
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
|
||||
sts = inferRequest->Infer(&resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob(item.first, output));
|
||||
ASSERT_NO_THROW(inferRequest->Infer());
|
||||
|
||||
compare(*output, *src);
|
||||
}
|
||||
@ -1717,7 +1713,7 @@ TEST_F(MKLDNNGraphStructureTests, TestResnetPart) {
|
||||
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
|
||||
execNetwork->setNetworkInputs(_networkInputs);
|
||||
execNetwork->setNetworkOutputs(_networkOutputs);
|
||||
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
|
||||
InferenceEngine::TensorDesc desc(InferenceEngine::Precision::FP32, {1, 3, 224, 224}, InferenceEngine::NCHW);
|
||||
InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>(desc);
|
||||
@ -1726,8 +1722,7 @@ TEST_F(MKLDNNGraphStructureTests, TestResnetPart) {
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
|
||||
InferenceEngine::StatusCode sts = inferRequest->SetBlob("input", src, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("input", src));
|
||||
|
||||
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
|
||||
|
||||
@ -1737,18 +1732,16 @@ TEST_F(MKLDNNGraphStructureTests, TestResnetPart) {
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
|
||||
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob(item.first.c_str(), output));
|
||||
|
||||
sts = inferRequest->Infer(&resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->Infer());
|
||||
}
|
||||
|
||||
TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
|
||||
std::string model = R"V0G0N(
|
||||
<net batch="1" name="model" version="2">
|
||||
<layers>
|
||||
<layer id="0" name="data" precision="FP32" type="Input">
|
||||
<layer id="0" name="data1" precision="FP32" type="Input">
|
||||
<output>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
@ -1866,7 +1859,7 @@ TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
|
||||
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
|
||||
execNetwork->setNetworkInputs(_networkInputs);
|
||||
execNetwork->setNetworkOutputs(_networkOutputs);
|
||||
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
|
||||
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 20, 20}, InferenceEngine::NCHW);
|
||||
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
|
||||
@ -1885,10 +1878,9 @@ TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
|
||||
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data1", src1, &resp);
|
||||
sts = inferRequest->SetBlob("data2", src2, &resp);
|
||||
sts = inferRequest->SetBlob("data3", src3, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data1", src1));
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data2", src2));
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data3", src3));
|
||||
|
||||
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
|
||||
|
||||
@ -1898,11 +1890,9 @@ TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
|
||||
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob(item.first, output));
|
||||
|
||||
sts = inferRequest->Infer(&resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->Infer());
|
||||
|
||||
// compare(*output, *src);
|
||||
}
|
||||
@ -2046,7 +2036,7 @@ TEST_F(MKLDNNGraphStructureTests, Test2ConcatFromConcat) {
|
||||
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
|
||||
execNetwork->setNetworkInputs(_networkInputs);
|
||||
execNetwork->setNetworkOutputs(_networkOutputs);
|
||||
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
|
||||
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 2, 2}, InferenceEngine::NCHW);
|
||||
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
|
||||
@ -2070,14 +2060,10 @@ TEST_F(MKLDNNGraphStructureTests, Test2ConcatFromConcat) {
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
|
||||
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data1", src1, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
sts = inferRequest->SetBlob("data2", src2, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
sts = inferRequest->SetBlob("data3", src3, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
sts = inferRequest->SetBlob("data4", src4, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data1", src1));
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data2", src2));
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data3", src3));
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data4", src4));
|
||||
|
||||
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
|
||||
|
||||
@ -2127,12 +2113,10 @@ TEST_F(MKLDNNGraphStructureTests, Test2ConcatFromConcat) {
|
||||
}
|
||||
refOutputs.push_back(refOutput);
|
||||
|
||||
sts = inferRequest->SetBlob(it.first.c_str(), output, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob(it.first, output));
|
||||
}
|
||||
|
||||
sts = inferRequest->Infer(&resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->Infer());
|
||||
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
compare(*outputs[i], *refOutputs[i]);
|
||||
@ -2376,7 +2360,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithConstLayer) {
|
||||
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
|
||||
execNetwork->setNetworkInputs(_networkInputs);
|
||||
execNetwork->setNetworkOutputs(_networkOutputs);
|
||||
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
|
||||
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 20, 20}, InferenceEngine::NCHW);
|
||||
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
|
||||
@ -2385,8 +2369,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithConstLayer) {
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
|
||||
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data", src1, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data", src1));
|
||||
|
||||
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
|
||||
|
||||
@ -2396,11 +2379,9 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithConstLayer) {
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
|
||||
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob(item.first.c_str(), output));
|
||||
|
||||
sts = inferRequest->Infer(&resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->Infer());
|
||||
}
|
||||
|
||||
TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
|
||||
@ -2523,7 +2504,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
|
||||
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
|
||||
execNetwork->setNetworkInputs(_networkInputs);
|
||||
execNetwork->setNetworkOutputs(_networkOutputs);
|
||||
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
|
||||
|
||||
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 20, 20}, InferenceEngine::NCHW);
|
||||
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
|
||||
@ -2535,8 +2516,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
|
||||
|
||||
InferenceEngine::ResponseDesc resp;
|
||||
|
||||
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data", src1, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob("data", src1));
|
||||
|
||||
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
|
||||
|
||||
@ -2546,11 +2526,9 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
|
||||
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
|
||||
output->allocate();
|
||||
|
||||
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->SetBlob(item.first.c_str(), output));
|
||||
|
||||
sts = inferRequest->Infer(&resp);
|
||||
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
|
||||
ASSERT_NO_THROW(inferRequest->Infer());
|
||||
|
||||
auto *res_ptr = output->buffer().as<float*>();
|
||||
size_t res_size = output->size();
|
||||
|
Loading…
Reference in New Issue
Block a user