Merged internal Infer Request implementation (#5125)

This commit is contained in:
Anton Pankratv 2021-04-19 15:16:47 +03:00 committed by GitHub
parent ef70e5187c
commit 46987def54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
84 changed files with 1312 additions and 1704 deletions

View File

@ -8,7 +8,7 @@
`InferRequest` Class
------------------------
Inference Engine Plugin API provides the helper InferenceEngine::InferRequestInternal class recommended
Inference Engine Plugin API provides the helper InferenceEngine::IInferRequestInternal class recommended
to use as a base class for a synchronous inference request implementation. Based of that, a declaration
of a synchronous request class can look as follows:
@ -46,7 +46,7 @@ Decrements a number of created inference requests:
### `InferImpl()`
**Implementation details:** Base InferRequestInternal class implements the public InferenceEngine::InferRequestInternal::Infer method as following:
**Implementation details:** Base IInferRequestInternal class implements the public InferenceEngine::IInferRequestInternal::Infer method as following:
- Checks blobs set by users
- Calls the `InferImpl` method defined in a derived class to call actual pipeline stages synchronously
@ -59,7 +59,7 @@ Below is the code of the the `inferPreprocess` method to demonstrate Inference E
@snippet src/template_infer_request.cpp infer_request:infer_preprocess
**Details:**
* `InferImpl` must call the InferenceEngine::InferRequestInternal::execDataPreprocessing function, which executes common Inference Engine preprocessing step (for example, applies resize or color conversion operations) if it is set by the user. The output dimensions, layout and precision matches the input information set via InferenceEngine::CNNNetwork::getInputsInfo.
* `InferImpl` must call the InferenceEngine::IInferRequestInternal::execDataPreprocessing function, which executes common Inference Engine preprocessing step (for example, applies resize or color conversion operations) if it is set by the user. The output dimensions, layout and precision matches the input information set via InferenceEngine::CNNNetwork::getInputsInfo.
* If `inputBlob` passed by user differs in terms of precisions from precision expected by plugin, `blobCopy` is performed which does actual precision conversion.
#### 2. `startPipeline`

View File

@ -8,7 +8,7 @@
using namespace InferenceEngine;
class AcceleratorSyncRequest : public InferRequestInternal {
class AcceleratorSyncRequest : public IInferRequestInternal {
public:
using Ptr = std::shared_ptr<AcceleratorSyncRequest>;

View File

@ -2,7 +2,7 @@
int main() {
InferenceEngine::Core core;
InferenceEngine::IInferRequest::CompletionCallback callback;
InferenceEngine::IInferRequest::CompletionCallback callback = nullptr;
int numRequests = 42;
int i = 1;
auto network = core.ReadNetwork("sample.xml");

View File

@ -18,7 +18,7 @@ public:
const InferenceEngine::ITaskExecutor::Ptr& waitExecutor,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
~TemplateAsyncInferRequest() override;
~TemplateAsyncInferRequest();
private:
TemplateInferRequest::Ptr _inferRequest;

View File

@ -131,21 +131,18 @@ void TemplatePlugin::ExecutableNetwork::InitExecutor() {
// ! [executable_network:create_infer_request_impl]
InferenceEngine::InferRequestInternal::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::IInferRequestInternal::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) {
return std::make_shared<TemplateInferRequest>(networkInputs, networkOutputs, std::static_pointer_cast<ExecutableNetwork>(shared_from_this()));
}
// ! [executable_network:create_infer_request_impl]
// ! [executable_network:create_infer_request]
InferenceEngine::IInferRequest::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequest() {
InferenceEngine::IInferRequestInternal::Ptr TemplatePlugin::ExecutableNetwork::CreateInferRequest() {
InferenceEngine::IInferRequest::Ptr asyncRequest;
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
return std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
asyncRequest.reset(new InferenceEngine::InferRequestBase(asyncThreadSafeImpl));
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
}
// ! [executable_network:create_infer_request]

View File

@ -36,9 +36,9 @@ public:
// Methods from a base class ExecutableNetworkThreadSafeDefault
void ExportImpl(std::ostream& model) override;
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
InferenceEngine::Parameter GetMetric(const std::string &name) const override;
InferenceEngine::Parameter GetConfig(const std::string &name) const override;

View File

@ -33,7 +33,7 @@ using Time = std::chrono::high_resolution_clock;
TemplateInferRequest::TemplateInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
const InferenceEngine::OutputsDataMap& networkOutputs,
const std::shared_ptr<TemplatePlugin::ExecutableNetwork>& executableNetwork) :
InferRequestInternal(networkInputs, networkOutputs),
IInferRequestInternal(networkInputs, networkOutputs),
_executableNetwork(executableNetwork) {
// TODO: allocate infer request device and host buffers if needed, fill actual list of profiling tasks
@ -178,9 +178,9 @@ static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
void TemplateInferRequest::inferPreprocess() {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, _profilingTask[Preprocess]);
auto start = Time::now();
// NOTE: After InferRequestInternal::execDataPreprocessing call
// NOTE: After IInferRequestInternal::execDataPreprocessing call
// input can points to other memory region than it was allocated in constructor.
InferRequestInternal::execDataPreprocessing(_deviceInputs);
IInferRequestInternal::execDataPreprocessing(_deviceInputs);
for (auto&& networkInput : _deviceInputs) {
auto index = _executableNetwork->_inputIndex[networkInput.first];
const auto& parameter = _parameters[index];

View File

@ -11,7 +11,6 @@
#include <unordered_map>
#include <ie_common.h>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <threading/ie_itask_executor.hpp>
#include <openvino/itt.hpp>
@ -27,14 +26,14 @@ namespace TemplatePlugin {
class ExecutableNetwork;
// ! [infer_request:header]
class TemplateInferRequest : public InferenceEngine::InferRequestInternal {
class TemplateInferRequest : public InferenceEngine::IInferRequestInternal {
public:
typedef std::shared_ptr<TemplateInferRequest> Ptr;
TemplateInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
const InferenceEngine::OutputsDataMap& networkOutputs,
const std::shared_ptr<ExecutableNetwork>& executableNetwork);
~TemplateInferRequest() override;
~TemplateInferRequest();
void InferImpl() override;
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;

View File

@ -4,6 +4,7 @@
#include <ie_metric_helpers.hpp>
#include <ie_plugin_config.hpp>
#include <ie_algorithm.hpp>
#include <hetero/hetero_plugin_config.hpp>
#include <threading/ie_executor_manager.hpp>

View File

@ -4,7 +4,7 @@
/**
* @brief A header file that provides ExecutableNetwork class
*
*
* @file ie_executable_network.hpp
*/

View File

@ -18,44 +18,13 @@
#include "ie_iinfer_request.hpp"
#include "details/ie_so_loader.h"
#include "ie_blob.h"
#include "ie_iinfer_request.hpp"
namespace InferenceEngine {
namespace details {
class ICompletionCallbackWrapper {
public:
virtual ~ICompletionCallbackWrapper() = default;
virtual void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept = 0;
};
template <class T>
class CompletionCallbackWrapper : public ICompletionCallbackWrapper {
T lambda;
public:
explicit CompletionCallbackWrapper(const T& lambda): lambda(lambda) {}
void call(InferenceEngine::IInferRequest::Ptr /*request*/, InferenceEngine::StatusCode /*code*/) const
noexcept override {
lambda();
}
};
template <>
class CompletionCallbackWrapper<IInferRequest::CompletionCallback> : public ICompletionCallbackWrapper {
IInferRequest::CompletionCallback callBack;
public:
explicit CompletionCallbackWrapper(const IInferRequest::CompletionCallback& callBack): callBack(callBack) {}
void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept override {
callBack(request, code);
}
};
} // namespace details
class SharedObjectLoader;
}
class IInferRequestInternal;
/**
* @copybrief IInferRequest
@ -63,42 +32,41 @@ public:
* Wraps IInferRequest
* It can throw exceptions safely for the application, where it is properly handled.
*/
class InferRequest {
IInferRequest::Ptr actual;
InferenceEngine::details::SharedObjectLoader::Ptr plg;
std::shared_ptr<details::ICompletionCallbackWrapper> callback;
class INFERENCE_ENGINE_API_CLASS(InferRequest) {
std::shared_ptr<IInferRequestInternal> _impl;
std::shared_ptr<details::SharedObjectLoader> _so;
static void callWrapper(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) {
details::ICompletionCallbackWrapper* pWrapper = nullptr;
ResponseDesc dsc;
request->GetUserData(reinterpret_cast<void**>(&pWrapper), &dsc);
pWrapper->call(request, code);
}
explicit InferRequest(const std::shared_ptr<IInferRequestInternal>& impl,
const std::shared_ptr<details::SharedObjectLoader>& so);
friend class ExecutableNetwork;
public:
/**
* @enum WaitMode
* @brief Enumeration to hold wait mode for IInferRequest
*/
enum WaitMode : int64_t {
/** Wait until inference result becomes available */
RESULT_READY = -1,
/** IInferRequest doesn't block or interrupt current thread and immediately returns inference status */
STATUS_ONLY = 0,
};
/**
* @brief A smart pointer to the InferRequest object
*/
using Ptr = std::shared_ptr<InferRequest>;
/**
* @brief Default constructor
*/
InferRequest() = default;
/**
* constructs InferRequest from the initialized shared_pointer
* @param request Initialized shared pointer to IInferRequest interface
* @param splg Plugin to use. This is required to ensure that InferRequest can work properly even if plugin object is destroyed.
*/
explicit InferRequest(IInferRequest::Ptr request,
InferenceEngine::details::SharedObjectLoader::Ptr splg = {}):
actual(request), plg(splg) {
// plg can be null, but not the actual
if (actual == nullptr) IE_THROW() << "InferRequest was not initialized.";
}
/**
* @brief Destructor
*/
~InferRequest() {
actual = nullptr;
}
~InferRequest();
/**
* @brief Sets input/output data to infer
@ -108,27 +76,16 @@ public:
* @param data Reference to input or output blob. The type of a blob must match the network input precision and
* size.
*/
void SetBlob(const std::string& name, const Blob::Ptr& data) {
CALL_STATUS_FNC(SetBlob, name.c_str(), data);
}
void SetBlob(const std::string& name, const Blob::Ptr& data);
/**
* @copybrief IInferRequest::GetBlob
* @brief Gets input/output data for inference
*
* Wraps IInferRequest::GetBlob
* @note Memory allocation does not happen
* @param name A name of Blob to get
* @return A shared pointer to a Blob with a name @p name. If a blob is not found, an exception is thrown.
*/
Blob::Ptr GetBlob(const std::string& name) {
Blob::Ptr data;
CALL_STATUS_FNC(GetBlob, name.c_str(), data);
std::string error = "Internal error: blob with name `" + name + "` is not allocated!";
auto blobPtr = data.get();
const bool remoteBlobPassed = blobPtr->is<RemoteBlob>();
if (blobPtr == nullptr) IE_THROW() << error;
if (!remoteBlobPassed && blobPtr->buffer() == nullptr) IE_THROW() << error;
return data;
}
Blob::Ptr GetBlob(const std::string& name);
/**
* @brief Sets blob with a pre-process information
@ -137,51 +94,37 @@ public:
* @param data A reference to input. The type of Blob must correspond to the network input precision and size.
* @param info Preprocess info for blob.
*/
void SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo& info) {
CALL_STATUS_FNC(SetBlob, name.c_str(), data, info);
}
void SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo& info);
/**
* @brief Gets pre-process for input data
* @param name Name of input blob.
* @return pointer to pre-process info of blob with name
*/
const PreProcessInfo& GetPreProcess(const std::string& name) const {
const PreProcessInfo* info = nullptr;
CALL_STATUS_FNC(GetPreProcess, name.c_str(), &info);
return *info;
}
const PreProcessInfo& GetPreProcess(const std::string& name) const;
/**
* @copybrief IInferRequest::Infer
* @brief Infers specified input(s) in synchronous mode
*
* @note blocks all methods of InferRequest while request is ongoing (running or waiting in queue)
*
* Wraps IInferRequest::Infer
*/
void Infer() {
CALL_STATUS_FNC_NO_ARGS(Infer);
}
void Infer();
/**
* @copybrief IInferRequest::Cancel
*
* Wraps IInferRequest::Cancel
* @brief Cancel inference request
* @param name Name of input blob.
* @return pointer to pre-process info of blob with name
*/
void Cancel() {
CALL_STATUS_FNC_NO_ARGS(Cancel);
}
void Cancel();
/**
* @copybrief IInferRequest::GetPerformanceCounts
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer
*
* Wraps IInferRequest::GetPerformanceCounts
* @note not all plugins provide meaningful data
* @return Map of layer names to profiling information for that layer
*/
std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const {
std::map<std::string, InferenceEngineProfileInfo> perfMap;
CALL_STATUS_FNC(GetPerformanceCounts, perfMap);
return perfMap;
}
std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const;
/**
* @brief Sets input data to infer
@ -190,11 +133,7 @@ public:
* @param inputs A reference to a map of input blobs accessed by input names.
* The type of Blob must correspond to the network input precision and size.
*/
void SetInput(const BlobMap& inputs) {
for (auto&& input : inputs) {
CALL_STATUS_FNC(SetBlob, input.first.c_str(), input.second);
}
}
void SetInput(const BlobMap& inputs);
/**
* @brief Sets data that will contain result of the inference
@ -203,34 +142,27 @@ public:
* @param results - a reference to a map of result blobs accessed by output names.
* The type of Blob must correspond to the network output precision and size.
*/
void SetOutput(const BlobMap& results) {
for (auto&& result : results) {
CALL_STATUS_FNC(SetBlob, result.first.c_str(), result.second);
}
}
void SetOutput(const BlobMap& results);
/**
* @brief Sets new batch size when dynamic batching is enabled in executable network that created this request.
*
* @param batch new batch size to be used by all the following inference calls for this request.
*/
void SetBatch(const int batch) {
CALL_STATUS_FNC(SetBatch, batch);
}
void SetBatch(const int batch);
/**
* @brief Start inference of specified input(s) in asynchronous mode
*
* @note It returns immediately. Inference starts also immediately.
*/
void StartAsync() {
CALL_STATUS_FNC_NO_ARGS(StartAsync);
}
void StartAsync();
/**
* @copybrief IInferRequest::Wait
* @brief Waits for the result to become available. Blocks until specified millis_timeout has elapsed or the result
* becomes available, whichever comes first.
*
*
* Wraps IInferRequest::Wait
* @param millis_timeout Maximum duration in milliseconds to block for
* @note There are special cases when millis_timeout is equal some value of the WaitMode enum:
* * STATUS_ONLY - immediately returns inference status (IInferRequest::RequestStatus). It does not block or
@ -238,105 +170,69 @@ public:
* * RESULT_READY - waits until inference result becomes available
* @return A status code of operation
*/
StatusCode Wait(int64_t millis_timeout) {
ResponseDesc resp;
if (actual == nullptr) IE_THROW() << "InferRequest was not initialized.";
auto res = actual->Wait(millis_timeout, &resp);
if (res != OK && res != RESULT_NOT_READY &&
res != INFER_NOT_STARTED && res != INFER_CANCELLED) {
IE_EXCEPTION_SWITCH(res, ExceptionType,
InferenceEngine::details::ThrowNow<ExceptionType>{}
<<= std::stringstream{} << IE_LOCATION << resp.msg)
}
return res;
}
StatusCode Wait(int64_t millis_timeout = RESULT_READY);
private:
void SetCompletionCallbackImpl(std::function<void()>);
void SetCompletionCallbackImpl(std::function<void(InferRequest, StatusCode)>);
void SetCompletionCallbackImpl(IInferRequest::CompletionCallback);
template<typename T>
struct SetCallback {
void operator()(std::function<void()> f) {_this.SetCompletionCallbackImpl(std::move(f));}
InferRequest& _this;
};
public:
/**
* @copybrief IInferRequest::SetCompletionCallback
* @brief Sets a callback function that will be called on success or failure of asynchronous request
*
* Wraps IInferRequest::SetCompletionCallback
*
* @param callbackToSet Lambda callback object which will be called on processing finish.
* @param callbackToSet callback object which will be called on when inference finish.
*/
template <class T>
void SetCompletionCallback(const T& callbackToSet) {
callback.reset(new details::CompletionCallbackWrapper<T>(callbackToSet));
CALL_STATUS_FNC(SetUserData, callback.get());
actual->SetCompletionCallback(callWrapper);
template<typename F>
void SetCompletionCallback(F callbackToSet) {
return SetCallback<F>{*this}(std::move(callbackToSet));
}
/**
* @copybrief IExecutableNetwork::QueryState
* @brief Gets state control interface for given infer request.
*
* Wraps IExecutableNetwork::QueryState
* State control essential for recurrent networks
* @return A vector of Memory State objects
*/
std::vector<VariableState> QueryState() {
IE_SUPPRESS_DEPRECATED_START
if (actual == nullptr) IE_THROW() << "ExecutableNetwork was not initialized.";
IVariableState::Ptr pState = nullptr;
auto res = OK;
std::vector<VariableState> controller;
for (size_t idx = 0; res == OK; ++idx) {
ResponseDesc resp;
res = actual->QueryState(pState, idx, &resp);
if (res != OK && res != OUT_OF_BOUNDS) {
IE_THROW() << resp.msg;
}
if (res != OUT_OF_BOUNDS) {
controller.push_back(VariableState(pState, plg));
}
}
IE_SUPPRESS_DEPRECATED_END
return controller;
}
std::vector<VariableState> QueryState();
/**
* @brief IInferRequest pointer to be used directly in CreateInferRequest functions
* @return A shared pointer to underlying IInferRequest interface
* @return A shared pointer to IInferRequest interface
*/
operator IInferRequest::Ptr&() {
if (actual == nullptr) IE_THROW() << "InferRequest was not initialized.";
return actual;
}
INFERENCE_ENGINE_DEPRECATED("Will be removed")
operator std::shared_ptr<IInferRequest> ();
/**
* @brief Checks if current InferRequest object is not initialized
* @return true if current InferRequest object is not initialized, false - otherwise
*/
bool operator!() const noexcept {
return !actual;
}
bool operator!() const noexcept;
/**
* @brief Checks if current InferRequest object is initialized
* @return true if current InferRequest object is initialized, false - otherwise
*/
explicit operator bool() const noexcept {
return !!actual;
}
/**
* @brief A smart pointer to the InferRequest object
*/
using Ptr = std::shared_ptr<InferRequest>;
explicit operator bool() const noexcept;
};
namespace details {
template <>
class CompletionCallbackWrapper<std::function<void(InferRequest, StatusCode)>> : public ICompletionCallbackWrapper {
std::function<void(InferRequest, StatusCode)> lambda;
public:
explicit CompletionCallbackWrapper(const std::function<void(InferRequest, InferenceEngine::StatusCode)>& lambda)
: lambda(lambda) {}
void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept override {
lambda(InferRequest(request), code);
template<>
struct InferRequest::SetCallback<std::function<void(InferRequest, StatusCode)>> {
void operator()(std::function<void(InferRequest, StatusCode)> f) {
_this.SetCompletionCallbackImpl(std::move(f));
}
InferRequest& _this;
};
template<>
struct InferRequest::SetCallback<IInferRequest::CompletionCallback> {
void operator()(IInferRequest::CompletionCallback f) {
_this.SetCompletionCallbackImpl(std::move(f));
}
InferRequest& _this;
};
} // namespace details
} // namespace InferenceEngine

View File

@ -25,7 +25,8 @@ namespace InferenceEngine {
* @brief This is an interface of asynchronous infer request
*
*/
class IInferRequest : public std::enable_shared_from_this<IInferRequest> {
IE_SUPPRESS_DEPRECATED_START
class INFERENCE_ENGINE_DEPRECATED("Do not use IInferRequest API") IInferRequest : public std::enable_shared_from_this<IInferRequest> {
public:
/**
* @enum WaitMode
@ -201,5 +202,6 @@ public:
protected:
~IInferRequest() = default;
};
IE_SUPPRESS_DEPRECATED_END
} // namespace InferenceEngine

View File

@ -5,7 +5,7 @@
#include "cldnn_async_infer_request.h"
#include <memory>
CLDNNPlugin::CLDNNAsyncInferRequest::CLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
CLDNNPlugin::CLDNNAsyncInferRequest::CLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr &inferRequest,
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor)
: InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor)

View File

@ -13,13 +13,13 @@ namespace CLDNNPlugin {
class CLDNNAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
public:
CLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
CLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr &inferRequest,
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor);
void Infer_ThreadUnsafe() override;
~CLDNNAsyncInferRequest() override;
~CLDNNAsyncInferRequest();
};
} // namespace CLDNNPlugin

View File

@ -22,6 +22,7 @@
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ie_ngraph_utils.hpp>
#include <ie_algorithm.hpp>
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>

View File

@ -26,6 +26,7 @@
#include "cldnn_executable_network.h"
#include "threading/ie_cpu_streams_executor.hpp"
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
using namespace InferenceEngine;
@ -62,8 +63,8 @@ CLDNNExecNetwork::CLDNNExecNetwork(InferenceEngine::CNNNetwork &network, RemoteC
}
}
InferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequestImpl(InputsDataMap networkInputs,
OutputsDataMap networkOutputs) {
IInferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequestImpl(InputsDataMap networkInputs,
OutputsDataMap networkOutputs) {
OV_ITT_SCOPED_TASK(itt::domains::CLDNNPlugin, "CLDNNExecNetwork::CreateInferRequestImpl");
if (m_graphs.empty()) {
IE_THROW(NetworkNotLoaded);
@ -91,7 +92,7 @@ InferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequestImpl(InputsDataMap
return ptr;
}
IInferRequest::Ptr CLDNNExecNetwork::CreateInferRequest() {
IInferRequestInternal::Ptr CLDNNExecNetwork::CreateInferRequest() {
OV_ITT_SCOPED_TASK(itt::domains::CLDNNPlugin, "CLDNNExecNetwork::CreateInferRequest");
return CreateAsyncInferRequestFromSync<CLDNNAsyncInferRequest>();
}

View File

@ -26,9 +26,9 @@ public:
CLDNNExecNetwork(InferenceEngine::CNNNetwork &network, InferenceEngine::RemoteContext::Ptr context, Config config);
InferenceEngine::CNNNetwork GetExecGraphInfo() override;
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::Parameter GetMetric(const std::string &name) const override;
InferenceEngine::Parameter GetConfig(const std::string &name) const override;

View File

@ -12,6 +12,8 @@
#include "cldnn_remote_context.h"
#include "cldnn_executable_network.h"
#include "cldnn_itt.h"
#include <ie_algorithm.hpp>
#include <debug.h>
using namespace InferenceEngine;
@ -812,7 +814,7 @@ void CLDNNInferRequest::SetBatch(int new_batch) {
CLDNNInferRequest::CLDNNInferRequest(InputsDataMap networkInputs, OutputsDataMap networkOutputs,
const CLDNNExecNetwork::Ptr& execNetwork)
: InferRequestInternal(networkInputs, networkOutputs)
: IInferRequestInternal(networkInputs, networkOutputs)
, m_useProfiling(false)
, m_useStreams(false) {
IE_ASSERT(nullptr != execNetwork);

View File

@ -9,7 +9,6 @@
#include <vector>
#include <memory>
#include <atomic>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include "cldnn_graph.h"
#include <threading/ie_istreams_executor.hpp>
@ -22,7 +21,7 @@ struct buf_info {
class CLDNNExecNetwork;
class CLDNNInferRequest : public InferenceEngine::InferRequestInternal {
class CLDNNInferRequest : public InferenceEngine::IInferRequestInternal {
public:
// make sure all blobs and cldnn::memory objects
// are in place and valid

View File

@ -8,16 +8,15 @@
#include <map>
#include <vector>
#include <cpp_interfaces/impl/ie_executable_network_thread_safe_default.hpp>
#include "gna_infer_request.hpp"
#include "gna_plugin.hpp"
#include <gna/gna_config.hpp>
#include <threading/ie_executor_manager.hpp>
#include <cpp_interfaces/impl/ie_executable_network_thread_safe_async_only.hpp>
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
namespace GNAPluginNS {
class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafeAsyncOnly {
class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkInternal {
std::shared_ptr<GNAPlugin> plg;
public:
@ -53,9 +52,9 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
: GNAExecutableNetwork(network, std::make_shared<GNAPlugin>(config)) {
}
InferenceEngine::AsyncInferRequestInternal::Ptr
CreateAsyncInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override {
InferenceEngine::IInferRequestInternal::Ptr
CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override {
return std::make_shared<GNAInferRequest>(plg, networkInputs, networkOutputs);
}

View File

@ -8,13 +8,12 @@
#include <string>
#include <map>
#include "cpp_interfaces/impl/ie_infer_async_request_internal.hpp"
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
#include "gna_plugin.hpp"
namespace GNAPluginNS {
class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
class GNAInferRequest : public InferenceEngine::IInferRequestInternal {
protected:
std::shared_ptr<GNAPlugin> plg;
uint32_t inferRequestIdx = -1;
@ -23,7 +22,7 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs)
: InferenceEngine::AsyncInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
: InferenceEngine::IInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
// TODO: internal connection API - better to generalize
if (networkOutputs.empty()) {
THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
@ -78,10 +77,19 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
inferRequestIdx = plg->QueueInference(_inputs, _outputs);
// workaround to unblock callback-based flows
if (_callback) {
auto infer_request = _publicInterface.lock();
IE_ASSERT(infer_request != nullptr);
auto res = Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
_callback(infer_request, res);
std::exception_ptr exceptionPtr;
if (res != InferenceEngine::StatusCode::OK) {
try {
IE_EXCEPTION_SWITCH(res, ExceptionType,
InferenceEngine::details::ThrowNow<ExceptionType>{}
<<= std::stringstream{} << IE_LOCATION
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
} catch (...) {
exceptionPtr = std::current_exception();
}
}
_callback(exceptionPtr);
}
}

View File

@ -9,9 +9,9 @@
using namespace HeteroPlugin;
using namespace InferenceEngine;
HeteroAsyncInferRequest::HeteroAsyncInferRequest(const InferRequestInternal::Ptr& request,
const ITaskExecutor::Ptr& taskExecutor,
const ITaskExecutor::Ptr& callbackExecutor) :
HeteroAsyncInferRequest::HeteroAsyncInferRequest(const IInferRequestInternal::Ptr& request,
const ITaskExecutor::Ptr& taskExecutor,
const ITaskExecutor::Ptr& callbackExecutor) :
AsyncInferRequestThreadSafeDefault(request, taskExecutor, callbackExecutor),
_heteroInferRequest(std::static_pointer_cast<HeteroInferRequest>(request)),
_statusCodes{_heteroInferRequest->_inferRequests.size(), StatusCode::OK} {

View File

@ -19,10 +19,10 @@ namespace HeteroPlugin {
class HeteroAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
public:
using Ptr = std::shared_ptr<HeteroAsyncInferRequest>;
HeteroAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr& request,
HeteroAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr& request,
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
~HeteroAsyncInferRequest() override;
~HeteroAsyncInferRequest();
void StartAsync_ThreadUnsafe() override;
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;

View File

@ -24,9 +24,11 @@
#include "transformations/serialize.hpp"
#include "ie_ngraph_utils.hpp"
#include "ie_plugin_config.hpp"
#include "ie_algorithm.hpp"
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"
#include "hetero/hetero_plugin_config.hpp"
#include "hetero_plugin.hpp"
#include <ie_algorithm.hpp>
#include <ngraph/function.hpp>
#include <ngraph/variant.hpp>
@ -638,7 +640,7 @@ void HeteroExecutableNetwork::ExportImpl(std::ostream& heteroModel) {
}
}
InferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequestImpl(
IInferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequestImpl(
InputsDataMap networkInputs,
OutputsDataMap networkOutputs) {
HeteroInferRequest::SubRequestsList inferRequests;
@ -655,7 +657,7 @@ InferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequestImpl(
_blobNameMap);
}
IInferRequest::Ptr HeteroExecutableNetwork::CreateInferRequest() {
IInferRequestInternal::Ptr HeteroExecutableNetwork::CreateInferRequest() {
return CreateAsyncInferRequestFromSync<HeteroAsyncInferRequest>();
}

View File

@ -49,10 +49,10 @@ public:
~HeteroExecutableNetwork() override = default;
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
InferenceEngine::Parameter GetConfig(const std::string &name) const override;

View File

@ -20,7 +20,7 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
InferenceEngine::OutputsDataMap networkOutputs,
const SubRequestsList& inferRequests,
const std::unordered_map<std::string, std::string>& subgraphInputToOutputBlobNames) :
InferRequestInternal(networkInputs, networkOutputs),
IInferRequestInternal(networkInputs, networkOutputs),
_inferRequests(inferRequests) {
if (_networkOutputs.empty() || _networkInputs.empty()) {
IE_THROW() << "Internal error: no information about network's output/input";
@ -65,7 +65,7 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
}
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) {
InferenceEngine::InferRequestInternal::SetBlob(name, data);
InferenceEngine::IInferRequestInternal::SetBlob(name, data);
assert(!_inferRequests.empty());
for (auto &&desc : _inferRequests) {
auto &r = desc._request;

View File

@ -15,14 +15,15 @@
#include <memory>
#include <unordered_map>
#include <ie_common.h>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <cpp/ie_infer_request.hpp>
#include <cpp/ie_executable_network.hpp>
#include <openvino/itt.hpp>
namespace HeteroPlugin {
class HeteroInferRequest : public InferenceEngine::InferRequestInternal {
class HeteroInferRequest : public InferenceEngine::IInferRequestInternal {
public:
typedef std::shared_ptr<HeteroInferRequest> Ptr;

View File

@ -9,6 +9,7 @@ file (GLOB LIBRARY_SRC
${CMAKE_CURRENT_SOURCE_DIR}/cpp/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threading/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp_interfaces/interface/*.cpp
)
# TODO: WA for OneHot pass usage in reshape

View File

@ -54,7 +54,7 @@ InferRequest ExecutableNetwork::CreateInferRequest() {
}
InferRequest::Ptr ExecutableNetwork::CreateInferRequestPtr() {
CALL_STATEMENT(return std::make_shared<InferRequest>(_impl->CreateInferRequest(), _so));
CALL_STATEMENT(return std::make_shared<InferRequest>(InferRequest{_impl->CreateInferRequest(), _so}));
}
void ExecutableNetwork::Export(const std::string& modelFileName) {

View File

@ -0,0 +1,209 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <map>
#include <memory>
#include <string>
#include "cpp/ie_infer_request.hpp"
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
#include "ie_remote_context.hpp"
namespace InferenceEngine {
#define CATCH_IE_EXCEPTION(ExceptionType) catch (const InferenceEngine::ExceptionType& e) {throw e;}
#define CATCH_IE_EXCEPTIONS \
CATCH_IE_EXCEPTION(GeneralError) \
CATCH_IE_EXCEPTION(NotImplemented) \
CATCH_IE_EXCEPTION(NetworkNotLoaded) \
CATCH_IE_EXCEPTION(ParameterMismatch) \
CATCH_IE_EXCEPTION(NotFound) \
CATCH_IE_EXCEPTION(OutOfBounds) \
CATCH_IE_EXCEPTION(Unexpected) \
CATCH_IE_EXCEPTION(RequestBusy) \
CATCH_IE_EXCEPTION(ResultNotReady) \
CATCH_IE_EXCEPTION(NotAllocated) \
CATCH_IE_EXCEPTION(InferNotStarted) \
CATCH_IE_EXCEPTION(NetworkNotRead) \
CATCH_IE_EXCEPTION(InferCancelled)
#define CALL_STATEMENT(...) \
if (_impl == nullptr) IE_THROW() << "Inference Requst is not initialized"; \
try { \
__VA_ARGS__ \
} CATCH_IE_EXCEPTIONS catch (const std::exception& ex) { \
IE_THROW() << ex.what(); \
} catch (...) { \
IE_THROW(Unexpected); \
}
InferRequest::InferRequest(const std::shared_ptr<IInferRequestInternal>& impl,
const std::shared_ptr<details::SharedObjectLoader>& so) :
_impl{impl},
_so{so} {
if (_impl == nullptr) IE_THROW() << "Inference Requst is not initialized";
}
InferRequest::~InferRequest() {
_impl = {};
}
void InferRequest::SetBlob(const std::string& name, const Blob::Ptr& data) {
CALL_STATEMENT(_impl->SetBlob(name, data);)
}
Blob::Ptr InferRequest::GetBlob(const std::string& name) {
Blob::Ptr blobPtr;
CALL_STATEMENT(blobPtr = _impl->GetBlob(name);)
std::string error = "Internal error: blob with name `" + name + "` is not allocated!";
const bool remoteBlobPassed = blobPtr->is<RemoteBlob>();
if (blobPtr == nullptr) IE_THROW() << error;
if (!remoteBlobPassed && blobPtr->buffer() == nullptr) IE_THROW() << error;
return blobPtr;
}
void InferRequest::SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo& info) {
CALL_STATEMENT(_impl->SetBlob(name, data, info);)
}
const PreProcessInfo& InferRequest::GetPreProcess(const std::string& name) const {
CALL_STATEMENT(return _impl->GetPreProcess(name);)
}
void InferRequest::Infer() {
CALL_STATEMENT(_impl->Infer();)
}
void InferRequest::Cancel() {
CALL_STATEMENT(_impl->Cancel();)
}
std::map<std::string, InferenceEngineProfileInfo> InferRequest::GetPerformanceCounts() const {
CALL_STATEMENT(return _impl->GetPerformanceCounts();)
}
void InferRequest::SetInput(const BlobMap& inputs) {
CALL_STATEMENT(
for (auto&& input : inputs) {
_impl->SetBlob(input.first, input.second);
}
)
}
void InferRequest::SetOutput(const BlobMap& results) {
CALL_STATEMENT(
for (auto&& result : results) {
_impl->SetBlob(result.first, result.second);
}
)
}
void InferRequest::SetBatch(const int batch) {
CALL_STATEMENT(_impl->SetBatch(batch);)
}
void InferRequest::StartAsync() {
CALL_STATEMENT(_impl->StartAsync();)
}
StatusCode InferRequest::Wait(int64_t millis_timeout) {
CALL_STATEMENT(return _impl->Wait(millis_timeout);)
}
void InferRequest::SetCompletionCallbackImpl(std::function<void()> callback) {
CALL_STATEMENT(
_impl->SetCallback([callback] (std::exception_ptr) {
callback();
});
)
}
#define CATCH_IE_EXCEPTION_RETURN(StatusCode, ExceptionType) catch (const ExceptionType&) {return StatusCode;}
#define CATCH_IE_EXCEPTIONS_RETURN \
CATCH_IE_EXCEPTION_RETURN(GENERAL_ERROR, GeneralError) \
CATCH_IE_EXCEPTION_RETURN(NOT_IMPLEMENTED, NotImplemented) \
CATCH_IE_EXCEPTION_RETURN(NETWORK_NOT_LOADED, NetworkNotLoaded) \
CATCH_IE_EXCEPTION_RETURN(PARAMETER_MISMATCH, ParameterMismatch) \
CATCH_IE_EXCEPTION_RETURN(NOT_FOUND, NotFound) \
CATCH_IE_EXCEPTION_RETURN(OUT_OF_BOUNDS, OutOfBounds) \
CATCH_IE_EXCEPTION_RETURN(UNEXPECTED, Unexpected) \
CATCH_IE_EXCEPTION_RETURN(REQUEST_BUSY, RequestBusy) \
CATCH_IE_EXCEPTION_RETURN(RESULT_NOT_READY, ResultNotReady) \
CATCH_IE_EXCEPTION_RETURN(NOT_ALLOCATED, NotAllocated) \
CATCH_IE_EXCEPTION_RETURN(INFER_NOT_STARTED, InferNotStarted) \
CATCH_IE_EXCEPTION_RETURN(NETWORK_NOT_READ, NetworkNotRead) \
CATCH_IE_EXCEPTION_RETURN(INFER_CANCELLED, InferCancelled)
void InferRequest::SetCompletionCallbackImpl(std::function<void(InferRequest, StatusCode)> callback) {
CALL_STATEMENT(
auto weakThis = InferRequest{std::shared_ptr<IInferRequestInternal>{_impl.get(), [](IInferRequestInternal*){}}, _so};
_impl->SetCallback([callback, weakThis] (std::exception_ptr exceptionPtr) {
StatusCode statusCode = StatusCode::OK;
if (exceptionPtr != nullptr) {
statusCode = [&] {
try {
std::rethrow_exception(exceptionPtr);
} CATCH_IE_EXCEPTIONS_RETURN catch (const std::exception& ex) {
return GENERAL_ERROR;
} catch (...) {
return UNEXPECTED;
}
} ();
}
callback(weakThis, statusCode);
});
)
}
void InferRequest::SetCompletionCallbackImpl(IInferRequest::CompletionCallback callback) {
CALL_STATEMENT(
IInferRequest::Ptr weakThis = InferRequest{std::shared_ptr<IInferRequestInternal>{_impl.get(), [](IInferRequestInternal*){}}, _so};
_impl->SetCallback([callback, weakThis] (std::exception_ptr exceptionPtr) {
StatusCode statusCode = StatusCode::OK;
if (exceptionPtr != nullptr) {
statusCode = [&] {
try {
std::rethrow_exception(exceptionPtr);
} CATCH_IE_EXCEPTIONS_RETURN catch (const std::exception& ex) {
return GENERAL_ERROR;
} catch (...) {
return UNEXPECTED;
}
} ();
}
callback(weakThis, statusCode);
});
)
}
std::vector<VariableState> InferRequest::QueryState() {
std::vector<VariableState> controller;
CALL_STATEMENT(
for (auto&& state : _impl->QueryState()) {
controller.emplace_back(std::make_shared<VariableStateBase>(state), _so);
}
)
return controller;
}
InferRequest::operator IInferRequest::Ptr () {
CALL_STATEMENT(
return std::make_shared<InferRequestBase>(_impl);
)
}
bool InferRequest::operator!() const noexcept {
return !_impl;
}
InferRequest::operator bool() const noexcept {
return !!_impl;
}
} // namespace InferenceEngine

View File

@ -0,0 +1,325 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <map>
#include <memory>
#include <string>
#include <ie_blob.h>
#include <ie_common.h>
#include <ie_preprocess.hpp>
#include <ie_compound_blob.h>
#include <ie_algorithm.hpp>
#include <debug.h>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iplugin_internal.hpp>
#include <cpp_interfaces/plugin_itt.hpp>
namespace InferenceEngine {
IInferRequestInternal::~IInferRequestInternal() {}
IInferRequestInternal::IInferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs) {
// // We should copy maps since they can be overriden in SetBlob with preprocess
copyInputOutputInfo(networkInputs, networkOutputs, _networkInputs, _networkOutputs);
}
void IInferRequestInternal::Infer() {
checkBlobs();
InferImpl();
}
void IInferRequestInternal::InferImpl() {
IE_THROW(NotImplemented);
}
void IInferRequestInternal::Cancel() {
IE_THROW(NotImplemented);
}
std::map<std::string, InferenceEngineProfileInfo> IInferRequestInternal::GetPerformanceCounts() const {
IE_THROW(NotImplemented);
}
void IInferRequestInternal::SetBlob(const std::string& name, const Blob::Ptr& userBlob) {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "SetBlob");
if (name.empty()) {
IE_THROW(NotFound) << "Failed to set blob with empty name";
}
if (!userBlob) IE_THROW(NotAllocated) << "Failed to set empty blob with name: \'" << name << "\'";
const bool compoundBlobPassed = userBlob->is<CompoundBlob>();
const bool remoteBlobPassed = userBlob->is<RemoteBlob>();
if (!compoundBlobPassed && !remoteBlobPassed && userBlob->buffer() == nullptr)
IE_THROW(NotAllocated) << "Input data was not allocated. Input name: \'" << name << "\'";
if (userBlob->size() == 0) {
IE_THROW() << "Input data is empty. Input name: \'" << name << "\'";
}
InputInfo::Ptr foundInput;
DataPtr foundOutput;
size_t dataSize = userBlob->size();
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
if (foundInput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
IE_THROW(ParameterMismatch) << "Failed to set Blob with precision not corresponding to user input precision";
}
auto& devBlob = _deviceInputs[name];
const bool preProcRequired = preProcessingRequired(foundInput, userBlob, devBlob);
if (compoundBlobPassed && !preProcRequired) {
IE_THROW(NotImplemented) << "cannot set compound blob: supported only for input pre-processing";
}
if (preProcRequired) {
addInputPreProcessingFor(name, userBlob, devBlob ? devBlob : _inputs[name]);
} else {
size_t inputSize = foundInput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
? InferenceEngine::details::product(foundInput->getTensorDesc().getDims())
: 1;
if (dataSize != inputSize) {
IE_THROW() << "Input blob size is not equal network input size (" << dataSize << "!=" << inputSize << ").";
}
_inputs[name] = userBlob;
devBlob = userBlob;
}
} else {
if (compoundBlobPassed) {
IE_THROW(NotImplemented) << "cannot set compound blob: supported only for input pre-processing";
}
size_t outputSize = foundOutput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
? details::product(foundOutput->getTensorDesc().getDims()) :
1;
if (dataSize != outputSize) {
IE_THROW() << "Output blob size is not equal network output size (" << dataSize << "!=" << outputSize << ").";
}
if (foundOutput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
IE_THROW(ParameterMismatch) << "Failed to set Blob with precision not corresponding to user output precision";
}
_outputs[name] = userBlob;
}
}
Blob::Ptr IInferRequestInternal::GetBlob(const std::string& name) {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "GetBlob");
Blob::Ptr data;
InputInfo::Ptr foundInput;
DataPtr foundOutput;
const SizeVector oneVector = { 1 };
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
// ROI blob is returned only if it was set previously. Otherwise default blob is returned.
auto it = _preProcData.find(name);
if (it != _preProcData.end()) {
data = it->second->getRoiBlob();
} else {
data = _inputs[name];
checkBlob(data, name, true,
foundInput->getTensorDesc().getLayout() != SCALAR
? foundInput->getTensorDesc().getDims()
: oneVector);
auto& devBlob = _deviceInputs[name];
if (preProcessingRequired(foundInput, data, devBlob)) {
// if no devBlob, performs inplace
addInputPreProcessingFor(name, data, devBlob ? devBlob : _inputs[name]);
}
}
} else {
data = _outputs[name];
checkBlob(data, name, false,
foundOutput->getTensorDesc().getLayout() != SCALAR
? foundOutput->getTensorDesc().getDims()
: oneVector);
}
return data;
}
void IInferRequestInternal::SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) {
InputInfo::Ptr foundInput;
DataPtr foundOutput;
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
copyPreProcess(info, foundInput->getPreProcess());
} else {
IE_THROW() << "Pre-process can't be set to output blob";
}
SetBlob(name, data);
}
const PreProcessInfo& IInferRequestInternal::GetPreProcess(const std::string& name) const {
InputInfo::Ptr foundInput;
DataPtr foundOutput;
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
return foundInput->getPreProcess();
} else {
IE_THROW() << "Output blob can't have pre-processing";
}
}
void IInferRequestInternal::SetBatch(int batch) {
IE_THROW(NotImplemented);
}
std::vector<std::shared_ptr<IVariableStateInternal>> IInferRequestInternal::QueryState() {
IE_THROW(NotImplemented);
}
void IInferRequestInternal::StartAsync() {
checkBlobs();
StartAsyncImpl();
}
void IInferRequestInternal::StartAsyncImpl() {
IE_THROW(NotImplemented);
}
StatusCode IInferRequestInternal::Wait(int64_t millis_timeout) {
IE_THROW(NotImplemented);
}
void IInferRequestInternal::SetCallback(Callback callback) {
_callback = std::move(callback);
}
void IInferRequestInternal::execDataPreprocessing(InferenceEngine::BlobMap& preprocessedBlobs, bool serial) {
for (auto& input : preprocessedBlobs) {
// If there is a pre-process entry for an input then it must be pre-processed
// using preconfigured resize algorithm.
auto it = _preProcData.find(input.first);
if (it != _preProcData.end()) {
it->second->execute(input.second, _networkInputs[input.first]->getPreProcess(), serial, m_curBatch);
}
}
}
bool IInferRequestInternal::findInputAndOutputBlobByName(const std::string& name, InputInfo::Ptr& foundInput, DataPtr& foundOutput) const {
foundInput = nullptr;
foundOutput = nullptr;
if (_networkOutputs.empty()) {
IE_THROW() << "Internal error: network outputs is not set";
}
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
return pair.first == name;
});
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
[&](const std::pair<std::string, DataPtr>& pair) {
return pair.first == name;
});
if (foundOutputPair == std::end(_networkOutputs) && (foundInputPair == std::end(_networkInputs))) {
IE_THROW(NotFound) << "Failed to find input or output with name: \'" << name << "\'";
}
if (foundInputPair != std::end(_networkInputs)) {
foundInput = foundInputPair->second;
return true;
} else {
foundOutput = foundOutputPair->second;
return false;
}
}
void IInferRequestInternal::checkBlob(const Blob::Ptr& blob, const std::string& name, bool isInput, const SizeVector& refDims) const {
std::string bType = isInput ? "Input" : "Output";
std::string sType = isInput ? "input" : "output";
std::string strNotAllocated(bType + " data was not allocated.");
std::string strNotMatched("The " + sType + " blob size is not equal to the network " + sType + " size");
if (!blob) {
IE_THROW(NotAllocated) << strNotAllocated;
}
size_t refSize;
if (refDims.empty()) {
SizeVector dims;
if (isInput) {
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
return pair.first == name;
});
if (foundInputPair == std::end(_networkInputs)) {
IE_THROW(NotFound) << "Failed to find input with name: \'" << name << "\'";
}
dims = foundInputPair->second->getTensorDesc().getDims();
refSize = foundInputPair->second->getTensorDesc().getLayout() != SCALAR
? details::product(dims)
: 1;
} else {
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
[&](const std::pair<std::string, DataPtr>& pair) {
return pair.first == name;
});
if (foundOutputPair == std::end(_networkOutputs)) {
IE_THROW(NotFound) << "Failed to find output with name: \'" << name << "\'";
}
dims = foundOutputPair->second->getTensorDesc().getDims();
refSize = foundOutputPair->second->getTensorDesc().getLayout() != SCALAR
? details::product(dims)
: 1;
}
} else {
refSize = details::product(refDims);
}
if (refSize != blob->size()) {
IE_THROW() << strNotMatched + ": got " << blob->size() << " expecting " << refSize;
}
const bool remoteBlobPassed = blob->is<RemoteBlob>();
if (!remoteBlobPassed && blob->buffer() == nullptr) IE_THROW() << strNotAllocated;
}
void IInferRequestInternal::checkBlobs() {
for (auto const& input : _inputs) {
checkBlob(input.second, input.first, true);
}
for (auto const& output : _outputs) {
checkBlob(output.second, output.first, false);
}
}
void IInferRequestInternal::setPointerToExecutableNetworkInternal(const std::shared_ptr<IExecutableNetworkInternal>& exeNetwork) {
_exeNetwork = exeNetwork;
}
bool IInferRequestInternal::preProcessingRequired(const InputInfo::Ptr& info, const Blob::Ptr& userBlob, const Blob::Ptr& deviceBlob) {
// pre-processing is required if:
// 1. resize algorithm is specified (resize required)
// 2. color format specified:
// 2.a. color format is not equal to network's expected (color conversion required)
// 2.b. network's layout != blob's layout (reorder required)
// 3. precision conversion is required
const auto& preProcessInfo = info->getPreProcess();
const auto inputColorFormat = preProcessInfo.getColorFormat();
// FIXME: support other network's input formats once the API is ready. Assuming input is in
// the BGR format by default
const auto networkColorFormat = ColorFormat::BGR;
const bool colorFormatSpecified = inputColorFormat != ColorFormat::RAW;
auto blob_layout = [](const Blob::Ptr& b) { return b->getTensorDesc().getLayout(); };
auto blob_prec = [](const Blob::Ptr& b) { return b->getTensorDesc().getPrecision();};
auto dst_layout = deviceBlob ? blob_layout(deviceBlob) : info->getLayout();
auto dst_prec = deviceBlob ? blob_prec(deviceBlob) : info->getPrecision();
//FIXME: remove the first part to allow any needed conversion?
const bool need_layout_conv = (colorFormatSpecified || deviceBlob) &&
(blob_layout(userBlob) != dst_layout);
return preProcessInfo.getResizeAlgorithm() != ResizeAlgorithm::NO_RESIZE ||
(colorFormatSpecified && inputColorFormat != networkColorFormat) ||
need_layout_conv ||
(blob_prec(userBlob) != dst_prec);
}
void IInferRequestInternal::addInputPreProcessingFor(const std::string& name, Blob::Ptr const& from, const Blob::Ptr& to) {
auto ppDataIt = _preProcData.find(name);
if (ppDataIt == _preProcData.end()) {
ppDataIt = (_preProcData.emplace(name, CreatePreprocDataHelper())).first;
}
auto& preproc_ptr = ppDataIt->second;
preproc_ptr->isApplicable(from, to);
// Stores the given blob as ROI blob. It will be used to fill in network input
// during pre-processing
preproc_ptr->setRoiBlob(from);
}
} // namespace InferenceEngine

View File

@ -57,42 +57,38 @@ namespace details {
IE_SUPPRESS_DEPRECATED_START
StatusCode InferenceEngineException::getStatus() const {
return ExceptionToStatus(dynamic_cast<const Exception&>(*this));
}
} // namespace details
IE_SUPPRESS_DEPRECATED_END
INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exception) {
if (dynamic_cast<const GeneralError*>(&exception) != nullptr) {
if (dynamic_cast<const GeneralError*>(this) != nullptr) {
return GENERAL_ERROR;
} else if (dynamic_cast<const NotImplemented*>(&exception) != nullptr) {
} else if (dynamic_cast<const NotImplemented*>(this) != nullptr) {
return NOT_IMPLEMENTED;
} else if (dynamic_cast<const NetworkNotLoaded*>(&exception) != nullptr) {
} else if (dynamic_cast<const NetworkNotLoaded*>(this) != nullptr) {
return NETWORK_NOT_LOADED;
} else if (dynamic_cast<const ParameterMismatch*>(&exception) != nullptr) {
} else if (dynamic_cast<const ParameterMismatch*>(this) != nullptr) {
return PARAMETER_MISMATCH;
} else if (dynamic_cast<const NotFound*>(&exception) != nullptr) {
} else if (dynamic_cast<const NotFound*>(this) != nullptr) {
return NOT_FOUND;
} else if (dynamic_cast<const OutOfBounds*>(&exception) != nullptr) {
} else if (dynamic_cast<const OutOfBounds*>(this) != nullptr) {
return OUT_OF_BOUNDS;
} else if (dynamic_cast<const Unexpected*>(&exception) != nullptr) {
} else if (dynamic_cast<const Unexpected*>(this) != nullptr) {
return UNEXPECTED;
} else if (dynamic_cast<const RequestBusy*>(&exception) != nullptr) {
} else if (dynamic_cast<const RequestBusy*>(this) != nullptr) {
return REQUEST_BUSY;
} else if (dynamic_cast<const ResultNotReady*>(&exception) != nullptr) {
} else if (dynamic_cast<const ResultNotReady*>(this) != nullptr) {
return RESULT_NOT_READY;
} else if (dynamic_cast<const NotAllocated*>(&exception) != nullptr) {
} else if (dynamic_cast<const NotAllocated*>(this) != nullptr) {
return NOT_ALLOCATED;
} else if (dynamic_cast<const InferNotStarted*>(&exception) != nullptr) {
} else if (dynamic_cast<const InferNotStarted*>(this) != nullptr) {
return INFER_NOT_STARTED;
} else if (dynamic_cast<const NetworkNotRead*>(&exception) != nullptr) {
} else if (dynamic_cast<const NetworkNotRead*>(this) != nullptr) {
return NETWORK_NOT_READ;
} else if (dynamic_cast<const InferCancelled*>(&exception) != nullptr) {
} else if (dynamic_cast<const InferCancelled*>(this) != nullptr) {
return INFER_CANCELLED;
} else {
assert(!"Unreachable"); return OK;
}
}
} // namespace details
IE_SUPPRESS_DEPRECATED_END
//
// ie_parameter.hpp

View File

@ -5,7 +5,7 @@
#include "mkldnn_async_infer_request.h"
#include <memory>
MKLDNNPlugin::MKLDNNAsyncInferRequest::MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr& inferRequest,
MKLDNNPlugin::MKLDNNAsyncInferRequest::MKLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr& inferRequest,
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor)
: InferenceEngine::AsyncInferRequestThreadSafeDefault(inferRequest, taskExecutor, callbackExecutor) {

View File

@ -13,10 +13,10 @@ namespace MKLDNNPlugin {
class MKLDNNAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadSafeDefault {
public:
MKLDNNAsyncInferRequest(const InferenceEngine::InferRequestInternal::Ptr &inferRequest,
MKLDNNAsyncInferRequest(const InferenceEngine::IInferRequestInternal::Ptr &inferRequest,
const InferenceEngine::ITaskExecutor::Ptr &taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor);
~MKLDNNAsyncInferRequest() override;
~MKLDNNAsyncInferRequest();
};
} // namespace MKLDNNPlugin

View File

@ -29,7 +29,7 @@ using namespace MKLDNNPlugin;
using namespace InferenceEngine;
using namespace InferenceEngine::details;
InferenceEngine::InferRequestInternal::Ptr
InferenceEngine::IInferRequestInternal::Ptr
MKLDNNExecNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) {
return std::make_shared<MKLDNNInferRequest>(networkInputs, networkOutputs, std::static_pointer_cast<MKLDNNExecNetwork>(shared_from_this()));
@ -323,7 +323,7 @@ void MKLDNNExecNetwork::setProperty(const std::map<std::string, std::string> &pr
}
}
InferenceEngine::IInferRequest::Ptr MKLDNNExecNetwork::CreateInferRequest() {
InferenceEngine::IInferRequestInternal::Ptr MKLDNNExecNetwork::CreateInferRequest() {
return CreateAsyncInferRequestFromSync<MKLDNNAsyncInferRequest>();
}

View File

@ -23,11 +23,11 @@ class MKLDNNExecNetwork: public InferenceEngine::ExecutableNetworkThreadSafeDefa
public:
typedef std::shared_ptr<MKLDNNExecNetwork> Ptr;
InferenceEngine::InferRequestInternal::Ptr
InferenceEngine::IInferRequestInternal::Ptr
CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
MKLDNNExecNetwork(const InferenceEngine::CNNNetwork &network, const Config &cfg,
const MKLDNNExtensionManager::Ptr &extMgr, NumaNodesWeights &weightsSharing);

View File

@ -19,11 +19,13 @@
#include "nodes/mkldnn_memory_node.hpp"
#include "nodes/common/cpu_memcpy.h"
#include "mkldnn_async_infer_request.h"
#include <debug.h>
MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs,
MKLDNNExecNetwork::Ptr execNetwork_)
: InferRequestInternal(networkInputs, networkOutputs)
: IInferRequestInternal(networkInputs, networkOutputs)
, execNetwork(execNetwork_) {
auto id = (execNetwork->_numRequests)++;
profilingTask = openvino::itt::handle("MKLDNN_INFER_" + execNetwork->_name + "_" + std::to_string(id));

View File

@ -8,21 +8,21 @@
#include <memory>
#include <string>
#include <map>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
namespace MKLDNNPlugin {
class MKLDNNExecNetwork;
class MKLDNNAsyncInferRequest;
class MKLDNNInferRequest : public InferenceEngine::InferRequestInternal {
class MKLDNNInferRequest : public InferenceEngine::IInferRequestInternal {
public:
typedef std::shared_ptr<MKLDNNInferRequest> Ptr;
explicit MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs,
std::shared_ptr<MKLDNNExecNetwork> execNetwork);
~MKLDNNInferRequest() override;
~MKLDNNInferRequest();
void InferImpl() override;
@ -34,7 +34,7 @@ public:
void SetBatch(int batch = -1) override;
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override;
/**
* @brief Sets the pointer to asynchronous inference request that holds this request
@ -59,7 +59,7 @@ private:
MKLDNNGraph* graph = nullptr;
std::map<std::string, void*> externalPtr;
openvino::itt::handle_t profilingTask;
std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> memoryStates;
MKLDNNAsyncInferRequest* _asyncRequest = nullptr;
};
} // namespace MKLDNNPlugin

View File

@ -27,7 +27,7 @@ public:
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
void Infer_ThreadUnsafe() override;
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
~MultiDeviceAsyncInferRequest() override;
~MultiDeviceAsyncInferRequest();
protected:
MultiDeviceExecutableNetwork::Ptr _multiDeviceExecutableNetwork;

View File

@ -174,7 +174,7 @@ RemoteContext::Ptr MultiDeviceExecutableNetwork::GetContext() const {
<< " Current list of devices allowed via the DEVICE_PRIORITIES config: " << devices_names;
}
InferenceEngine::InferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::IInferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) {
auto num = _numRequestsCreated++;
size_t sum = 0;
@ -192,17 +192,13 @@ InferenceEngine::InferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateI
return std::make_shared<MultiDeviceInferRequest>(networkInputs, networkOutputs, request_to_share_blobs_with);
}
IInferRequest::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
IInferRequest::Ptr asyncRequest;
IInferRequestInternal::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
auto syncRequestImpl = CreateInferRequestImpl(_networkInputs, _networkOutputs);
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
auto asyncTreadSafeImpl = std::make_shared<MultiDeviceAsyncInferRequest>(std::static_pointer_cast<MultiDeviceInferRequest>(syncRequestImpl),
_needPerfCounters,
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
_callbackExecutor);
asyncRequest.reset(new InferRequestBase(asyncTreadSafeImpl));
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
return std::make_shared<MultiDeviceAsyncInferRequest>(std::static_pointer_cast<MultiDeviceInferRequest>(syncRequestImpl),
_needPerfCounters,
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
_callbackExecutor);
}
void MultiDeviceExecutableNetwork::SetConfig(const std::map<std::string, InferenceEngine::Parameter> &config) {

View File

@ -114,9 +114,9 @@ public:
InferenceEngine::Parameter GetConfig(const std::string &name) const override;
InferenceEngine::Parameter GetMetric(const std::string &name) const override;
void run(InferenceEngine::Task inferTask) override;
InferenceEngine::IInferRequest::Ptr CreateInferRequest() override;
InferenceEngine::InferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequest() override;
InferenceEngine::IInferRequestInternal::Ptr CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs) override;
InferenceEngine::RemoteContext::Ptr GetContext() const override;
~MultiDeviceExecutableNetwork() override;

View File

@ -5,6 +5,10 @@
///////////////////////////////////////////////////////////////////////////////////////////////////
#include "multi_device_infer_request.hpp"
#include <ie_input_info.hpp>
#include <ie_icnn_network.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <blob_factory.hpp>
namespace MultiDevicePlugin {
using namespace InferenceEngine;
@ -12,7 +16,7 @@ namespace MultiDevicePlugin {
MultiDeviceInferRequest::MultiDeviceInferRequest(const InputsDataMap& networkInputs,
const OutputsDataMap& networkOutputs,
InferRequest request_to_share_blobs_with)
: InferRequestInternal(networkInputs, networkOutputs) {
: IInferRequestInternal(networkInputs, networkOutputs) {
if (request_to_share_blobs_with) {
// borrow device-friendly blobs from the request
for (const auto &it : _networkInputs)

View File

@ -14,12 +14,13 @@
#include <utility>
#include <memory>
#include <string>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <cpp/ie_executable_network.hpp>
#include <cpp/ie_infer_request.hpp>
namespace MultiDevicePlugin {
class MultiDeviceInferRequest : public InferenceEngine::InferRequestInternal {
class MultiDeviceInferRequest : public InferenceEngine::IInferRequestInternal {
public:
using Ptr = std::shared_ptr<MultiDeviceInferRequest>;
explicit MultiDeviceInferRequest(const InferenceEngine::InputsDataMap& networkInputs,

View File

@ -15,6 +15,7 @@
#include <multi-device/multi_device_config.hpp>
#include <threading/ie_executor_manager.hpp>
#include "multi_device_plugin.hpp"
#include <ie_algorithm.hpp>
// ------------------------------MultiDeviceInferencePlugin----------------------------
namespace MultiDevicePlugin {

View File

@ -20,6 +20,7 @@
#include <vector>
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
namespace InferenceEngine {
@ -53,7 +54,7 @@ public:
}
StatusCode CreateInferRequest(IInferRequest::Ptr& req, ResponseDesc* resp) noexcept override {
TO_STATUS(req = _impl->CreateInferRequest());
TO_STATUS(req = std::make_shared<InferRequestBase>(_impl->CreateInferRequest()));
}
StatusCode Export(const std::string& modelFileName, ResponseDesc* resp) noexcept override {

View File

@ -11,25 +11,92 @@
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/plugin_itt.hpp"
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include "ie_iinfer_request.hpp"
#include "ie_preprocess.hpp"
namespace InferenceEngine {
#define CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(StatusCode, ExceptionType) catch (const ExceptionType& ex) { \
return InferenceEngine::DescriptionBuffer(StatusCode) << ex.what(); \
}
#define CATCH_IE_EXCEPTIONS_TO_STATUS_NO_RESP \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(GENERAL_ERROR, GeneralError) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NOT_IMPLEMENTED, NotImplemented) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NETWORK_NOT_LOADED, NetworkNotLoaded) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(PARAMETER_MISMATCH, ParameterMismatch) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NOT_FOUND, NotFound) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(OUT_OF_BOUNDS, OutOfBounds) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(UNEXPECTED, Unexpected) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(REQUEST_BUSY, RequestBusy) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(RESULT_NOT_READY, ResultNotReady) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NOT_ALLOCATED, NotAllocated) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(INFER_NOT_STARTED, InferNotStarted) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(NETWORK_NOT_READ, NetworkNotRead) \
CATCH_IE_EXCEPTION_TO_STATUS_NO_RESP(INFER_CANCELLED, InferCancelled)
/**
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
* @ingroup ie_dev_api_async_infer_request_api
* @def TO_STATUS_NO_RESP(x)
* @brief Converts C++ exceptioned function call into a status code. Does not work with a ResponseDesc object
* @ingroup ie_dev_api_error_debug
*/
#define TO_STATUS_NO_RESP(x) \
try { \
x; \
return OK; \
} CATCH_IE_EXCEPTIONS_TO_STATUS_NO_RESP catch (const std::exception& ex) { \
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR) << ex.what(); \
} catch (...) { \
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
}
#define CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(StatusCode, ExceptionType) catch (const ExceptionType& ex) { \
return InferenceEngine::DescriptionBuffer(StatusCode, resp) << ex.what(); \
}
#define CATCH_IE_EXCEPTIONS_CALL_RETURN_STATUS \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(GENERAL_ERROR, GeneralError) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NOT_IMPLEMENTED, NotImplemented) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NETWORK_NOT_LOADED, NetworkNotLoaded) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(PARAMETER_MISMATCH, ParameterMismatch) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NOT_FOUND, NotFound) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(OUT_OF_BOUNDS, OutOfBounds) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(UNEXPECTED, Unexpected) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(REQUEST_BUSY, RequestBusy) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(RESULT_NOT_READY, ResultNotReady) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NOT_ALLOCATED, NotAllocated) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(INFER_NOT_STARTED, InferNotStarted) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(NETWORK_NOT_READ, NetworkNotRead) \
CATCH_IE_EXCEPTION_CALL_RETURN_STATUS(INFER_CANCELLED, InferCancelled)
/**
* @def NO_EXCEPT_CALL_RETURN_STATUS(x)
* @brief Returns a status code of a called function, handles exeptions and converts to a status code.
* @ingroup ie_dev_api_error_debug
*/
#define NO_EXCEPT_CALL_RETURN_STATUS(x) \
try { \
return x; \
} CATCH_IE_EXCEPTIONS_CALL_RETURN_STATUS catch (const std::exception& ex) { \
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); \
} catch (...) { \
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
}
/**
* @brief Inference request `noexcept` wrapper which accepts IInferRequestInternal derived instance which can throw exceptions
* @ingroup ie_dev_api_infer_request_api
*/
class InferRequestBase : public IInferRequest {
std::shared_ptr<IAsyncInferRequestInternal> _impl;
std::shared_ptr<IInferRequestInternal> _impl;
public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IAsyncInferRequestInternal
* @param impl Underlying implementation of type IInferRequestInternal
*/
explicit InferRequestBase(std::shared_ptr<IAsyncInferRequestInternal> impl): _impl(impl) {}
explicit InferRequestBase(std::shared_ptr<IInferRequestInternal> impl): _impl(impl) {}
StatusCode Infer(ResponseDesc* resp) noexcept override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
@ -73,15 +140,30 @@ public:
}
StatusCode SetCompletionCallback(CompletionCallback callback) noexcept override {
TO_STATUS_NO_RESP(_impl->SetCompletionCallback(callback));
TO_STATUS_NO_RESP(_impl->SetCallback([callback, this] (std::exception_ptr exceptionPtr) {
StatusCode statusCode = [&] ()-> StatusCode {
if (exceptionPtr) {
TO_STATUS_NO_RESP(std::rethrow_exception(exceptionPtr));
} else {
return OK;
}
} ();
callback(std::shared_ptr<InferRequestBase>{this, [](InferRequestBase*){}}, statusCode);
}));
}
StatusCode GetUserData(void** data, ResponseDesc* resp) noexcept override {
TO_STATUS(_impl->GetUserData(data));
StatusCode GetUserData(void** data, ResponseDesc*) noexcept override {
if (data != nullptr) {
*data = _data;
return OK;
} else {
return GENERAL_ERROR;
}
}
StatusCode SetUserData(void* data, ResponseDesc* resp) noexcept override {
TO_STATUS(_impl->SetUserData(data));
StatusCode SetUserData(void* data, ResponseDesc*) noexcept override {
_data = data;
return OK;
}
StatusCode SetBatch(int batch_size, ResponseDesc* resp) noexcept override {
@ -104,6 +186,8 @@ public:
}
}
IE_SUPPRESS_DEPRECATED_END
void* _data = nullptr;
};
} // namespace InferenceEngine

View File

@ -13,8 +13,24 @@
#include "description_buffer.hpp"
namespace InferenceEngine {
#define CATCH_IE_EXCEPTION_TO_STATUS(StatusCode, ExceptionType) catch (const ExceptionType& ex) { \
return InferenceEngine::DescriptionBuffer(StatusCode, resp) << ex.what(); \
}
INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exception);
#define CATCH_IE_EXCEPTIONS_TO_STATUS \
CATCH_IE_EXCEPTION_TO_STATUS(GENERAL_ERROR, GeneralError) \
CATCH_IE_EXCEPTION_TO_STATUS(NOT_IMPLEMENTED, NotImplemented) \
CATCH_IE_EXCEPTION_TO_STATUS(NETWORK_NOT_LOADED, NetworkNotLoaded) \
CATCH_IE_EXCEPTION_TO_STATUS(PARAMETER_MISMATCH, ParameterMismatch) \
CATCH_IE_EXCEPTION_TO_STATUS(NOT_FOUND, NotFound) \
CATCH_IE_EXCEPTION_TO_STATUS(OUT_OF_BOUNDS, OutOfBounds) \
CATCH_IE_EXCEPTION_TO_STATUS(UNEXPECTED, Unexpected) \
CATCH_IE_EXCEPTION_TO_STATUS(REQUEST_BUSY, RequestBusy) \
CATCH_IE_EXCEPTION_TO_STATUS(RESULT_NOT_READY, ResultNotReady) \
CATCH_IE_EXCEPTION_TO_STATUS(NOT_ALLOCATED, NotAllocated) \
CATCH_IE_EXCEPTION_TO_STATUS(INFER_NOT_STARTED, InferNotStarted) \
CATCH_IE_EXCEPTION_TO_STATUS(NETWORK_NOT_READ, NetworkNotRead) \
CATCH_IE_EXCEPTION_TO_STATUS(INFER_CANCELLED, InferCancelled)
/**
* @def TO_STATUS(x)
@ -25,42 +41,7 @@ INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exceptio
try { \
x; \
return OK; \
} catch (const ::InferenceEngine::Exception& iex) { \
return InferenceEngine::DescriptionBuffer(InferenceEngine::ExceptionToStatus(iex), resp) << iex.what(); \
} catch (const std::exception& ex) { \
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); \
} catch (...) { \
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
}
/**
* @def TO_STATUS_NO_RESP(x)
* @brief Converts C++ exceptioned function call into a status code. Does not work with a ResponseDesc object
* @ingroup ie_dev_api_error_debug
*/
#define TO_STATUS_NO_RESP(x) \
try { \
x; \
return OK; \
} catch (const ::InferenceEngine::Exception& iex) { \
return InferenceEngine::DescriptionBuffer(InferenceEngine::ExceptionToStatus(iex)) << iex.what(); \
} catch (const std::exception& ex) { \
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR) << ex.what(); \
} catch (...) { \
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
}
/**
* @def NO_EXCEPT_CALL_RETURN_STATUS(x)
* @brief Returns a status code of a called function, handles exeptions and converts to a status code.
* @ingroup ie_dev_api_error_debug
*/
#define NO_EXCEPT_CALL_RETURN_STATUS(x) \
try { \
return x; \
} catch (const ::InferenceEngine::Exception& iex) { \
return InferenceEngine::DescriptionBuffer(InferenceEngine::ExceptionToStatus(iex), resp) << iex.what(); \
} catch (const std::exception& ex) { \
} CATCH_IE_EXCEPTIONS_TO_STATUS catch (const std::exception& ex) { \
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); \
} catch (...) { \
return InferenceEngine::DescriptionBuffer(UNEXPECTED); \
@ -82,5 +63,4 @@ INFERENCE_ENGINE_API_CPP(StatusCode) ExceptionToStatus(const Exception& exceptio
CATCH_IE_EXCEPTION(InferNotStarted) \
CATCH_IE_EXCEPTION(NetworkNotRead) \
CATCH_IE_EXCEPTION(InferCancelled)
} // namespace InferenceEngine

View File

@ -11,8 +11,6 @@
#include <vector>
#include <fstream>
#include "cpp_interfaces/impl/ie_infer_async_request_internal.hpp"
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
@ -118,7 +116,29 @@ public:
IE_THROW(NotImplemented);
}
/**
* @brief Creates an inference request public implementation.
* @return The request public implementation
*/
IInferRequestInternal::Ptr CreateInferRequest() override {
auto asyncRequestImpl = this->CreateInferRequestImpl(_networkInputs, _networkOutputs);
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
return asyncRequestImpl;
}
protected:
/**
* @brief Creates an asynchronous inference request internal implementation.
* @note The method is called by ExecutableNetworkInternal::CreateInferRequest as
* plugin-specific implementation.
* @param[in] networkInputs The network inputs
* @param[in] networkOutputs The network outputs
* @return A shared pointer to asynchnous inference request object.
*/
virtual IInferRequestInternal::Ptr CreateInferRequestImpl(InputsDataMap networkInputs,
OutputsDataMap networkOutputs) = 0;
/**
* @brief Exports an internal hardware-dependent model to a stream.
* @note The function is called from ExecutableNetworkInternal::Export(std::ostream&),
@ -130,7 +150,7 @@ protected:
IE_THROW(NotImplemented);
}
InferenceEngine::InputsDataMap _networkInputs; //!< Holds infromation about network inputs info
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
/**

View File

@ -1,59 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_iinfer_request.hpp>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
#include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
#include "cpp_interfaces/impl/ie_infer_async_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
#include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
namespace InferenceEngine {
/**
* @brief This class describes an executable network thread safe asynchronous only implementation.
* @ingroup ie_dev_api_exec_network_api
*/
class ExecutableNetworkThreadSafeAsyncOnly : public ExecutableNetworkInternal {
public:
/**
* @brief A shared pointer to a ExecutableNetworkThreadSafeAsyncOnly object
*/
typedef std::shared_ptr<ExecutableNetworkThreadSafeAsyncOnly> Ptr;
/**
* @brief Creates an asynchronous inference request public implementation.
* @return The asynchronous request public implementation
*/
IInferRequest::Ptr CreateInferRequest() override {
IInferRequest::Ptr asyncRequest;
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
asyncRequest.reset(new InferRequestBase(asyncRequestImpl));
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
}
protected:
/**
* @brief Creates an asynchronous inference request internal implementation.
* @note The method is called by ExecutableNetworkThreadSafeAsyncOnly::CreateInferRequest as
* plugin-specific implementation.
* @param[in] networkInputs The network inputs
* @param[in] networkOutputs The network outputs
* @return A shared pointer to asynchnous inference request object.
*/
virtual AsyncInferRequestInternal::Ptr CreateAsyncInferRequestImpl(InputsDataMap networkInputs,
OutputsDataMap networkOutputs) = 0;
};
} // namespace InferenceEngine

View File

@ -12,7 +12,6 @@
#include "cpp_interfaces/base/ie_infer_async_request_base.hpp"
#include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
#include "cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp"
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
#include "threading/ie_cpu_streams_executor.hpp"
namespace InferenceEngine {
@ -49,7 +48,7 @@ public:
* need for it to be implemented by plugin
* @return shared_ptr for the created asynchronous inference request
*/
IInferRequest::Ptr CreateInferRequest() override {
IInferRequestInternal::Ptr CreateInferRequest() override {
return CreateAsyncInferRequestFromSync();
}
@ -60,28 +59,12 @@ protected:
* @return A shared pointer to an asynchronous inference request
*/
template <typename AsyncInferRequestType = AsyncInferRequestThreadSafeDefault>
IInferRequest::Ptr CreateAsyncInferRequestFromSync() {
IInferRequestInternal::Ptr CreateAsyncInferRequestFromSync() {
auto syncRequestImpl = this->CreateInferRequestImpl(_networkInputs, _networkOutputs);
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
syncRequestImpl, _taskExecutor, _callbackExecutor);
IInferRequest::Ptr asyncRequest = std::make_shared<InferRequestBase>(asyncThreadSafeImpl);
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
return std::make_shared<AsyncInferRequestType>(syncRequestImpl, _taskExecutor, _callbackExecutor);
}
/**
* @brief Creates a synchronous inference request object used to infer the network
* @note Used by ExecutableNetworkThreadSafeDefault::CreateInferRequest as a plugin-specific implementation
* @param networkInputs An input info map needed to create input blobs
* @param networkOutputs An output data map needed to create output blobs
* @return Synchronous inference request object
*/
virtual InferRequestInternal::Ptr CreateInferRequestImpl(InputsDataMap networkInputs,
OutputsDataMap networkOutputs) = 0;
ITaskExecutor::Ptr _taskExecutor = nullptr; //!< Holds a task executor
ITaskExecutor::Ptr _callbackExecutor = nullptr; //!< Holds a callback executor
};

View File

@ -1,82 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include <memory>
#include <string>
#include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
namespace InferenceEngine {
#if defined(_MSC_VER)
#pragma warning(disable : 4250)
#endif
/**
* @brief minimum API to be implemented by plugin, which is used in InferRequestBase forwarding mechanism
* @ingroup ie_dev_api_async_infer_request_api
*/
class AsyncInferRequestInternal : public IAsyncInferRequestInternal, public InferRequestInternal {
public:
/**
* @brief A shared pointer to a AsyncInferRequestInternal implementation
*/
typedef std::shared_ptr<AsyncInferRequestInternal> Ptr;
/**
* @brief Constructs a new instance.
* @param[in] networkInputs The network inputs info
* @param[in] networkOutputs The network outputs data
*/
AsyncInferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs)
: InferRequestInternal(networkInputs, networkOutputs), _callback(nullptr), _userData(nullptr) {}
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
_callback = callback;
}
void GetUserData(void** data) override {
if (data == nullptr) IE_THROW(NotAllocated);
*data = _userData;
}
void SetUserData(void* data) override {
_userData = data;
}
/**
* @brief Set weak pointer to the corresponding public interface: IInferRequest. This allow to pass it to
* IInferRequest::CompletionCallback
* @param ptr A weak pointer to InferRequestBase
*/
void SetPointerToPublicInterface(IInferRequest::Ptr ptr) {
_publicInterface = ptr;
}
void StartAsync() override {
checkBlobs();
StartAsyncImpl();
};
protected:
/**
* @brief The minimal asynchronous inference function to be implemented by plugins.
* It starts inference of specified input(s) in asynchronous mode
* @note
* * The methos is used in AsyncInferRequestInternal::StartAsync which performs common steps first and
* calls plugin dependent implementation of this method after.
* * It returns immediately. Inference starts also immediately.
*/
virtual void StartAsyncImpl() = 0;
IInferRequest::WeakPtr _publicInterface; //!< A weak pointer to a IInferRequest interface for callback calling
InferenceEngine::IInferRequest::CompletionCallback _callback; //!< A callback
void* _userData; //!< A callback user data
};
} // namespace InferenceEngine

View File

@ -8,10 +8,10 @@
#include <threading/ie_itask_executor.hpp>
#include <threading/ie_istreams_executor.hpp>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <cpp_interfaces/exception2status.hpp>
#include <ie_system_conf.h>
#include <ie_iinfer_request.hpp>
#include <exception>
#include <future>
@ -40,12 +40,12 @@ namespace InferenceEngine {
*
* @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
*/
class AsyncInferRequestThreadSafeDefault : public IAsyncInferRequestInternal {
class AsyncInferRequestThreadSafeDefault : public IInferRequestInternal {
enum InferState {Idle, Busy, Canceled, Stop};
using Futures = std::vector<std::shared_future<void>>;
using Promise = std::shared_ptr<std::promise<void>>;
enum Stage_e : std::uint8_t { executor, task };
InferRequestInternal::Ptr _syncRequest;
IInferRequestInternal::Ptr _syncRequest;
friend struct DisableCallbackGuard;
struct DisableCallbackGuard {
@ -59,7 +59,7 @@ class AsyncInferRequestThreadSafeDefault : public IAsyncInferRequestInternal {
_this->_callback = _callback;
}
AsyncInferRequestThreadSafeDefault* _this = nullptr;
IInferRequest::CompletionCallback _callback = nullptr;
Callback _callback;
};
struct ImmediateStreamsExecutor : public InferenceEngine::ITaskExecutor {
@ -132,15 +132,15 @@ public:
using Ptr = std::shared_ptr<AsyncInferRequestThreadSafeDefault>;
/**
* @brief Wraps a InferRequestInternal::Ptr implementation and constructs a
* AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run InferRequestInternal::Infer
* @brief Wraps a IInferRequestInternal::Ptr implementation and constructs a
* AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run IInferRequestInternal::Infer
* asynchronously.
*
* @param[in] request The synchronous request
* @param[in] taskExecutor The task executor
* @param[in] callbackExecutor The callback executor
*/
AsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr& request,
AsyncInferRequestThreadSafeDefault(const IInferRequestInternal::Ptr& request,
const ITaskExecutor::Ptr& taskExecutor,
const ITaskExecutor::Ptr& callbackExecutor) :
_syncRequest {request},
@ -245,32 +245,13 @@ public:
_syncRequest->SetBatch(batch);
};
void GetUserData(void** data) override {
void SetCallback(Callback callback) override {
CheckState();
if (data == nullptr) IE_THROW(NotAllocated);
*data = _userData;
_callback = std::move(callback);
}
void SetUserData(void* data) override {
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override {
CheckState();
_userData = data;
}
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
CheckState();
_callback = callback;
}
/**
* @brief Sets the pointer to public interface.
* @note Needed to correctly handle ownership between objects
* @param[in] ptr A shared pointer to a public IInferRequest interface.
*/
void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr) {
_publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
}
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
return _syncRequest->QueryState();
}
@ -319,13 +300,13 @@ protected:
* pipeline tasks
*/
void StopAndWait() {
_callback = nullptr;
Futures futures;
InferState state = InferState::Idle;
{
std::lock_guard<std::mutex> lock{_mutex};
state = _state;
if (state != InferState::Stop) {
_callback = {};
_state = InferState::Stop;
futures = std::move(_futures);
}
@ -385,51 +366,44 @@ private:
Task MakeNextStageTask(const Pipeline::iterator itStage, const Pipeline::iterator itEndStage,
const ITaskExecutor::Ptr callbackExecutor) {
return std::bind([this, itStage, itEndStage](ITaskExecutor::Ptr& callbackExecutor) mutable {
StatusCode requestStatus = StatusCode::OK;
std::exception_ptr localCurrentException = nullptr;
std::exception_ptr currentException = nullptr;
auto& thisStage = *itStage;
auto itNextStage = itStage + 1;
try {
auto& stageTask = std::get<Stage_e::task>(thisStage);
IE_ASSERT(nullptr != stageTask);
stageTask();
if (itEndStage != itNextStage) {
if (itEndStage != itNextStage) {
auto& nextStage = *itNextStage;
auto& nextStageExecutor = std::get<Stage_e::executor>(nextStage);
IE_ASSERT(nullptr != nextStageExecutor);
nextStageExecutor->run(MakeNextStageTask(itNextStage, itEndStage, std::move(callbackExecutor)));
}
} catch (InferenceEngine::Exception& ie_ex) {
requestStatus = ExceptionToStatus(ie_ex);
localCurrentException = std::current_exception();
} catch (...) {
requestStatus = StatusCode::GENERAL_ERROR;
localCurrentException = std::current_exception();
currentException = std::current_exception();
}
if ((itEndStage == itNextStage) || (nullptr != localCurrentException)) {
auto lastStageTask = [this, requestStatus, localCurrentException]() mutable {
if ((itEndStage == itNextStage) || (nullptr != currentException)) {
auto lastStageTask = [this, currentException]() mutable {
auto promise = std::move(_promise);
IInferRequest::CompletionCallback callback = nullptr;
Callback callback;
{
std::lock_guard<std::mutex> lock{_mutex};
_state = InferState::Idle;
callback = _callback;
}
if (nullptr != callback) {
InferenceEngine::CurrentException() = localCurrentException;
if (callback) {
try {
callback(_publicInterface, requestStatus);
auto local_callback = std::move(callback);
local_callback(currentException);
} catch (...) {
localCurrentException = std::current_exception();
currentException = std::current_exception();
}
InferenceEngine::CurrentException() = nullptr;
}
if (nullptr == localCurrentException) {
if (nullptr == currentException) {
promise.set_value();
} else {
promise.set_exception(localCurrentException);
promise.set_exception(currentException);
}
};
@ -442,9 +416,6 @@ private:
}, std::move(callbackExecutor));
}
void* _userData = nullptr;
IInferRequest::CompletionCallback _callback = nullptr;
IInferRequest::Ptr _publicInterface;
std::promise<void> _promise;
mutable std::mutex _mutex;
Futures _futures;

View File

@ -1,422 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_icnn_network.hpp>
#include <ie_input_info.hpp>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/plugin_itt.hpp"
#include "cpp_interfaces/interface/ie_iinfer_request_internal.hpp"
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
#include "debug.h"
#include "ie_compound_blob.h"
#include "ie_memcpy.h"
#include "ie_preprocess_data.hpp"
namespace InferenceEngine {
class IExecutableNetworkInternal;
/**
* @brief An optimal implementation of IInferRequestInternal interface to avoid duplication in all plugins
* This base class is recommended to be used as a base class for plugin synchronous inference request implementation.
* @ingroup ie_dev_api_infer_request_api
*/
class InferRequestInternal : virtual public IInferRequestInternal {
public:
/**
* @brief A shared pointer to a InferRequestInternal implementation.
*/
typedef std::shared_ptr<InferRequestInternal> Ptr;
/**
* @brief Constructs a new instance.
* @param[in] networkInputs The network inputs info
* @param[in] networkOutputs The network outputs data
*/
InferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs): m_curBatch(-1) {
// // We should copy maps since they can be overriden in SetBlob with preprocess
copyInputOutputInfo(networkInputs, networkOutputs, _networkInputs, _networkOutputs);
}
/**
* @brief The minimal infer function to be implemented by plugins. It infers specified input(s) in synchronous mode
* @note
* * This method is used in InferRequestInternal::Infer, which calls the common code first and after uses this
* plugin dependent implementation.
* * Blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
*/
virtual void InferImpl() = 0;
/**
* @brief Default common implementation for all plugins with checking input and output blobs before inference
*/
void Infer() override {
checkBlobs();
InferImpl();
}
/**
* @brief Default common implementation for all plugins
*/
void Cancel() override {
IE_THROW(NotImplemented);
}
/**
* @brief Given optional implementation of setting blob to avoid need for it to be implemented by plugin
* @param name - a name of input or output blob.
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input
* precision and size.
*/
void SetBlob(const std::string& name, const Blob::Ptr& userBlob) override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "SetBlob");
if (name.empty()) {
IE_THROW(NotFound) << "Failed to set blob with empty name";
}
if (!userBlob) IE_THROW(NotAllocated) << "Failed to set empty blob with name: \'" << name << "\'";
const bool compoundBlobPassed = userBlob->is<CompoundBlob>();
const bool remoteBlobPassed = userBlob->is<RemoteBlob>();
if (!compoundBlobPassed && !remoteBlobPassed && userBlob->buffer() == nullptr)
IE_THROW(NotAllocated) << "Input data was not allocated. Input name: \'" << name << "\'";
if (userBlob->size() == 0) {
IE_THROW() << "Input data is empty. Input name: \'" << name << "\'";
}
InputInfo::Ptr foundInput;
DataPtr foundOutput;
size_t dataSize = userBlob->size();
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
if (foundInput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
IE_THROW(ParameterMismatch)
<< "Failed to set Blob with precision not corresponding to user input precision";
}
auto& devBlob = _deviceInputs[name];
const bool preProcRequired = preProcessingRequired(foundInput, userBlob, devBlob);
if (compoundBlobPassed && !preProcRequired) {
IE_THROW(NotImplemented)
<< "cannot set compound blob: supported only for input pre-processing";
}
if (preProcRequired) {
addInputPreProcessingFor(name, userBlob, devBlob ? devBlob : _inputs[name]);
} else {
size_t inputSize = foundInput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
? InferenceEngine::details::product(foundInput->getTensorDesc().getDims())
: 1;
if (dataSize != inputSize) {
IE_THROW() << "Input blob size is not equal network input size (" << dataSize
<< "!=" << inputSize << ").";
}
_inputs[name] = userBlob;
devBlob = userBlob;
}
} else {
if (compoundBlobPassed) {
IE_THROW(NotImplemented)
<< "cannot set compound blob: supported only for input pre-processing";
}
size_t outputSize = foundOutput->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR
? details::product(foundOutput->getTensorDesc().getDims()) :
1;
if (dataSize != outputSize) {
IE_THROW() << "Output blob size is not equal network output size (" << dataSize
<< "!=" << outputSize << ").";
}
if (foundOutput->getPrecision() != userBlob->getTensorDesc().getPrecision()) {
IE_THROW(ParameterMismatch)
<< "Failed to set Blob with precision not corresponding to user output precision";
}
_outputs[name] = userBlob;
}
}
/**
* @brief Given optional implementation of getting blob to avoid need for it to be implemented by plugin
* @param name - a name of input or output blob.
* @return Returns input or output blob. The type of Blob must correspond to the network input
* precision and size.
* @note if ROI blob was previously set it is returned (without dimensions checks) instead of default blob.
*/
Blob::Ptr GetBlob(const std::string& name) override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "GetBlob");
Blob::Ptr data;
InputInfo::Ptr foundInput;
DataPtr foundOutput;
const SizeVector oneVector = { 1 };
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
// ROI blob is returned only if it was set previously. Otherwise default blob is returned.
auto it = _preProcData.find(name);
if (it != _preProcData.end()) {
data = it->second->getRoiBlob();
} else {
data = _inputs[name];
checkBlob(data, name, true,
foundInput->getTensorDesc().getLayout() != SCALAR
? foundInput->getTensorDesc().getDims()
: oneVector);
auto& devBlob = _deviceInputs[name];
if (preProcessingRequired(foundInput, data, devBlob)) {
// if no devBlob, performs inplace
addInputPreProcessingFor(name, data, devBlob ? devBlob : _inputs[name]);
}
}
} else {
data = _outputs[name];
checkBlob(data, name, false,
foundOutput->getTensorDesc().getLayout() != SCALAR
? foundOutput->getTensorDesc().getDims()
: oneVector);
}
return data;
}
/**
* @brief Sets pre-process for input data
* @param name Name of input blob.
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input precision and size.
* @param info Preprocess info for blob.
*/
void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) override {
InputInfo::Ptr foundInput;
DataPtr foundOutput;
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
copyPreProcess(info, foundInput->getPreProcess());
} else {
IE_THROW() << "Pre-process can't be set to output blob";
}
SetBlob(name, data);
}
/**
* @brief Gets pre-process for input data
* @param name Name of input blob.
* @return Returns constant reference to PreProcessInfo structure
*/
const PreProcessInfo& GetPreProcess(const std::string& name) const override {
InputInfo::Ptr foundInput;
DataPtr foundOutput;
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
return foundInput->getPreProcess();
} else {
IE_THROW() << "Output blob can't have pre-processing";
}
}
void SetBatch(int batch) override {
(void)batch;
IE_THROW() << "Dynamic batch is not supported";
};
/**
* @brief Sets the pointer to executable network internal.
* @note Needed to correctly handle ownership between objects.
* @param[in] exeNetwork The executable network
*/
void setPointerToExecutableNetworkInternal(std::shared_ptr<IExecutableNetworkInternal> exeNetwork) {
_exeNetwork = exeNetwork;
}
/**
* @brief Checks that both inputs and outputs blob are valid. Throws an exception if they are not.
*/
virtual void checkBlobs() {
for (auto const& input : _inputs) {
checkBlob(input.second, input.first, true);
}
for (auto const& output : _outputs) {
checkBlob(output.second, output.first, false);
}
}
std::vector<IVariableStateInternal::Ptr> QueryState() override {
// meaning base plugin reports as no state available - plugin owners need to create proper override of this
IE_THROW() << "Plugin doesn't override QueryState";
return {};
}
protected:
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
InferenceEngine::BlobMap _inputs; //!< A map of user passed blobs for network inputs
InferenceEngine::BlobMap _deviceInputs; //!< A map of actual network inputs, in plugin specific format
InferenceEngine::BlobMap _outputs; //!< A map of user passed blobs for network outputs
std::map<std::string, PreProcessDataPtr> _preProcData; //!< A map of pre-process data per input
int m_curBatch; //!< Current batch value used in dynamic batching
/**
* @brief A shared pointer to ExecutableNetworkInternal interface
* @note Needed to correctly handle ownership between objects.
*/
std::shared_ptr<IExecutableNetworkInternal> _exeNetwork;
/**
* @brief Checks and executes input data pre-processing if needed.
* @param inputs Inputs blobs to perform preprocessing on
* @param serial Whether to use multiple threads to execute the step
*/
void execDataPreprocessing(InferenceEngine::BlobMap& preprocessedBlobs, bool serial = false) {
for (auto& input : preprocessedBlobs) {
// If there is a pre-process entry for an input then it must be pre-processed
// using preconfigured resize algorithm.
auto it = _preProcData.find(input.first);
if (it != _preProcData.end()) {
_preProcData[input.first]->execute(input.second, _networkInputs[input.first]->getPreProcess(), serial,
m_curBatch);
}
}
}
/**
* @brief Helper function to find input or output blob by name
* @param name A name of input or output blob.
* @param foundInput A pointer to input information if found.
* @param foundOutput A pointer to output DataPtr if found.
* @return `True` - if loaded network has input with provided name,
* `false` - if loaded network has output with provided name
* @throws [parameter_mismatch] exception if input and output has the same name
* @throws [not_found] exception if there is no input and output layers with given name
*/
bool findInputAndOutputBlobByName(const std::string& name, InputInfo::Ptr& foundInput, DataPtr& foundOutput) const {
foundInput = nullptr;
foundOutput = nullptr;
if (_networkOutputs.empty()) {
IE_THROW() << "Internal error: network outputs is not set";
}
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
return pair.first == name;
});
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
[&](const std::pair<std::string, DataPtr>& pair) {
return pair.first == name;
});
if (foundOutputPair == std::end(_networkOutputs) && (foundInputPair == std::end(_networkInputs))) {
IE_THROW(NotFound) << "Failed to find input or output with name: \'" << name << "\'";
}
if (foundInputPair != std::end(_networkInputs)) {
foundInput = foundInputPair->second;
return true;
} else {
foundOutput = foundOutputPair->second;
return false;
}
}
/**
* @brief Check that @p blob is valid. Throws an exception if it's not.
*
* @param[in] blob The blob to check
* @param[in] name The name of input or output depending of if the @p blob is input or output
* @param[in] isInput Indicates if @p is input
* @param[in] refDims The reference dims, empty if not specified
*/
void checkBlob(const Blob::Ptr& blob, const std::string& name, bool isInput, const SizeVector& refDims = {}) const {
std::string bType = isInput ? "Input" : "Output";
std::string sType = isInput ? "input" : "output";
std::string strNotAllocated(bType + " data was not allocated.");
std::string strNotMatched("The " + sType + " blob size is not equal to the network " + sType + " size");
if (!blob) {
IE_THROW() << strNotAllocated;
}
size_t refSize;
if (refDims.empty()) {
SizeVector dims;
if (isInput) {
auto foundInputPair = std::find_if(std::begin(_networkInputs), std::end(_networkInputs),
[&](const std::pair<std::string, InputInfo::Ptr>& pair) {
return pair.first == name;
});
if (foundInputPair == std::end(_networkInputs)) {
IE_THROW(NotFound) << "Failed to find input with name: \'" << name << "\'";
}
dims = foundInputPair->second->getTensorDesc().getDims();
refSize = foundInputPair->second->getTensorDesc().getLayout() != SCALAR
? details::product(dims)
: 1;
} else {
auto foundOutputPair = std::find_if(std::begin(_networkOutputs), std::end(_networkOutputs),
[&](const std::pair<std::string, DataPtr>& pair) {
return pair.first == name;
});
if (foundOutputPair == std::end(_networkOutputs)) {
IE_THROW(NotFound) << "Failed to find output with name: \'" << name << "\'";
}
dims = foundOutputPair->second->getTensorDesc().getDims();
refSize = foundOutputPair->second->getTensorDesc().getLayout() != SCALAR
? details::product(dims)
: 1;
}
} else {
refSize = details::product(refDims);
}
if (refSize != blob->size()) {
IE_THROW() << strNotMatched + ": got " << blob->size() << " expecting " << refSize;
}
const bool remoteBlobPassed = blob->is<RemoteBlob>();
if (!remoteBlobPassed && blob->buffer() == nullptr) IE_THROW() << strNotAllocated;
}
/**
* @brief Checks whether pre-processing step is required for a given input
* @param info InputInfo corresponding to input blob
* @param userBlob Input Blob object corresponding to input info
* @param deviceBlob Blob object in plugin's desired format
* @return `True` if pre-processing is required, `false` otherwise
*/
bool preProcessingRequired(const InputInfo::Ptr& info, const Blob::Ptr& userBlob, const Blob::Ptr& deviceBlob = nullptr) {
// pre-processing is required if:
// 1. resize algorithm is specified (resize required)
// 2. color format specified:
// 2.a. color format is not equal to network's expected (color conversion required)
// 2.b. network's layout != blob's layout (reorder required)
// 3. precision conversion is required
const auto& preProcessInfo = info->getPreProcess();
const auto inputColorFormat = preProcessInfo.getColorFormat();
// FIXME: support other network's input formats once the API is ready. Assuming input is in
// the BGR format by default
const auto networkColorFormat = ColorFormat::BGR;
const bool colorFormatSpecified = inputColorFormat != ColorFormat::RAW;
auto blob_layout = [](const Blob::Ptr& b) { return b->getTensorDesc().getLayout(); };
auto blob_prec = [](const Blob::Ptr& b) { return b->getTensorDesc().getPrecision();};
auto dst_layout = deviceBlob ? blob_layout(deviceBlob) : info->getLayout();
auto dst_prec = deviceBlob ? blob_prec(deviceBlob) : info->getPrecision();
//FIXME: remove the first part to allow any needed conversion?
const bool need_layout_conv = (colorFormatSpecified || deviceBlob) &&
(blob_layout(userBlob) != dst_layout);
return preProcessInfo.getResizeAlgorithm() != ResizeAlgorithm::NO_RESIZE ||
(colorFormatSpecified && inputColorFormat != networkColorFormat) ||
need_layout_conv ||
(blob_prec(userBlob) != dst_prec);
}
void addInputPreProcessingFor(const std::string& name, Blob::Ptr const& from, const Blob::Ptr& to) {
auto ppDataIt = _preProcData.find(name);
if (ppDataIt == _preProcData.end()) {
ppDataIt = (_preProcData.emplace(name, CreatePreprocDataHelper())).first;
}
auto& preproc_ptr = ppDataIt->second;
preproc_ptr->isApplicable(from, to);
// Stores the given blob as ROI blob. It will be used to fill in network input
// during pre-processingstd::map<std::string, InferenceEngineProfileInfo>
preproc_ptr->setRoiBlob(from);
}
};
} // namespace InferenceEngine

View File

@ -4,6 +4,7 @@
#pragma once
#include <ie_iexecutable_network.hpp>
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
#include <ie_iinfer_request.hpp>
#include <ie_parameter.hpp>
@ -11,9 +12,10 @@
#include <memory>
#include <string>
#include <vector>
#include <cpp/ie_cnn_network.h>
namespace InferenceEngine {
class IInferRequestInternal;
/**
* @interface IExecutableNetworkInternal
* @brief An internal API of executable network to be implemented by plugin,
@ -52,7 +54,7 @@ public:
* Note: the returned request will have allocated input and output blobs (that can be changed later)
* @return shared_ptr for the created request
*/
virtual IInferRequest::Ptr CreateInferRequest() = 0;
virtual std::shared_ptr<IInferRequestInternal> CreateInferRequest() = 0;
/**
* @deprecated Use IExecutableNetworkInternal::Export(std::ostream& networkModel)

View File

@ -1,66 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_iinfer_request.hpp>
#include <map>
#include <memory>
#include <string>
#include "ie_iinfer_request_internal.hpp"
namespace InferenceEngine {
/**
* @interface IAsyncInferRequestInternal
* @ingroup ie_dev_api_async_infer_request_api
* @brief An internal API of asynchronous inference request to be implemented by plugin,
* which is used in InferRequestBase forwarding mechanism
*/
class IAsyncInferRequestInternal : virtual public IInferRequestInternal {
public:
/**
* @brief A shared pointer to IAsyncInferRequestInternal interface
*/
typedef std::shared_ptr<IAsyncInferRequestInternal> Ptr;
/**
* @brief Start inference of specified input(s) in asynchronous mode
* @note The method returns immediately. Inference starts also immediately.
*/
virtual void StartAsync() = 0;
/**
* @brief Waits for the result to become available. Blocks until specified millis_timeout has elapsed or the result
* becomes available, whichever comes first.
* @param millis_timeout - maximum duration in milliseconds to block for
* @note There are special cases when millis_timeout is equal some value of WaitMode enum:
* * STATUS_ONLY - immediately returns request status (IInferRequest::RequestStatus). It doesn't block or interrupt
* current thread.
* * RESULT_READY - waits until inference result becomes available
* @return A status code
*/
virtual StatusCode Wait(int64_t millis_timeout) = 0;
/**
* @brief Get arbitrary data for the request
* @param data A pointer to a pointer to arbitrary data
*/
virtual void GetUserData(void** data) = 0;
/**
* @brief Set arbitrary data for the request
* @param data A pointer to a pointer to arbitrary data
*/
virtual void SetUserData(void* data) = 0;
/**
* @brief Set callback function which will be called on success or failure of asynchronous request
* @param callback - function to be called with the following description:
*/
virtual void SetCompletionCallback(IInferRequest::CompletionCallback callback) = 0;
};
} // namespace InferenceEngine

View File

@ -4,52 +4,67 @@
#pragma once
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
#include <ie_blob.h>
#include <ie_common.h>
#include <ie_preprocess.hpp>
#include <ie_preprocess_data.hpp>
#include <ie_input_info.hpp>
#include <ie_icnn_network.hpp>
#include <map>
#include <memory>
#include <string>
namespace InferenceEngine {
class IExecutableNetworkInternal;
class IVariableStateInternal;
/**
* @interface IInferRequestInternal
* @brief An internal API of synchronous inference request to be implemented by plugin,
* which is used in InferRequestBase forwarding mechanism
* @ingroup ie_dev_api_infer_request_api
*/
class IInferRequestInternal {
class INFERENCE_ENGINE_API_CLASS(IInferRequestInternal) : public std::enable_shared_from_this<IInferRequestInternal> {
public:
/**
* @brief A shared pointer to a IInferRequestInternal interface
*/
typedef std::shared_ptr<IInferRequestInternal> Ptr;
using Ptr = std::shared_ptr<IInferRequestInternal>;
IInferRequestInternal() = default;
/**
* @brief Destroys the object.
* @brief Constructs a new instance.
* @param[in] networkInputs The network inputs info
* @param[in] networkOutputs The network outputs data
*/
virtual ~IInferRequestInternal() = default;
IInferRequestInternal(const InputsDataMap& networkInputs, const OutputsDataMap& networkOutputs);
/**
* @brief Infers specified input(s) in synchronous mode
* @note blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
*/
virtual void Infer() = 0;
virtual void Infer();
/**
* @brief The minimal infer function to be implemented by plugins. It infers specified input(s) in synchronous mode
* @note
* * This method is used in IInferRequestInternal::Infer, which calls the common code first and after uses this
* plugin dependent implementation.
* * Blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
*/
virtual void InferImpl();
/**
* @brief Cancel current inference request execution
*/
virtual void Cancel() = 0;
virtual void Cancel();
/**
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
* Note: not all plugins may provide meaningful data
* @return Returns a map of layer names to profiling information for that layer.
* @return - a map of layer names to profiling information for that layer.
*/
virtual std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const = 0;
virtual std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const;
/**
* @brief Set input/output data to infer
@ -58,16 +73,16 @@ public:
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input
* precision and size.
*/
virtual void SetBlob(const std::string& name, const Blob::Ptr& data) = 0;
virtual void SetBlob(const std::string& name, const Blob::Ptr& data);
/**
* @brief Get input/output data to infer
* @note Memory allocation doesn't happen
* @param name - a name of input or output blob.
* @return Returns input or output blob. The type of Blob must correspond to the network input
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input
* precision and size.
*/
virtual Blob::Ptr GetBlob(const std::string& name) = 0;
virtual Blob::Ptr GetBlob(const std::string& name);
/**
* @brief Sets pre-process for input data
@ -75,26 +90,138 @@ public:
* @param data - a reference to input or output blob. The type of Blob must correspond to the network input precision and size.
* @param info Preprocess info for blob.
*/
virtual void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) = 0;
virtual void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info);
/**
* @brief Gets pre-process for input data
* @param name Name of input blob.
* @return Returns constant reference to PreProcessInfo structure
* @param info pointer to a pointer to PreProcessInfo structure
*/
virtual const PreProcessInfo& GetPreProcess(const std::string& name) const = 0;
virtual const PreProcessInfo& GetPreProcess(const std::string& name) const;
/**
* @brief Sets new batch size when dynamic batching is enabled in executable network that created this request.
* @param batch - new batch size to be used by all the following inference calls for this request.
*/
virtual void SetBatch(int batch) = 0;
virtual void SetBatch(int batch);
/**
* @brief Queries memory states.
* @return Returns memory states
*/
virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
virtual std::vector<std::shared_ptr<IVariableStateInternal>> QueryState();
/**
* @brief Start inference of specified input(s) in asynchronous mode
* @note The method returns immediately. Inference starts also immediately.
*/
virtual void StartAsync();
/**
* @brief The minimal asynchronous inference function to be implemented by plugins.
* It starts inference of specified input(s) in asynchronous mode
* @note
* * The methos is used in AsyncInferRequestInternal::StartAsync which performs common steps first and
* calls plugin dependent implementation of this method after.
* * It returns immediately. Inference starts also immediately.
*/
virtual void StartAsyncImpl();
/**
* @brief Waits for the result to become available. Blocks until specified millis_timeout has elapsed or the result
* becomes available, whichever comes first.
* @param millis_timeout - maximum duration in milliseconds to block for
* @note There are special cases when millis_timeout is equal some value of WaitMode enum:
* * STATUS_ONLY - immediately returns request status (IInferRequest::RequestStatus). It doesn't block or interrupt
* current thread.
* * RESULT_READY - waits until inference result becomes available
* @return A status code
*/
virtual StatusCode Wait(int64_t millis_timeout);
/**
* @brief Alias for callback type
*/
using Callback = std::function<void(std::exception_ptr)>;
/**
* @brief Set callback function which will be called on success or failure of asynchronous request
* @param callback - function to be called with the following description:
*/
virtual void SetCallback(Callback callback);
/**
* @brief Check that @p blob is valid. Throws an exception if it's not.
*
* @param[in] blob The blob to check
* @param[in] name The name of input or output depending of if the @p blob is input or output
* @param[in] isInput Indicates if @p is input
* @param[in] refDims The reference dims, empty if not specified
*/
void checkBlob(const Blob::Ptr& blob, const std::string& name, bool isInput, const SizeVector& refDims = {}) const;
/**
* @brief Check that all of the blobs is valid. Throws an exception if it's not.
*/
virtual void checkBlobs();
/**
* @brief Sets the pointer to executable network internal.
* @note Needed to correctly handle ownership between objects.
* @param[in] exeNetwork The executable network
*/
void setPointerToExecutableNetworkInternal(const std::shared_ptr<IExecutableNetworkInternal>& exeNetwork);
protected:
/**
* @brief Checks and executes input data pre-processing if needed.
* @param inputs Inputs blobs to perform preprocessing on
* @param serial Whether to use multiple threads to execute the step
*/
void execDataPreprocessing(InferenceEngine::BlobMap& preprocessedBlobs, bool serial = false);
/**
* @brief Helper function to find input or output blob by name
* @param name A name of input or output blob.
* @param foundInput A pointer to input information if found.
* @param foundOutput A pointer to output DataPtr if found.
* @return `True` - if loaded network has input with provided name,
* `false` - if loaded network has output with provided name
* @throws [parameter_mismatch] exception if input and output has the same name
* @throws [not_found] exception if there is no input and output layers with given name
*/
bool findInputAndOutputBlobByName(const std::string& name, InputInfo::Ptr& foundInput, DataPtr& foundOutput) const;
/**
* @brief Checks whether pre-processing step is required for a given input
* @param info InputInfo corresponding to input blob
* @param userBlob Input Blob object corresponding to input info
* @param deviceBlob Blob object in plugin's desired format
* @return `True` if pre-processing is required, `false` otherwise
*/
bool preProcessingRequired(const InputInfo::Ptr& info, const Blob::Ptr& userBlob, const Blob::Ptr& deviceBlob = nullptr);
void addInputPreProcessingFor(const std::string& name, Blob::Ptr const& from, const Blob::Ptr& to);
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
InferenceEngine::BlobMap _inputs; //!< A map of user passed blobs for network inputs
InferenceEngine::BlobMap _deviceInputs; //!< A map of actual network inputs, in plugin specific format
InferenceEngine::BlobMap _outputs; //!< A map of user passed blobs for network outputs
std::map<std::string, PreProcessDataPtr> _preProcData; //!< A map of pre-process data per input
int m_curBatch = -1; //!< Current batch value used in dynamic batching
/**
* @brief A shared pointer to ExecutableNetworkInternal interface
* @note Needed to correctly handle ownership between objects.
*/
std::shared_ptr<IExecutableNetworkInternal> _exeNetwork;
Callback _callback; //!< A callback
/**
* @brief Destroys the object.
*/
~IInferRequestInternal();
};
} // namespace InferenceEngine

View File

@ -17,7 +17,7 @@ public:
const InferenceEngine::ITaskExecutor::Ptr &callbackExecutor,
const InferenceEngine::ITaskExecutor::Ptr &taskExecutorGetResult);
~MyriadAsyncInferRequest() override;
~MyriadAsyncInferRequest();
private:
MyriadInferRequest::Ptr _request;
InferenceEngine::ITaskExecutor::Ptr _taskExecutorGetResult;

View File

@ -61,7 +61,7 @@ public:
}
}
ie::InferRequestInternal::Ptr CreateInferRequestImpl(ie::InputsDataMap networkInputs,
ie::IInferRequestInternal::Ptr CreateInferRequestImpl(ie::InputsDataMap networkInputs,
ie::OutputsDataMap networkOutputs) override {
if (_device == nullptr || !_device->isBooted()) {
IE_THROW() << "Can not create infer request: there is no available devices with platform "
@ -73,7 +73,7 @@ public:
_graphMetaData.stagesMeta, _config, _log, _executor);
}
ie::IInferRequest::Ptr CreateInferRequest() override {
ie::IInferRequestInternal::Ptr CreateInferRequest() override {
ie::IInferRequest::Ptr asyncRequest;
if (_device == nullptr || !_device->isBooted()) {
IE_THROW() << "Can not create infer request: there is no available devices with platform "
@ -86,11 +86,8 @@ public:
_executor);
syncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
auto taskExecutorGetResult = getNextTaskExecutor();
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
return std::make_shared<MyriadAsyncInferRequest>(
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
asyncRequest.reset(new ie::InferRequestBase(asyncThreadSafeImpl));
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
}
void Export(std::ostream& model) override {

View File

@ -33,7 +33,7 @@ MyriadInferRequest::MyriadInferRequest(GraphDesc &graphDesc,
const MyriadConfig& myriadConfig,
const Logger::Ptr &log,
const MyriadExecutorPtr &executor) :
InferRequestInternal(networkInputs, networkOutputs), _executor(executor),
IInferRequestInternal(networkInputs, networkOutputs), _executor(executor),
_log(log), _stagesMetaData(blobMetaData), _config(myriadConfig),
_inputInfo(compilerInputsInfo), _outputInfo(compilerOutputsInfo),
_graphDesc(graphDesc) {

View File

@ -10,7 +10,6 @@
#include <memory>
#include <ie_common.h>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <vpu/utils/logger.hpp>
@ -22,7 +21,7 @@
namespace vpu {
namespace MyriadPlugin {
class MyriadInferRequest : public InferenceEngine::InferRequestInternal {
class MyriadInferRequest : public InferenceEngine::IInferRequestInternal {
MyriadExecutorPtr _executor;
Logger::Ptr _log;
std::vector<StageMetaInfo> _stagesMetaData;

View File

@ -13,11 +13,6 @@ using namespace InferenceEngine;
using namespace InferenceEngine::details;
TEST(InferRequestCPPTests, throwsOnInitWithNull) {
IInferRequest::Ptr nlptr = nullptr;
ASSERT_THROW(InferRequest req(nlptr), InferenceEngine::Exception);
}
TEST(InferRequestCPPTests, throwsOnUninitializedSetBlob) {
InferRequest req;
ASSERT_THROW(req.SetBlob({}, {}), InferenceEngine::Exception);
@ -81,7 +76,7 @@ TEST(InferRequestCPPTests, throwsOnUninitializedSetCompletionCallback) {
TEST(InferRequestCPPTests, throwsOnUninitializedCast) {
InferRequest req;
ASSERT_THROW((void)static_cast<IInferRequest::Ptr &>(req), InferenceEngine::Exception);
ASSERT_THROW((void)static_cast<IInferRequest::Ptr>(req), InferenceEngine::Exception);
}
TEST(InferRequestCPPTests, throwsOnUninitializedQueryState) {

View File

@ -98,11 +98,12 @@ class MockExecutableNetwork : public ExecutableNetworkInternal {
public:
MockExecutableNetwork() {}
MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel));
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr());
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap());
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap());
MOCK_CONST_METHOD1(GetConfig, Parameter(const std::string& name));
MOCK_CONST_METHOD1(GetMetric, Parameter(const std::string& name));
MOCK_METHOD2(CreateInferRequestImpl, IInferRequestInternal::Ptr(InputsDataMap, OutputsDataMap));
};
//------------------------------------------------------
@ -251,9 +252,8 @@ public:
EXPECT_CALL(*mock, GetOutputsInfo()).Times(AnyNumber()).WillRepeatedly(Return(ConstOutputsDataMap{}));
EXPECT_CALL(*mock, GetConfig(PluginConfigParams::KEY_PERF_COUNT)).Times(AnyNumber()).WillRepeatedly(Return(Parameter{PluginConfigParams::NO}));
EXPECT_CALL(*mock, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS))).Times(AnyNumber()).WillRepeatedly(Return(Parameter{1u}));
auto ptr = std::make_shared<MockIInferRequest>();
EXPECT_CALL(*ptr, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*ptr, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
auto ptr = std::make_shared<MockIInferRequestInternal>();
EXPECT_CALL(*ptr, SetCallback(_)).Times(AnyNumber());
EXPECT_CALL(*mock, CreateInferRequest()).Times(AnyNumber()).WillRepeatedly(Return(ptr));
return mock;
}
@ -345,10 +345,8 @@ private:
}));
EXPECT_CALL(*net, CreateInferRequest()).Times(AnyNumber())
.WillRepeatedly(Invoke([&]() {
std::vector<std::string> res;
auto inferReq = std::make_shared<MockIInferRequest>();
EXPECT_CALL(*inferReq, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*inferReq, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
auto inferReq = std::make_shared<MockIInferRequestInternal>();
EXPECT_CALL(*inferReq, SetCallback(_)).Times(AnyNumber());
return inferReq;
}));
}

View File

@ -122,20 +122,12 @@ TEST_P(CallbackTests, returnGeneralErrorIfCallbackThrowException) {
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest
InferenceEngine::IInferRequest::Ptr req = static_cast<InferenceEngine::IInferRequest::Ptr &>(execNet.CreateInferRequest());
req->SetCompletionCallback(
[](InferenceEngine::IInferRequest::Ptr, InferenceEngine::StatusCode status) {
IE_THROW() << "returnGeneralErrorIfCallbackThrowException";
});
auto req = execNet.CreateInferRequest();
req.SetCompletionCallback([] {
IE_THROW(GeneralError);
});
InferenceEngine::ResponseDesc resp;
req->StartAsync(&resp);
InferenceEngine::StatusCode waitStatus = InferenceEngine::StatusCode::INFER_NOT_STARTED;
while (InferenceEngine::StatusCode::RESULT_NOT_READY == waitStatus ||
InferenceEngine::StatusCode::INFER_NOT_STARTED == waitStatus) {
waitStatus = req->Wait(InferenceEngine::IInferRequest::WaitMode::STATUS_ONLY, &resp);
}
ASSERT_EQ(InferenceEngine::StatusCode::GENERAL_ERROR, waitStatus);
ASSERT_NE(std::string(resp.msg).find("returnGeneralErrorIfCallbackThrowException"), std::string::npos);
ASSERT_NO_THROW(req.StartAsync());
ASSERT_THROW(req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY), InferenceEngine::GeneralError);
}
} // namespace BehaviorTestsDefinitions

View File

@ -131,11 +131,9 @@ TEST_P(InferRequestOutputTests, canStartAsyncInferWithGetInOut) {
InferenceEngine::InferRequest req;
ASSERT_NO_THROW(req = execNet.CreateInferRequest());
InferenceEngine::Blob::Ptr inputBlob = req.GetBlob(cnnNet.getInputsInfo().begin()->first);
InferenceEngine::StatusCode sts;
ASSERT_NO_THROW(req.Infer());
ASSERT_NO_THROW(req.StartAsync());
sts = req.Wait(500);
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts);
ASSERT_NO_THROW(req.Wait());
InferenceEngine::Blob::Ptr outputBlob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first);
}
} // namespace BehaviorTestsDefinitions

View File

@ -12,14 +12,10 @@
#include "unit_test_utils/mocks/cpp_interfaces/mock_task_executor.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_network_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_default.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_inference_plugin_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_ivariable_state_internal.hpp"

View File

@ -12,13 +12,11 @@
#include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
using namespace InferenceEngine;
class MockAsyncInferRequestDefault : public AsyncInferRequestThreadSafeDefault {
public:
MockAsyncInferRequestDefault(InferRequestInternal::Ptr request,
MockAsyncInferRequestDefault(IInferRequestInternal::Ptr request,
const ITaskExecutor::Ptr &taskExecutor,
const ITaskExecutor::Ptr &callbackExecutor)
: AsyncInferRequestThreadSafeDefault(request, taskExecutor, callbackExecutor) {}

View File

@ -1,37 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include <map>
#include <gmock/gmock.h>
#include <ie_iinfer_request.hpp>
#include <cpp_interfaces/impl/ie_infer_async_request_internal.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
using namespace InferenceEngine;
class MockAsyncInferRequestInternal : public AsyncInferRequestInternal {
public:
using AsyncInferRequestInternal::SetBlob;
MockAsyncInferRequestInternal(InputsDataMap networkInputs, OutputsDataMap networkOutputs)
: AsyncInferRequestInternal(networkInputs, networkOutputs) {}
MOCK_METHOD0(StartAsyncImpl, void());
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
MOCK_METHOD1(GetUserData, void(void **));
MOCK_METHOD1(SetUserData, void(void *));
MOCK_METHOD0(InferImpl, void());
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
MOCK_METHOD1(setNetworkInputs, void(InputsDataMap));
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
MOCK_METHOD0(Cancel, void());
MOCK_METHOD0(Cancel_ThreadUnsafe, void());
};

View File

@ -13,12 +13,10 @@
#include "ie_iexecutable_network.hpp"
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <gmock/gmock.h>
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
using namespace InferenceEngine;
@ -26,7 +24,7 @@ class MockExecutableNetworkInternal : public ExecutableNetworkInternal {
public:
MOCK_METHOD1(setNetworkInputs, void(InputsDataMap));
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
MOCK_METHOD1(Export, void(const std::string &));
MOCK_METHOD0(GetExecGraphInfo, CNNNetwork(void));
void WrapOstreamExport(std::ostream& networkModel) {
@ -36,4 +34,5 @@ public:
void ExportImpl(std::ostream& networkModel) override {
networkModel << exportString << std::endl;
}
MOCK_METHOD2(CreateInferRequestImpl, IInferRequestInternal::Ptr(InputsDataMap, OutputsDataMap));
};

View File

@ -1,23 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include "ie_iexecutable_network.hpp"
#include <gmock/gmock.h>
#include <string>
#include <vector>
#include <cpp_interfaces/impl/ie_executable_network_thread_safe_async_only.hpp>
using namespace InferenceEngine;
class MockExecutableNetworkThreadSafeAsyncOnly : public ExecutableNetworkThreadSafeAsyncOnly {
public:
MOCK_METHOD2(CreateAsyncInferRequestImpl,
AsyncInferRequestInternal::Ptr(InputsDataMap networkInputs, OutputsDataMap networkOutputs));
MOCK_METHOD1(Export, void(const std::string &));
void Export(std::ostream&) override {}
};

View File

@ -15,7 +15,7 @@ using namespace InferenceEngine;
class MockExecutableNetworkThreadSafe : public ExecutableNetworkThreadSafeDefault {
public:
MOCK_METHOD2(CreateInferRequestImpl,
std::shared_ptr<InferRequestInternal>(InputsDataMap networkInputs, OutputsDataMap networkOutputs));
std::shared_ptr<IInferRequestInternal>(InputsDataMap networkInputs, OutputsDataMap networkOutputs));
MOCK_METHOD1(Export, void(const std::string &));
void Export(std::ostream &) override {}
};

View File

@ -1,27 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <map>
#include "ie_iexecutable_network.hpp"
#include <gmock/gmock.h>
#include <string>
#include <vector>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
using namespace InferenceEngine;
class MockInferRequestInternal : public InferRequestInternal {
public:
MockInferRequestInternal(InputsDataMap networkInputs, OutputsDataMap networkOutputs)
: InferRequestInternal(networkInputs, networkOutputs) {}
using InferRequestInternal::SetBlob;
using InferRequestInternal::GetBlob;
MOCK_METHOD0(InferImpl, void());
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
MOCK_METHOD0(checkBlobs, void());
MOCK_METHOD0(Cancel, void());
};

View File

@ -1,32 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <gmock/gmock.h>
#include <map>
#include <string>
#include <vector>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
class MockIAsyncInferRequestInternal : public InferenceEngine::IAsyncInferRequestInternal {
public:
MOCK_METHOD0(StartAsync, void());
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
MOCK_METHOD1(GetUserData, void(void **));
MOCK_METHOD1(SetUserData, void(void *));
MOCK_METHOD0(Infer, void());
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngine::InferenceEngineProfileInfo>());
MOCK_METHOD2(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &));
MOCK_METHOD1(GetBlob, InferenceEngine::Blob::Ptr(const std::string&));
MOCK_METHOD3(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &, const InferenceEngine::PreProcessInfo&));
MOCK_CONST_METHOD1(GetPreProcess, const InferenceEngine::PreProcessInfo&(const std::string&));
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
MOCK_METHOD0(Cancel, void());
};

View File

@ -14,10 +14,8 @@
#include "ie_icnn_network.hpp"
#include "ie_iexecutable_network.hpp"
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
using namespace InferenceEngine;
@ -25,7 +23,7 @@ class MockIExecutableNetworkInternal : public IExecutableNetworkInternal {
public:
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap(void));
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap(void));
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
MOCK_METHOD1(Export, void(const std::string &));
void Export(std::ostream &) override {};
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));

View File

@ -10,16 +10,24 @@
#include <string>
#include <vector>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/impl/ie_variable_state_internal.hpp>
class MockIInferRequestInternal : public InferenceEngine::IInferRequestInternal {
public:
using InferenceEngine::IInferRequestInternal::IInferRequestInternal;
MOCK_METHOD0(StartAsync, void());
MOCK_METHOD1(Wait, InferenceEngine::StatusCode(int64_t));
MOCK_METHOD0(Infer, void());
MOCK_CONST_METHOD1(GetPerformanceCounts, void(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &));
MOCK_METHOD2(SetBlob, void(const char *name, const InferenceEngine::Blob::Ptr &));
MOCK_METHOD2(GetBlob, void(const char *name, InferenceEngine::Blob::Ptr &));
MOCK_METHOD3(SetBlob, void(const char*, const InferenceEngine::Blob::Ptr&, const InferenceEngine::PreProcessInfo&));
MOCK_METHOD2(GetPreProcess, void(const char*, const InferenceEngine::PreProcessInfo**));
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngine::InferenceEngineProfileInfo>());
MOCK_METHOD2(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &));
MOCK_METHOD1(GetBlob, InferenceEngine::Blob::Ptr(const std::string&));
MOCK_METHOD3(SetBlob, void(const std::string&, const InferenceEngine::Blob::Ptr &, const InferenceEngine::PreProcessInfo&));
MOCK_CONST_METHOD1(GetPreProcess, const InferenceEngine::PreProcessInfo&(const std::string&));
MOCK_METHOD1(SetCallback, void(std::function<void(std::exception_ptr)>));
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD0(QueryState, std::vector<InferenceEngine::IVariableStateInternal::Ptr>());
MOCK_METHOD0(Cancel, void());
MOCK_METHOD0(StartAsyncImpl, void());
MOCK_METHOD0(InferImpl, void());
MOCK_METHOD0(checkBlobs, void());
};

View File

@ -14,7 +14,7 @@ class MockIInferencePlugin : public InferenceEngine::IInferencePlugin {
public:
MOCK_METHOD1(AddExtension, void(InferenceEngine::IExtensionPtr));
MOCK_METHOD2(LoadNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
const CNNNetwork&, const std::map<std::string, std::string>&));
const InferenceEngine::CNNNetwork&, const std::map<std::string, std::string>&));
MOCK_METHOD2(ImportNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
const std::string&, const std::map<std::string, std::string>&));
MOCK_METHOD1(SetConfig, void(const std::map<std::string, std::string> &));
@ -38,4 +38,7 @@ public:
MOCK_METHOD3(ImportNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
std::istream&, const InferenceEngine::RemoteContext::Ptr&,
const std::map<std::string, std::string>&));
MOCK_QUALIFIED_METHOD2(QueryNetwork, const,
InferenceEngine::QueryNetworkResult(const InferenceEngine::CNNNetwork&,
const std::map<std::string, std::string>&));
};

View File

@ -7,9 +7,8 @@
#include <cpp_interfaces/base/ie_executable_network_base.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_executable_thread_safe_default.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
using namespace ::testing;
using namespace std;
@ -18,85 +17,11 @@ using namespace InferenceEngine::details;
IE_SUPPRESS_DEPRECATED_START
class ExecutableNetworkThreadSafeAsyncOnlyTests : public ::testing::Test {
protected:
shared_ptr<MockExecutableNetworkThreadSafeAsyncOnly> mockExeNetwork;
shared_ptr<MockAsyncInferRequestInternal> mockAsyncInferRequestInternal;
shared_ptr<IExecutableNetwork> exeNetwork;
ResponseDesc dsc;
StatusCode sts;
virtual void TearDown() {
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockAsyncInferRequestInternal.get()));
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetwork.get()));
}
virtual void SetUp() {
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
exeNetwork = std::make_shared<ExecutableNetworkBase>(mockExeNetwork);
InputsDataMap networkInputs;
OutputsDataMap networkOutputs;
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
}
};
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsThreadSafeImplAndSetNetworkIO) {
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
ASSERT_NE(threadSafeReq, nullptr);
}
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, returnErrorIfInferThrowsException) {
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
EXPECT_NO_THROW(sts = req->Infer(&dsc));
ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
}
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, returnErrorIfStartAsyncThrowsException) {
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).WillOnce(Throw(std::runtime_error("")));
EXPECT_NO_THROW(sts = req->StartAsync(&dsc));
ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
}
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, canForwardStartAsyncAndInfer) {
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).Times(1);
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).Times(1);
EXPECT_NO_THROW(req->StartAsync(&dsc)) << dsc.msg;
EXPECT_NO_THROW(req->Infer(&dsc)) << dsc.msg;
}
TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, canForwardInferAndStartAsync) {
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).Times(1);
EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).Times(1);
EXPECT_NO_THROW(req->Infer(&dsc)) << dsc.msg;
EXPECT_NO_THROW(req->StartAsync(&dsc)) << dsc.msg;
}
class ExecutableNetworkThreadSafeTests : public ::testing::Test {
protected:
shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetwork;
shared_ptr<IExecutableNetwork> exeNetwork;
shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
ResponseDesc dsc;
StatusCode sts;
@ -110,7 +35,7 @@ protected:
exeNetwork = std::make_shared<ExecutableNetworkBase>(mockExeNetwork);
InputsDataMap networkInputs;
OutputsDataMap networkOutputs;
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(networkInputs, networkOutputs);
}
};

View File

@ -7,13 +7,16 @@
#include <gmock/gmock-generated-actions.h>
#include <cpp/ie_infer_request.hpp>
#include <cpp/ie_executable_network.hpp>
#include <ie_plugin_cpp.hpp>
#include <cpp_interfaces/exception2status.hpp>
#include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinference_plugin.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
#include "unit_test_utils/mocks/mock_iinfer_request.hpp"
#include "unit_test_utils/mocks/mock_not_empty_icnn_network.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
using namespace ::testing;
using namespace std;
@ -25,7 +28,7 @@ constexpr const char* MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME;
class InferRequestBaseTests : public ::testing::Test {
protected:
std::shared_ptr<MockIAsyncInferRequestInternal> mock_impl;
std::shared_ptr<MockIInferRequestInternal> mock_impl;
shared_ptr<IInferRequest> request;
ResponseDesc dsc;
@ -33,7 +36,7 @@ protected:
}
virtual void SetUp() {
mock_impl.reset(new MockIAsyncInferRequestInternal());
mock_impl.reset(new MockIInferRequestInternal());
request = std::make_shared<InferRequestBase>(mock_impl);
}
};
@ -55,42 +58,6 @@ TEST_F(InferRequestBaseTests, canCatchUnknownErrorInStartAsync) {
ASSERT_EQ(UNEXPECTED, request->StartAsync(nullptr));
}
// GetUserData
TEST_F(InferRequestBaseTests, canForwardGetUserData) {
void **data = nullptr;
EXPECT_CALL(*mock_impl.get(), GetUserData(data)).Times(1);
ASSERT_EQ(OK, request->GetUserData(data, &dsc));
}
TEST_F(InferRequestBaseTests, canReportErrorInGetUserData) {
EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
ASSERT_NE(request->GetUserData(nullptr, &dsc), OK);
ASSERT_STREQ(dsc.msg, "compare");
}
TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetUserData) {
EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(5));
ASSERT_EQ(UNEXPECTED, request->GetUserData(nullptr, nullptr));
}
// SetUserData
TEST_F(InferRequestBaseTests, canForwardSetUserData) {
void *data = nullptr;
EXPECT_CALL(*mock_impl.get(), SetUserData(data)).Times(1);
ASSERT_EQ(OK, request->SetUserData(data, &dsc));
}
TEST_F(InferRequestBaseTests, canReportErrorInSetUserData) {
EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
ASSERT_NE(request->SetUserData(nullptr, &dsc), OK);
ASSERT_STREQ(dsc.msg, "compare");
}
TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetUserData) {
EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(5));
ASSERT_EQ(UNEXPECTED, request->SetUserData(nullptr, nullptr));
}
// Wait
TEST_F(InferRequestBaseTests, canForwardWait) {
int64_t ms = 0;
@ -192,23 +159,27 @@ TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetBlob) {
// SetCompletionCallback
TEST_F(InferRequestBaseTests, canForwardSetCompletionCallback) {
InferenceEngine::IInferRequest::CompletionCallback callback = nullptr;
EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(callback)).Times(1);
EXPECT_CALL(*mock_impl.get(), SetCallback(_)).Times(1);
ASSERT_NO_THROW(request->SetCompletionCallback(callback));
}
TEST_F(InferRequestBaseTests, canReportErrorInSetCompletionCallback) {
EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(_)).WillOnce(Throw(std::runtime_error("compare")));
ASSERT_NO_THROW(request->SetCompletionCallback(nullptr));
EXPECT_CALL(*mock_impl.get(), SetCallback(_)).WillOnce(Throw(std::runtime_error("compare")));
ASSERT_NE(request->SetCompletionCallback(nullptr), OK);
}
class InferRequestTests : public ::testing::Test {
protected:
std::shared_ptr<MockIInferRequest> mock_request;
InferRequest::Ptr requestWrapper;
std::shared_ptr<MockIInferRequestInternal> mock_request;
InferRequest request;
ResponseDesc dsc;
shared_ptr<MockAsyncInferRequestInternal> mockInferRequestInternal;
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
InferenceEngine::ExecutableNetwork exeNetwork;
MockIInferencePlugin* mockIPlugin;
InferencePlugin plugin;
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
MockNotEmptyICNNNetwork mockNotEmptyNet;
std::string _incorrectName;
std::string _inputName;
@ -216,15 +187,21 @@ protected:
std::string _inputDataNotAllocatedError;
std::string _inputDataIsEmptyError;
virtual void TearDown() {
void TearDown() override {
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_request.get()));
EXPECT_TRUE(Mock::VerifyAndClearExpectations(requestWrapper.get()));
request = {};
}
virtual void SetUp() {
mock_request = make_shared<MockIInferRequest>();
requestWrapper = make_shared<InferRequest>(mock_request);
void SetUp() override {
mock_request = make_shared<MockIInferRequestInternal>();
mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
ON_CALL(*mockIExeNet, CreateInferRequest()).WillByDefault(Return(mock_request));
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
exeNetwork = plugin.LoadNetwork({}, {});
request = exeNetwork.CreateInferRequest();
_incorrectName = "incorrect_name";
_inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
_failedToFindInOutError =
@ -235,15 +212,20 @@ protected:
+ _inputName + "\'";
}
InferRequest::Ptr getInferRequestWithMockImplInside() {
InferRequest getInferRequestWithMockImplInside() {
IInferRequest::Ptr inferRequest;
InputsDataMap inputsInfo;
mockNotEmptyNet.getInputsInfo(inputsInfo);
OutputsDataMap outputsInfo;
mockNotEmptyNet.getOutputsInfo(outputsInfo);
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
inferRequest = std::make_shared<InferRequestBase>(mockInferRequestInternal);
return make_shared<InferRequest>(inferRequest);
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(inputsInfo, outputsInfo);
auto mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
ON_CALL(*mockIExeNet, CreateInferRequest()).WillByDefault(Return(mockInferRequestInternal));
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
auto plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
auto exeNetwork = plugin.LoadNetwork({}, {});
return exeNetwork.CreateInferRequest();
}
std::string getExceptionMessage(std::function<void()> function) {
@ -274,60 +256,52 @@ protected:
}
};
// constructor tests
TEST_F(InferRequestTests, constructorsTests) {
// construction from the non-null should not throw
ASSERT_NO_THROW(InferRequest req(mock_request));
IInferRequest::Ptr tmp;
// InferRequest's "actual" is nullptr, let's check it throws on construction
ASSERT_THROW(InferRequest req(tmp), Exception);
}
// StartAsync
TEST_F(InferRequestTests, canForwardStartAsync) {
EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(OK));
ASSERT_NO_THROW(requestWrapper->StartAsync());
EXPECT_CALL(*mock_request.get(), StartAsync());
ASSERT_NO_THROW(request.StartAsync());
}
TEST_F(InferRequestTests, throwsIfStartAsyncReturnNotOK) {
EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(GENERAL_ERROR));
ASSERT_THROW(requestWrapper->StartAsync(), Exception);
EXPECT_CALL(*mock_request.get(), StartAsync()).WillOnce(Throw(GeneralError{""}));
ASSERT_THROW(request.StartAsync(), Exception);
}
// Wait
TEST_F(InferRequestTests, canForwardWait) {
int64_t ms = 0;
EXPECT_CALL(*mock_request.get(), Wait(ms, _)).WillOnce(Return(OK));
ASSERT_TRUE(OK == requestWrapper->Wait(ms));
EXPECT_CALL(*mock_request.get(), Wait(_)).WillOnce(Return(OK));
ASSERT_TRUE(OK == request.Wait(ms));
}
TEST_F(InferRequestTests, canForwardStatusFromWait) {
EXPECT_CALL(*mock_request.get(), Wait(_, _)).WillOnce(Return(RESULT_NOT_READY));
ASSERT_EQ(requestWrapper->Wait(0), RESULT_NOT_READY);
EXPECT_CALL(*mock_request.get(), Wait(_)).WillOnce(Return(RESULT_NOT_READY));
ASSERT_EQ(request.Wait(0), RESULT_NOT_READY);
}
// Infer
TEST_F(InferRequestTests, canForwardInfer) {
EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(OK));
ASSERT_NO_THROW(requestWrapper->Infer());
EXPECT_CALL(*mock_request.get(), Infer());
ASSERT_NO_THROW(request.Infer());
}
TEST_F(InferRequestTests, throwsIfInferReturnNotOK) {
EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(GENERAL_ERROR));
ASSERT_THROW(requestWrapper->Infer(), Exception);
EXPECT_CALL(*mock_request.get(), Infer()).WillOnce(Throw(GeneralError{""}));
ASSERT_THROW(request.Infer(), Exception);
}
// GetPerformanceCounts
TEST_F(InferRequestTests, canForwardGetPerformanceCounts) {
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(OK));
ASSERT_NO_THROW(info = requestWrapper->GetPerformanceCounts());
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts()).WillOnce(Return(info));
ASSERT_NO_THROW(info = request.GetPerformanceCounts());
}
TEST_F(InferRequestTests, throwsIfGetPerformanceCountsReturnNotOK) {
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(GENERAL_ERROR));
ASSERT_THROW(info = requestWrapper->GetPerformanceCounts(), Exception);
EXPECT_CALL(*mock_request.get(), GetPerformanceCounts()).WillOnce(Throw(GeneralError{""}));
ASSERT_THROW(info = request.GetPerformanceCounts(), Exception);
}
MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
@ -343,15 +317,15 @@ TEST_F(InferRequestTests, getInputCallsSetBlob) {
BlobMap blobMap{{blobName1, inblob},
{blobName2, inblob}};
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
ASSERT_NO_THROW(requestWrapper->SetInput(blobMap));
EXPECT_CALL(*mock_request.get(), SetBlob(blobName1, inblob));
EXPECT_CALL(*mock_request.get(), SetBlob(blobName2, inblob));
ASSERT_NO_THROW(request.SetInput(blobMap));
}
TEST_F(InferRequestTests, throwsIfSetInputReturnNotOK) {
EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
BlobMap blobMap{{{}, {}}};
ASSERT_THROW(requestWrapper->SetInput(blobMap), Exception);
ASSERT_THROW(request.SetInput(blobMap), Exception);
}
// SetOutput
@ -362,9 +336,9 @@ TEST_F(InferRequestTests, getOutputCallsSetBlob) {
BlobMap blobMap{{blobName1, inblob},
{blobName2, inblob}};
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
ASSERT_NO_THROW(requestWrapper->SetOutput(blobMap));
EXPECT_CALL(*mock_request.get(), SetBlob(blobName1, inblob));
EXPECT_CALL(*mock_request.get(), SetBlob(blobName2, inblob));
ASSERT_NO_THROW(request.SetOutput(blobMap));
}
// GetBlob
@ -373,16 +347,16 @@ TEST_F(InferRequestTests, canForwardGetBlob) {
blob->allocate();
std::string name = "blob1";
EXPECT_CALL(*mock_request.get(), GetBlob(StrEq(name.c_str()), _, _)).WillOnce(DoAll(SetArgReferee<1>(blob), Return(OK)));
ASSERT_NO_THROW(requestWrapper->GetBlob(name));
EXPECT_CALL(*mock_request.get(), GetBlob(_)).WillOnce(Return(blob));
ASSERT_NO_THROW(request.GetBlob(name));
}
TEST_F(InferRequestTests, throwsIfGetBlobReturnNotOK) {
Blob::Ptr blob;
std::string name = "blob1";
EXPECT_CALL(*mock_request.get(), GetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
ASSERT_THROW(blob = requestWrapper->GetBlob(name), Exception);
EXPECT_CALL(*mock_request.get(), GetBlob(_)).WillOnce(Throw(GeneralError{""}));
ASSERT_THROW(blob = request.GetBlob(name), Exception);
}
// SetBlob
@ -390,79 +364,49 @@ TEST_F(InferRequestTests, canForwardSetBlob) {
Blob::Ptr blob;
std::string name = "blob1";
EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(name.c_str()), blob, _)).WillOnce(Return(OK));
ASSERT_NO_THROW(requestWrapper->SetBlob(name, blob));
EXPECT_CALL(*mock_request.get(), SetBlob(name, blob));
ASSERT_NO_THROW(request.SetBlob(name, blob));
}
TEST_F(InferRequestTests, throwsIfSetBlobReturnNotOK) {
Blob::Ptr blob;
std::string name = "blob1";
EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
ASSERT_THROW(requestWrapper->SetBlob(name, blob), Exception);
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
ASSERT_THROW(request.SetBlob(name, blob), Exception);
}
TEST_F(InferRequestTests, throwsIfSetOutputReturnNotOK) {
EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
BlobMap blobMap{{{}, {}}};
ASSERT_THROW(requestWrapper->SetOutput(blobMap), Exception);
}
// SetCompletionCallback API
void callme(InferenceEngine::IInferRequest::Ptr p, InferenceEngine::StatusCode) {
void *data = nullptr;
p->GetUserData(&data, nullptr);
ASSERT_NE(nullptr, data);
}
TEST_F(InferRequestTests, canForwardCompletionCallback) {
void *data = nullptr;
EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
*pData = data;
}), Return(OK)));
EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
ASSERT_NO_THROW(requestWrapper->SetCompletionCallback(&callme));
ASSERT_THROW(request.SetOutput(blobMap), Exception);
}
TEST_F(InferRequestTests, canForwardAnyCallback) {
void *data = nullptr;
EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
*pData = data;
}), Return(OK)));
EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
ASSERT_NO_THROW(requestWrapper->SetCompletionCallback([&]() {
// data used to store callback pointer
ASSERT_NE(data, nullptr);
}));
EXPECT_CALL(*mock_request.get(), SetCallback(_));
ASSERT_NO_THROW(request.SetCompletionCallback([] {}));
}
TEST_F(InferRequestTests, failToSetInputWithInCorrectName) {
auto InferRequest = getInferRequestWithMockImplInside();
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(NotFound{""}));
auto blobMap = getBlobMapWithIncorrectName();
ASSERT_THROW(InferRequest->SetInput(blobMap), NotFound);
ASSERT_THROW(request.SetInput(blobMap), NotFound);
}
TEST_F(InferRequestTests, failToSetOutputWithInCorrectName) {
auto InferRequest = getInferRequestWithMockImplInside();
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(NotFound{""}));
auto blobMap = getBlobMapWithIncorrectName();
ASSERT_THROW(InferRequest->SetOutput(blobMap), NotFound);
ASSERT_THROW(request.SetOutput(blobMap), NotFound);
}
TEST_F(InferRequestTests, failToSetInputWithNotAllocatedInput) {
auto InferRequest = getInferRequestWithMockImplInside();
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(NotAllocated{""}));
auto blobMap = getBlobMapWithNotAllocatedInput();
ASSERT_THROW(InferRequest->SetInput(blobMap), NotAllocated);
ASSERT_THROW(request.SetInput(blobMap), NotAllocated);
}
TEST_F(InferRequestTests, failToSetInputWithEmptyDimensions) {
auto InferRequest = getInferRequestWithMockImplInside();
EXPECT_CALL(*mock_request.get(), SetBlob(_, _)).WillOnce(Throw(GeneralError{""}));
auto blobMap = getBlobMapWithEmptyDimensions();
ASSERT_THROW(InferRequest->SetInput(blobMap), GeneralError);
ASSERT_THROW(request.SetInput(blobMap), GeneralError);
}

View File

@ -13,7 +13,7 @@
#include <threading/ie_cpu_streams_executor.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/mock_task_executor.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_default.hpp"
using namespace ::testing;
@ -52,7 +52,7 @@ protected:
shared_ptr<AsyncInferRequestThreadSafeDefault> testRequest;
ResponseDesc dsc;
shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
MockTaskExecutor::Ptr mockTaskExecutor;
@ -63,7 +63,7 @@ protected:
InputsDataMap inputsInfo;
OutputsDataMap outputsInfo;
mockTaskExecutor = make_shared<MockTaskExecutor>();
mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(inputsInfo, outputsInfo);
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, mockTaskExecutor, mockTaskExecutor);
}
};
@ -90,26 +90,6 @@ TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfStartAsyncFails)
taskExecutor->executeAll();
}
// GetUserData
TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetUserData) {
auto taskExecutor = std::make_shared<DeferedExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(1).WillOnce(Return());
ASSERT_NO_THROW(testRequest->StartAsync());
ASSERT_THROW(testRequest->GetUserData(nullptr), RequestBusy);
taskExecutor->executeAll();
}
// SetUserData
TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetUserData) {
auto taskExecutor = std::make_shared<DeferedExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(1).WillOnce(Return());
ASSERT_NO_THROW(testRequest->StartAsync());
ASSERT_THROW(testRequest->SetUserData(nullptr), RequestBusy);
taskExecutor->executeAll();
}
// Wait
TEST_F(InferRequestThreadSafeDefaultTests, returnInferNotStartedOnWait) {
int64_t ms = 0;
@ -175,58 +155,42 @@ TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetCompletionCallb
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(1).WillOnce(Return());
ASSERT_NO_THROW(testRequest->StartAsync());
ASSERT_THROW(testRequest->SetCompletionCallback({}), RequestBusy);
ASSERT_THROW(testRequest->SetCallback({}), RequestBusy);
taskExecutor->executeAll();
}
TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
auto taskExecutor = std::make_shared<DeferedExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase(testRequest));
testRequest->SetPointerToPublicInterface(asyncRequest);
testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
ASSERT_EQ((int) StatusCode::OK, status);
std::exception_ptr exceptionPtr;
testRequest->SetCallback([&](std::exception_ptr exceptionPtr_) {
exceptionPtr = exceptionPtr_;
});
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).Times(1);
testRequest->StartAsync();
taskExecutor->executeAll();
testRequest->Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
ASSERT_EQ(nullptr, exceptionPtr);
}
TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed) {
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
auto taskExecutor = std::make_shared<DeferedExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase(testRequest));
testRequest->SetPointerToPublicInterface(asyncRequest);
bool wasCalled = false;
InferRequest cppRequest(asyncRequest);
std::function<void(InferRequest, StatusCode)> callback =
[&](InferRequest request, StatusCode status) {
wasCalled = true;
ASSERT_EQ(StatusCode::GENERAL_ERROR, status);
};
cppRequest.SetCompletionCallback(callback);
std::exception_ptr exceptionPtr;
testRequest->SetCallback([&](std::exception_ptr exceptionPtr_) {
exceptionPtr = exceptionPtr_;
});
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
testRequest->StartAsync();
taskExecutor->executeAll();
EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
ASSERT_TRUE(wasCalled);
ASSERT_NE(nullptr, exceptionPtr);
}
TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailedAndNoCallback) {
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase(testRequest));
testRequest->SetPointerToPublicInterface(asyncRequest);
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
testRequest->StartAsync();
EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
}

View File

@ -11,7 +11,6 @@
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_ivariable_state_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinference_plugin.hpp"
#include "ie_plugin_cpp.hpp"
@ -20,43 +19,32 @@ using namespace std;
using namespace InferenceEngine;
using namespace InferenceEngine::details;
template <class T>
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
typename InferRequestBase::Ptr req(new InferRequestBase(impl));
return InferenceEngine::InferRequest(req);
}
class VariableStateTests : public ::testing::Test {
protected:
shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
shared_ptr<MockIAsyncInferRequestInternal> mockInferRequestInternal;
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
shared_ptr<MockIVariableStateInternal> mockVariableStateInternal;
struct TestPluginInternal : public MockIInferencePlugin {
TestPluginInternal(const std::shared_ptr<MockIExecutableNetworkInternal>& mockIExeNet_) : mockIExeNet{mockIExeNet_} {}
std::shared_ptr<IExecutableNetworkInternal> LoadNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) override {
return mockIExeNet;
}
QueryNetworkResult QueryNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) const override {return {};}
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
};
struct TestPlugin : public InferenceEngine::InferencePlugin {
TestPlugin(std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet) :
InferenceEngine::InferencePlugin(InferenceEngine::details::SOPointer<TestPluginInternal>{
new TestPluginInternal{mockIExeNet}}) {}
};
MockIInferencePlugin* mockIPlugin;
InferencePlugin plugin;
ExecutableNetwork net;
InferRequest req;
virtual void SetUp() {
mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
mockInferRequestInternal = make_shared<MockIAsyncInferRequestInternal>();
mockInferRequestInternal = make_shared<MockIInferRequestInternal>();
mockVariableStateInternal = make_shared<MockIVariableStateInternal>();
ON_CALL(*mockExeNetworkInternal, CreateInferRequest()).WillByDefault(Return(mockInferRequestInternal));
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockExeNetworkInternal));
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
net = plugin.LoadNetwork({}, {});
req = net.CreateInferRequest();
}
};
TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn(1);
toReturn[0] = mockVariableStateInternal;
@ -69,7 +57,6 @@ TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToA
TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
@ -81,7 +68,6 @@ TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppTo
TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
toReturn.push_back(mockVariableStateInternal);
@ -95,7 +81,6 @@ TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAP
TEST_F(VariableStateTests, VariableStatePropagatesReset) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -109,7 +94,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesReset) {
TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -123,7 +107,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -137,7 +120,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -152,7 +134,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -168,7 +149,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -184,7 +164,6 @@ TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
Blob::Ptr saver;
toReturn.push_back(mockVariableStateInternal);
@ -204,7 +183,6 @@ TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
TEST_F(VariableStateTests, VariableStateCanPropagateGetLastState) {
IE_SUPPRESS_DEPRECATED_START
auto net = TestPlugin{mockExeNetworkInternal}.LoadNetwork({}, {});
std::vector<IVariableStateInternal::Ptr> toReturn;
float data[] = {123, 124, 125};
@ -269,18 +247,16 @@ TEST_F(VariableStateTests, VariableStateInternalCanSaveStateByReference) {
// Tests for InferRequest::QueryState
TEST_F(VariableStateTests, InferRequestCanConvertOneVariableStateFromCppToAPI) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn(1);
toReturn[0] = mockVariableStateInternal;
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
auto state = req.QueryState();
ASSERT_EQ(state.size(), 1);
}
TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillOnce(Return(toReturn));
@ -290,23 +266,21 @@ TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI)
}
TEST_F(VariableStateTests, InferRequestCanConvert2VariableStatesFromCPPtoAPI) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
auto state = req.QueryState();
ASSERT_EQ(state.size(), 2);
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
auto state = req.QueryState();
@ -314,11 +288,10 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
auto state = req.QueryState();
@ -326,11 +299,10 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetName) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto state = req.QueryState();
@ -339,7 +311,6 @@ auto req = make_infer_request(mockInferRequestInternal);
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
IE_SUPPRESS_DEPRECATED_START
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -356,7 +327,6 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
IE_SUPPRESS_DEPRECATED_START
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -374,7 +344,6 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
IE_SUPPRESS_DEPRECATED_START
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
@ -391,7 +360,6 @@ TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
}
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
Blob::Ptr saver;
toReturn.push_back(mockVariableStateInternal);
@ -409,7 +377,6 @@ TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
}
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateGetLastState) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
float data[] = {123, 124, 125};

View File

@ -27,7 +27,7 @@ protected:
shared_ptr<MockInferencePluginInternal> mock_plugin_impl;
shared_ptr<MockExecutableNetworkInternal> mockExeNetworkInternal;
shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetworkTS;
shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
shared_ptr<MockIInferRequestInternal> mockInferRequestInternal;
std::shared_ptr<MockNotEmptyICNNNetwork> mockNotEmptyNet = std::make_shared<MockNotEmptyICNNNetwork>();
std::string pluginId;
@ -50,13 +50,13 @@ protected:
mockExeNetworkInternal->SetPointerToPlugin(mock_plugin_impl);
}
void getInferRequestWithMockImplInside(IInferRequest::Ptr &request) {
void getInferRequestWithMockImplInside(IInferRequestInternal::Ptr &request) {
IExecutableNetworkInternal::Ptr exeNetwork;
InputsDataMap inputsInfo;
mockNotEmptyNet->getInputsInfo(inputsInfo);
OutputsDataMap outputsInfo;
mockNotEmptyNet->getOutputsInfo(outputsInfo);
mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
mockInferRequestInternal = make_shared<MockIInferRequestInternal>(inputsInfo, outputsInfo);
mockExeNetworkTS = make_shared<MockExecutableNetworkThreadSafe>();
EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(_, _)).WillOnce(Return(mockExeNetworkTS));
EXPECT_CALL(*mockExeNetworkTS.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
@ -74,14 +74,15 @@ TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithInCorrectName) {
inBlob->allocate();
string inputName = "not_input";
std::string refError = "[ NOT_FOUND ] Failed to find input or output with name: \'" + inputName + "\'";
IInferRequest::Ptr inferRequest;
IInferRequestInternal::Ptr inferRequest;
getInferRequestWithMockImplInside(inferRequest);
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
ASSERT_EQ(StatusCode::NOT_FOUND, sts);
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << dsc.msg;
try {
inferRequest->SetBlob(inputName, inBlob);
} catch(InferenceEngine::NotFound& ex) {
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << ex.what();
}
}
TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithEmptyName) {
@ -89,56 +90,60 @@ TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithEmptyName) {
inBlob->allocate();
string inputName = "not_input";
std::string refError = "[ NOT_FOUND ] Failed to set blob with empty name";
IInferRequest::Ptr inferRequest;
IInferRequestInternal::Ptr inferRequest;
getInferRequestWithMockImplInside(inferRequest);
ASSERT_NO_THROW(sts = inferRequest->SetBlob("", inBlob, &dsc));
ASSERT_EQ(StatusCode::NOT_FOUND, sts);
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << dsc.msg;
try {
inferRequest->SetBlob(inputName, inBlob);
} catch(InferenceEngine::NotFound& ex) {
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << ex.what();
}
}
TEST_F(InferenceEnginePluginInternalTest, failToSetNullPtr) {
string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
std::string refError = "[ NOT_ALLOCATED ] Failed to set empty blob with name: \'" + inputName + "\'";
IInferRequest::Ptr inferRequest;
IInferRequestInternal::Ptr inferRequest;
getInferRequestWithMockImplInside(inferRequest);
Blob::Ptr inBlob = nullptr;
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
ASSERT_EQ(StatusCode::NOT_ALLOCATED, sts);
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << dsc.msg;
try {
inferRequest->SetBlob(inputName, inBlob);
} catch(InferenceEngine::NotAllocated& ex) {
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << ex.what();
}
}
TEST_F(InferenceEnginePluginInternalTest, failToSetEmptyBlob) {
Blob::Ptr inBlob;
string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
std::string refError = "[ NOT_ALLOCATED ] Failed to set empty blob with name: \'" + inputName + "\'";
IInferRequest::Ptr inferRequest;
IInferRequestInternal::Ptr inferRequest;
getInferRequestWithMockImplInside(inferRequest);
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
ASSERT_EQ(StatusCode::NOT_ALLOCATED, sts);
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << dsc.msg;
try {
inferRequest->SetBlob(inputName, inBlob);
} catch(InferenceEngine::NotAllocated& ex) {
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << ex.what();
}
}
TEST_F(InferenceEnginePluginInternalTest, failToSetNotAllocatedBlob) {
string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
std::string refError = "[ NOT_ALLOCATED ] Input data was not allocated. Input name: \'" + inputName + "\'";
IInferRequest::Ptr inferRequest;
IInferRequestInternal::Ptr inferRequest;
getInferRequestWithMockImplInside(inferRequest);
Blob::Ptr blob = make_shared_blob<float>({ Precision::FP32, {}, NCHW });
ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), blob, &dsc));
ASSERT_EQ(StatusCode::NOT_ALLOCATED, sts);
ASSERT_TRUE(std::string{dsc.msg}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << dsc.msg;
try {
inferRequest->SetBlob(inputName, blob);
} catch(InferenceEngine::NotAllocated& ex) {
ASSERT_TRUE(std::string{ex.what()}.find(refError) != std::string::npos)
<< "\tExpected: " << refError
<< "\n\tActual: " << ex.what();
}
}
TEST_F(InferenceEnginePluginInternalTest, executableNetworkInternalExportsMagicAndName) {

View File

@ -58,15 +58,6 @@ TEST(ExceptionTests, ExceptionCanBeCaughtAsStandard) {
ASSERT_THROW(IE_THROW(), std::exception);
}
TEST(ExceptionTests, CanThrowStatusCode) {
try {
IE_THROW(InferNotStarted);
}
catch (const InferenceEngine::InferNotStarted& iex) {
ASSERT_EQ(InferenceEngine::ExceptionToStatus(iex), InferenceEngine::StatusCode::INFER_NOT_STARTED);
}
}
#ifdef NDEBUG // disabled for debug as macros calls assert()
TEST(ExceptionTests, ExceptionWithAssertThrowsNothingIfTrue) {
ASSERT_NO_THROW(IE_ASSERT(true) << "shouldn't assert if true");

View File

@ -39,31 +39,22 @@ class ExecutableNetworkTests : public ::testing::Test {
protected:
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
InferenceEngine::ExecutableNetwork exeNetwork;
MockIInferencePlugin* mockIPlugin;
InferencePlugin plugin;
struct TestPluginInternal : public MockIInferencePlugin {
TestPluginInternal(const std::shared_ptr<MockIExecutableNetworkInternal>& mockIExeNet_) : mockIExeNet{mockIExeNet_} {}
std::shared_ptr<IExecutableNetworkInternal> LoadNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) override {
return mockIExeNet;
}
QueryNetworkResult QueryNetwork(const CNNNetwork&, const std::map<std::string, std::string>&) const override {
IE_THROW(NotImplemented);
}
std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet;
};
struct TestPlugin : public InferenceEngine::InferencePlugin {
TestPlugin(std::shared_ptr<MockIExecutableNetworkInternal> mockIExeNet) :
InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<TestPluginInternal>{
new TestPluginInternal{mockIExeNet}}} {}
};
virtual void TearDown() {
mockIExeNet.reset();
exeNetwork = {};
plugin = {};
}
virtual void SetUp() {
mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
exeNetwork = TestPlugin{mockIExeNet}.LoadNetwork({}, {});
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
exeNetwork = plugin.LoadNetwork({}, {});
}
};
@ -126,7 +117,7 @@ IE_SUPPRESS_DEPRECATED_END
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
protected:
std::shared_ptr<MockIInferRequest> mockIInferReq_p;
std::shared_ptr<MockIInferRequestInternal> mockIInferReq_p;
virtual void TearDown() {
ExecutableNetworkTests::TearDown();
@ -135,7 +126,7 @@ protected:
virtual void SetUp() {
ExecutableNetworkTests::SetUp();
mockIInferReq_p = std::make_shared<MockIInferRequest>();
mockIInferReq_p = std::make_shared<MockIInferRequestInternal>();
}
};
@ -143,7 +134,6 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CanCreateInferRequest) {
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
InferRequest actualInferReq;
ASSERT_NO_THROW(actualInferReq = exeNetwork.CreateInferRequest());
ASSERT_EQ(mockIInferReq_p, static_cast<IInferRequest::Ptr &>(actualInferReq));
}
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotOK) {
@ -153,16 +143,14 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotO
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfSetRequestToNullptr) {
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest())
.WillOnce(Return(std::shared_ptr<MockIInferRequest>{}));
.WillOnce(Return(std::shared_ptr<MockIInferRequestInternal>{}));
ASSERT_THROW(exeNetwork.CreateInferRequest(), InferenceEngine::Exception);
}
// CreateInferRequestPtr
TEST_F(ExecutableNetworkWithIInferReqTests, CanCreateInferRequestPtr) {
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
InferRequest::Ptr actualInferReq;
ASSERT_NO_THROW(actualInferReq = exeNetwork.CreateInferRequestPtr());
ASSERT_EQ(mockIInferReq_p, static_cast<IInferRequest::Ptr &>(*actualInferReq.get()));
ASSERT_NO_THROW(exeNetwork.CreateInferRequest());
}
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestPtrThrowsIfReturnNotOK) {
@ -171,7 +159,7 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestPtrThrowsIfReturnN
}
TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestPtrThrowsIfSetRequestToNullptr) {
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(std::shared_ptr<MockIInferRequest>{}));
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(std::shared_ptr<MockIInferRequestInternal>{}));
ASSERT_THROW(exeNetwork.CreateInferRequestPtr(), InferenceEngine::Exception);
}
@ -194,9 +182,10 @@ protected:
// CreateInferRequest
TEST_F(ExecutableNetworkBaseTests, canForwardCreateInferRequest) {
auto inferReqInternal = std::make_shared<MockIInferRequestInternal>();
EXPECT_CALL(*mock_impl.get(), CreateInferRequest()).Times(1).WillRepeatedly(Return(inferReqInternal));
IInferRequest::Ptr req;
EXPECT_CALL(*mock_impl.get(), CreateInferRequest()).Times(1).WillRepeatedly(Return(req));
ASSERT_EQ(OK, exeNetwork->CreateInferRequest(req, &dsc));
ASSERT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
}
TEST_F(ExecutableNetworkBaseTests, canReportErrorInCreateInferRequest) {

View File

@ -18,6 +18,7 @@
#include "matchers/fill_with_data.hpp"
#include "matchers/weights_matcher.hpp"
#include <gmock/gmock-generated-actions.h>
#include <debug.h>
#include <gmock/gmock-more-actions.h>
#include "gmock/gmock.h"

View File

@ -1202,7 +1202,7 @@ TEST_F(MKLDNNGraphStructureTests, TestOutputAfterInplacePlusConcat) {
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
execNetwork->setNetworkInputs(_networkInputs);
execNetwork->setNetworkOutputs(_networkOutputs);
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::TensorDesc desc(InferenceEngine::Precision::FP32, {1, 3, 2, 2}, InferenceEngine::NCHW);
InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>(desc);
@ -1211,8 +1211,7 @@ TEST_F(MKLDNNGraphStructureTests, TestOutputAfterInplacePlusConcat) {
InferenceEngine::ResponseDesc resp;
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data", src, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob("data", src));
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
@ -1222,11 +1221,8 @@ TEST_F(MKLDNNGraphStructureTests, TestOutputAfterInplacePlusConcat) {
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
output->allocate();
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
sts = inferRequest->Infer(&resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob(item.first, output));
ASSERT_NO_THROW(inferRequest->Infer());
compare(*output, *src);
}
@ -1717,7 +1713,7 @@ TEST_F(MKLDNNGraphStructureTests, TestResnetPart) {
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
execNetwork->setNetworkInputs(_networkInputs);
execNetwork->setNetworkOutputs(_networkOutputs);
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::TensorDesc desc(InferenceEngine::Precision::FP32, {1, 3, 224, 224}, InferenceEngine::NCHW);
InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>(desc);
@ -1726,8 +1722,7 @@ TEST_F(MKLDNNGraphStructureTests, TestResnetPart) {
InferenceEngine::ResponseDesc resp;
InferenceEngine::StatusCode sts = inferRequest->SetBlob("input", src, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob("input", src));
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
@ -1737,18 +1732,16 @@ TEST_F(MKLDNNGraphStructureTests, TestResnetPart) {
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
output->allocate();
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob(item.first.c_str(), output));
sts = inferRequest->Infer(&resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->Infer());
}
TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
std::string model = R"V0G0N(
<net batch="1" name="model" version="2">
<layers>
<layer id="0" name="data" precision="FP32" type="Input">
<layer id="0" name="data1" precision="FP32" type="Input">
<output>
<port id="0">
<dim>1</dim>
@ -1866,7 +1859,7 @@ TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
execNetwork->setNetworkInputs(_networkInputs);
execNetwork->setNetworkOutputs(_networkOutputs);
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 20, 20}, InferenceEngine::NCHW);
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
@ -1885,10 +1878,9 @@ TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
InferenceEngine::ResponseDesc resp;
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data1", src1, &resp);
sts = inferRequest->SetBlob("data2", src2, &resp);
sts = inferRequest->SetBlob("data3", src3, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob("data1", src1));
ASSERT_NO_THROW(inferRequest->SetBlob("data2", src2));
ASSERT_NO_THROW(inferRequest->SetBlob("data3", src3));
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
@ -1898,11 +1890,9 @@ TEST_F(MKLDNNGraphStructureTests, TestConcatAfterConcat) {
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
output->allocate();
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob(item.first, output));
sts = inferRequest->Infer(&resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->Infer());
// compare(*output, *src);
}
@ -2046,7 +2036,7 @@ TEST_F(MKLDNNGraphStructureTests, Test2ConcatFromConcat) {
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
execNetwork->setNetworkInputs(_networkInputs);
execNetwork->setNetworkOutputs(_networkOutputs);
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 2, 2}, InferenceEngine::NCHW);
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
@ -2070,14 +2060,10 @@ TEST_F(MKLDNNGraphStructureTests, Test2ConcatFromConcat) {
InferenceEngine::ResponseDesc resp;
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data1", src1, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
sts = inferRequest->SetBlob("data2", src2, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
sts = inferRequest->SetBlob("data3", src3, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
sts = inferRequest->SetBlob("data4", src4, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob("data1", src1));
ASSERT_NO_THROW(inferRequest->SetBlob("data2", src2));
ASSERT_NO_THROW(inferRequest->SetBlob("data3", src3));
ASSERT_NO_THROW(inferRequest->SetBlob("data4", src4));
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
@ -2127,12 +2113,10 @@ TEST_F(MKLDNNGraphStructureTests, Test2ConcatFromConcat) {
}
refOutputs.push_back(refOutput);
sts = inferRequest->SetBlob(it.first.c_str(), output, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob(it.first, output));
}
sts = inferRequest->Infer(&resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->Infer());
for (size_t i = 0; i < outputs.size(); i++) {
compare(*outputs[i], *refOutputs[i]);
@ -2376,7 +2360,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithConstLayer) {
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
execNetwork->setNetworkInputs(_networkInputs);
execNetwork->setNetworkOutputs(_networkOutputs);
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 20, 20}, InferenceEngine::NCHW);
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
@ -2385,8 +2369,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithConstLayer) {
InferenceEngine::ResponseDesc resp;
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data", src1, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob("data", src1));
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
@ -2396,11 +2379,9 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithConstLayer) {
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
output->allocate();
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob(item.first.c_str(), output));
sts = inferRequest->Infer(&resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->Infer());
}
TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
@ -2523,7 +2504,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
InferenceEngine::OutputsDataMap _networkOutputs = network.getOutputsInfo();
execNetwork->setNetworkInputs(_networkInputs);
execNetwork->setNetworkOutputs(_networkOutputs);
InferenceEngine::IInferRequest::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::IInferRequestInternal::Ptr inferRequest = execNetwork->CreateInferRequest();
InferenceEngine::TensorDesc desc1(InferenceEngine::Precision::FP32, {1, 3, 20, 20}, InferenceEngine::NCHW);
InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float>(desc1);
@ -2535,8 +2516,7 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
InferenceEngine::ResponseDesc resp;
InferenceEngine::StatusCode sts = inferRequest->SetBlob("data", src1, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob("data", src1));
InferenceEngine::OutputsDataMap out = network.getOutputsInfo();
@ -2546,11 +2526,9 @@ TEST_F(MKLDNNGraphStructureTests, TestLoadTopologyWithEltwiseBeforeConcat) {
output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
output->allocate();
sts = inferRequest->SetBlob(item.first.c_str(), output, &resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->SetBlob(item.first.c_str(), output));
sts = inferRequest->Infer(&resp);
ASSERT_EQ(InferenceEngine::OK, sts) << resp.msg;
ASSERT_NO_THROW(inferRequest->Infer());
auto *res_ptr = output->buffer().as<float*>();
size_t res_size = output->size();