Move QueryState from ExecutableNetwork to InferRequest (#2818)
* QueryState moved to InferRequest * deprecate ExecutableNetwork::QueryState,chaged tests (without any check yet) * fix build * review fixes + build fix * build fix + review changes * remove blank line * style fixes * test build fixes * style fix * style fix * fixed build of tests * fix * mac build fix * hddl plugin build fix * clean up unneeded implementation for method * fixed tests build * add implementation for getstate, correct getName for MklDNN * fixed description of state API in comments * lint fixes * Rename MemoryState to VariableState * added tests for cpu for VariableStates, several small fixes in tests and code * merge fix * lint fix * remove whitespaces * spaces fix * fix in test to make it workable for all plugins * fix typo * fix test for gna * remove extra comment * fix test for gna
This commit is contained in:
parent
809c504d0a
commit
7bd76dc12b
@ -175,11 +175,13 @@ public:
|
||||
* Wraps IExecutableNetwork::QueryState
|
||||
* @return A vector of Memory State objects
|
||||
*/
|
||||
std::vector<MemoryState> QueryState() {
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<VariableState> QueryState() {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized.";
|
||||
IMemoryState::Ptr pState = nullptr;
|
||||
IVariableState::Ptr pState = nullptr;
|
||||
auto res = OK;
|
||||
std::vector<MemoryState> controller;
|
||||
std::vector<VariableState> controller;
|
||||
for (size_t idx = 0; res == OK; ++idx) {
|
||||
ResponseDesc resp;
|
||||
res = actual->QueryState(pState, idx, &resp);
|
||||
@ -187,10 +189,11 @@ public:
|
||||
THROW_IE_EXCEPTION << resp.msg;
|
||||
}
|
||||
if (res != OUT_OF_BOUNDS) {
|
||||
controller.push_back(MemoryState(pState));
|
||||
controller.push_back(VariableState(pState, plg));
|
||||
}
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
return controller;
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "cpp/ie_memory_state.hpp"
|
||||
#include "ie_iinfer_request.hpp"
|
||||
#include "details/ie_exception_conversion.hpp"
|
||||
#include "details/ie_so_loader.h"
|
||||
@ -250,6 +251,31 @@ public:
|
||||
actual->SetCompletionCallback(callWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* @copybrief IExecutableNetwork::QueryState
|
||||
*
|
||||
* Wraps IExecutableNetwork::QueryState
|
||||
* @return A vector of Memory State objects
|
||||
*/
|
||||
std::vector<VariableState> QueryState() {
|
||||
if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized.";
|
||||
IVariableState::Ptr pState = nullptr;
|
||||
auto res = OK;
|
||||
std::vector<VariableState> controller;
|
||||
for (size_t idx = 0; res == OK; ++idx) {
|
||||
ResponseDesc resp;
|
||||
res = actual->QueryState(pState, idx, &resp);
|
||||
if (res != OK && res != OUT_OF_BOUNDS) {
|
||||
THROW_IE_EXCEPTION << resp.msg;
|
||||
}
|
||||
if (res != OUT_OF_BOUNDS) {
|
||||
controller.push_back(VariableState(pState, plg));
|
||||
}
|
||||
}
|
||||
|
||||
return controller;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief IInferRequest pointer to be used directly in CreateInferRequest functions
|
||||
* @return A shared pointer to underlying IInferRequest interface
|
||||
|
@ -11,39 +11,42 @@
|
||||
#include <string>
|
||||
|
||||
#include "ie_imemory_state.hpp"
|
||||
#include "details/ie_exception_conversion.hpp"
|
||||
#include "details/ie_so_loader.h"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @brief C++ exception based error reporting wrapper of API class IMemoryState
|
||||
* @brief C++ exception based error reporting wrapper of API class IVariableState
|
||||
*/
|
||||
class MemoryState {
|
||||
IMemoryState::Ptr actual = nullptr;
|
||||
class VariableState {
|
||||
IVariableState::Ptr actual = nullptr;
|
||||
details::SharedObjectLoader::Ptr plugin = {};
|
||||
|
||||
public:
|
||||
/**
|
||||
* constructs MemoryState from the initialized shared_pointer
|
||||
* constructs VariableState from the initialized shared_pointer
|
||||
* @param pState Initialized shared pointer
|
||||
*/
|
||||
explicit MemoryState(IMemoryState::Ptr pState): actual(pState) {
|
||||
explicit VariableState(IVariableState::Ptr pState, details::SharedObjectLoader::Ptr plg = {}) : actual(pState), plugin(plg) {
|
||||
if (actual == nullptr) {
|
||||
THROW_IE_EXCEPTION << "MemoryState wrapper was not initialized.";
|
||||
THROW_IE_EXCEPTION << "VariableState wrapper was not initialized.";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @copybrief IMemoryState::Reset
|
||||
* @copybrief IVariableState::Reset
|
||||
*
|
||||
* Wraps IMemoryState::Reset
|
||||
* Wraps IVariableState::Reset
|
||||
*/
|
||||
void Reset() {
|
||||
CALL_STATUS_FNC_NO_ARGS(Reset);
|
||||
}
|
||||
|
||||
/**
|
||||
* @copybrief IMemoryState::GetName
|
||||
* @copybrief IVariableState::GetName
|
||||
*
|
||||
* Wraps IMemoryState::GetName
|
||||
* Wraps IVariableState::GetName
|
||||
* @return A string representing a state name
|
||||
*/
|
||||
std::string GetName() const {
|
||||
@ -53,21 +56,26 @@ public:
|
||||
}
|
||||
|
||||
/**
|
||||
* @copybrief IMemoryState::GetLastState
|
||||
* @copybrief IVariableState::GetState
|
||||
*
|
||||
* Wraps IMemoryState::GetLastState
|
||||
* Wraps IVariableState::GetState
|
||||
* @return A blob representing a last state
|
||||
*/
|
||||
Blob::CPtr GetLastState() const {
|
||||
Blob::CPtr GetState() const {
|
||||
Blob::CPtr stateBlob;
|
||||
CALL_STATUS_FNC(GetLastState, stateBlob);
|
||||
CALL_STATUS_FNC(GetState, stateBlob);
|
||||
return stateBlob;
|
||||
}
|
||||
|
||||
INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
|
||||
Blob::CPtr GetLastState() const {
|
||||
return GetState();
|
||||
}
|
||||
|
||||
/**
|
||||
* @copybrief IMemoryState::SetState
|
||||
* @copybrief IVariableState::SetState
|
||||
*
|
||||
* Wraps IMemoryState::SetState
|
||||
* Wraps IVariableState::SetState
|
||||
* @param state The current state to set
|
||||
*/
|
||||
void SetState(Blob::Ptr state) {
|
||||
@ -75,4 +83,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
/*
|
||||
* @brief For compatibility reasons.
|
||||
*/
|
||||
using MemoryState = VariableState;
|
||||
} // namespace InferenceEngine
|
||||
|
@ -118,7 +118,7 @@ public:
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for
|
||||
* given index
|
||||
*/
|
||||
virtual StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
|
||||
virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets configuration for current executable network
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "ie_blob.h"
|
||||
#include "ie_common.h"
|
||||
#include "ie_preprocess.hpp"
|
||||
#include "ie_imemory_state.hpp"
|
||||
#include "details/ie_irelease.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
@ -177,6 +178,18 @@ public:
|
||||
* @return Enumeration of the resulted action: InferenceEngine::OK (0) for success
|
||||
*/
|
||||
virtual InferenceEngine::StatusCode SetBatch(int batch_size, ResponseDesc* resp) noexcept = 0;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
/**
|
||||
* @brief Gets state control interface for given infer request.
|
||||
*
|
||||
* State control essential for recurrent networks
|
||||
*
|
||||
* @param pState reference to a pointer that receives internal states
|
||||
* @param idx requested index for receiving memory state
|
||||
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for
|
||||
* given index
|
||||
*/
|
||||
virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
|
||||
};
|
||||
} // namespace InferenceEngine
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief a header file for IMemoryState interface
|
||||
* @brief a header file for IVariableState interface
|
||||
*
|
||||
* @file ie_imemory_state.hpp
|
||||
*/
|
||||
@ -19,19 +19,19 @@
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @interface IMemoryState
|
||||
* @interface IVariableState
|
||||
* @brief manages data for reset operations
|
||||
*/
|
||||
class IMemoryState : public details::no_copy {
|
||||
class IVariableState : public details::no_copy {
|
||||
public:
|
||||
/**
|
||||
* @brief A shared pointer to the IMemoryState interface
|
||||
* @brief A shared pointer to the IVariableState interface
|
||||
*/
|
||||
using Ptr = std::shared_ptr<IMemoryState>;
|
||||
using Ptr = std::shared_ptr<IVariableState>;
|
||||
|
||||
/**
|
||||
* @brief Gets name of current memory state, if length of array is not enough name is truncated by len, null
|
||||
* terminator is inserted as well.
|
||||
* terminator is inserted as well. As memory state name variable_id from according ReadValue used.
|
||||
*
|
||||
* @param name preallocated buffer for receiving name
|
||||
* @param len Length of the buffer
|
||||
@ -41,7 +41,7 @@ public:
|
||||
virtual StatusCode GetName(char* name, size_t len, ResponseDesc* resp) const noexcept = 0;
|
||||
|
||||
/**
|
||||
* @brief reset internal memory state for relevant iexecutable network, to a value specified in SetState
|
||||
* @brief Reset internal memory state for relevant infer request, to a value specified as default for according ReadValue node
|
||||
*
|
||||
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success*
|
||||
@ -49,25 +49,30 @@ public:
|
||||
virtual StatusCode Reset(ResponseDesc* resp) noexcept = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets the new state that is used for all future Reset() operations as a base.
|
||||
* @brief Sets the new state for the next inference.
|
||||
*
|
||||
* This method can fail if Blob size does not match the internal state size or precision
|
||||
*
|
||||
* @param newState is the data to use as base state
|
||||
* @param newState is the data to use as new state
|
||||
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success
|
||||
*/
|
||||
virtual StatusCode SetState(Blob::Ptr newState, ResponseDesc* resp) noexcept = 0;
|
||||
|
||||
/**
|
||||
* @brief returns the value of the last memory state.
|
||||
* @brief Returns the value of the memory state.
|
||||
*
|
||||
* @details Since we roll memory after each infer, we can query the input state always and still get the last state.
|
||||
* @param lastState
|
||||
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success
|
||||
* */
|
||||
virtual StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept = 0;
|
||||
INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
|
||||
virtual StatusCode GetLastState(Blob::CPtr& state, ResponseDesc* resp) const noexcept {return GetState(state, resp);}
|
||||
virtual StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept = 0;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief For compatibility reasons.
|
||||
*/
|
||||
using IMemoryState = IVariableState;
|
||||
} // namespace InferenceEngine
|
@ -845,7 +845,7 @@ int main(int argc, char *argv[]) {
|
||||
ptrUtterances.resize(inputArkFiles.size());
|
||||
|
||||
// initialize memory state before starting
|
||||
for (auto &&state : executableNet.QueryState()) {
|
||||
for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
|
||||
state.Reset();
|
||||
}
|
||||
|
||||
@ -1080,7 +1080,7 @@ int main(int argc, char *argv[]) {
|
||||
totalTime += d.count();
|
||||
|
||||
// resetting state between utterances
|
||||
for (auto &&state : executableNet.QueryState()) {
|
||||
for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
|
||||
state.Reset();
|
||||
}
|
||||
|
||||
|
@ -59,12 +59,13 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
|
||||
return std::make_shared<GNAInferRequest>(plg, networkInputs, networkOutputs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState() override {
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto pluginStates = plg->QueryState();
|
||||
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
||||
return plg->QueryState();
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void Export(const std::string &modelFileName) override {
|
||||
|
@ -111,5 +111,13 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
||||
}
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
||||
auto pluginStates = plg->QueryState();
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
|
||||
return plg->QueryState();
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
};
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -1186,11 +1186,11 @@ Blob::Ptr GNAPlugin::GetInputBlob(const std::string& name, InferenceEngine::Prec
|
||||
return inputBlob;
|
||||
}
|
||||
|
||||
std::vector<InferenceEngine::MemoryStateInternal::Ptr> GNAPlugin::QueryState() {
|
||||
std::vector<InferenceEngine::VariableStateInternal::Ptr> GNAPlugin::QueryState() {
|
||||
if (memoryStates.size() != graphCompiler.memory_connection.size()) {
|
||||
memoryStates.clear();
|
||||
for (auto& connection : graphCompiler.memory_connection) {
|
||||
auto state = std::make_shared<memory::GNAMemoryState>(connection.first, std::make_shared <GNAMemoryLayer>(connection.second));
|
||||
auto state = std::make_shared<memory::GNAVariableState>(connection.first, std::make_shared <GNAMemoryLayer>(connection.second));
|
||||
memoryStates.emplace_back(state);
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
|
||||
|
||||
InferenceEngine::InputsDataMap inputsDataMap;
|
||||
InferenceEngine::OutputsDataMap outputsDataMap;
|
||||
std::vector<InferenceEngine::MemoryStateInternal::Ptr> memoryStates;
|
||||
std::vector<InferenceEngine::VariableStateInternal::Ptr> memoryStates;
|
||||
|
||||
public:
|
||||
explicit GNAPlugin(const std::map<std::string, std::string>& configMap);
|
||||
@ -159,7 +159,8 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
|
||||
* QueryState API
|
||||
* @return
|
||||
*/
|
||||
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState();
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState();
|
||||
|
||||
/**
|
||||
* test-wise API
|
||||
|
@ -12,15 +12,15 @@ namespace GNAPluginNS {
|
||||
|
||||
namespace memory {
|
||||
|
||||
std::string GNAMemoryState::GetName() const {
|
||||
std::string GNAVariableState::GetName() const {
|
||||
return name;
|
||||
}
|
||||
|
||||
void GNAMemoryState::Reset() {
|
||||
void GNAVariableState::Reset() {
|
||||
state->Reset();
|
||||
}
|
||||
|
||||
InferenceEngine::Precision GNAMemoryState::getPrecision() const {
|
||||
InferenceEngine::Precision GNAVariableState::getPrecision() const {
|
||||
InferenceEngine::Precision state_precision;
|
||||
|
||||
if (state->getInput()) {
|
||||
@ -36,14 +36,14 @@ namespace memory {
|
||||
break;
|
||||
default:
|
||||
THROW_GNA_EXCEPTION << "Incorrect state element size " << element_size <<
|
||||
" to determine precision for MemoryState " << name;
|
||||
" to determine precision for VariableState " << name;
|
||||
}
|
||||
}
|
||||
|
||||
return state_precision;
|
||||
}
|
||||
|
||||
void GNAMemoryState::SetState(InferenceEngine::Blob::Ptr newState) {
|
||||
void GNAVariableState::SetState(InferenceEngine::Blob::Ptr newState) {
|
||||
IE_ASSERT(newState != nullptr);
|
||||
|
||||
auto data_ptr = newState->cbuffer().as<void*>();
|
||||
@ -78,20 +78,20 @@ namespace memory {
|
||||
data_elements,
|
||||
scale_factor);
|
||||
} else {
|
||||
THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name
|
||||
THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name
|
||||
<< ". If old state precision is I16 only I16 and FP32 are allowed as new state precisions."
|
||||
<< " Old state: " << state_precision << " New state: " << new_state_precision;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name
|
||||
THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name
|
||||
<< ". Incorrect new/old precision pair"
|
||||
<< " Old state: " << state_precision << " New state: " << new_state_precision;
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::CPtr GNAMemoryState::GetLastState() const {
|
||||
InferenceEngine::Blob::CPtr GNAVariableState::GetState() const {
|
||||
auto elements = state->reserved_size / state->elementSizeBytes();
|
||||
InferenceEngine::Precision state_precision = getPrecision();
|
||||
|
||||
|
@ -11,14 +11,14 @@
|
||||
|
||||
namespace GNAPluginNS {
|
||||
namespace memory {
|
||||
class GNAMemoryState : public InferenceEngine::IMemoryStateInternal {
|
||||
class GNAVariableState : public InferenceEngine::IVariableStateInternal {
|
||||
public:
|
||||
GNAMemoryState(std::string name, std::shared_ptr<GNAMemoryLayer> state)
|
||||
GNAVariableState(std::string name, std::shared_ptr<GNAMemoryLayer> state)
|
||||
: name(name), state(state) { IE_ASSERT(state != nullptr); }
|
||||
|
||||
void Reset() override;
|
||||
void SetState(InferenceEngine::Blob::Ptr newState) override;
|
||||
InferenceEngine::Blob::CPtr GetLastState() const override;
|
||||
InferenceEngine::Blob::CPtr GetState() const override;
|
||||
std::string GetName() const override;
|
||||
|
||||
private:
|
||||
|
@ -183,14 +183,14 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network
|
||||
if (node->getType() == MemoryInput) {
|
||||
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
|
||||
auto state_store = memoryNode->getStore();
|
||||
auto state_name = node->getName();
|
||||
auto state_name = memoryNode->getId();
|
||||
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos)
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
|
||||
memoryStates.emplace_back(new MKLDNNMemoryState(state_name, state_store));
|
||||
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -314,6 +314,8 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::ICNNNetwork &n
|
||||
return check_result;
|
||||
}
|
||||
|
||||
std::vector<IMemoryStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
|
||||
return memoryStates;
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
@ -42,14 +42,15 @@ public:
|
||||
|
||||
InferenceEngine::CNNNetwork GetExecGraphInfo() override;
|
||||
|
||||
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState() override;
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
|
||||
|
||||
InferenceEngine::ThreadLocal<MKLDNNGraph::Ptr> _graphs;
|
||||
|
||||
protected:
|
||||
friend class MKLDNNInferRequest;
|
||||
MKLDNNExtensionManager::Ptr extensionManager;
|
||||
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> memoryStates;
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
|
||||
InferenceEngine::details::CNNNetworkImplPtr _clonedNetwork;
|
||||
std::mutex _cfgMutex;
|
||||
Config _cfg;
|
||||
|
@ -14,6 +14,8 @@
|
||||
#include "mkldnn_exec_network.h"
|
||||
#include "mkldnn_itt.h"
|
||||
#include "nodes/common/cpu_convert.h"
|
||||
#include "mkldnn_memory_state.h"
|
||||
#include "nodes/mkldnn_memory_node.hpp"
|
||||
|
||||
MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
|
||||
InferenceEngine::OutputsDataMap networkOutputs,
|
||||
@ -35,6 +37,30 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData
|
||||
InferenceEngine::Blob::Ptr blob;
|
||||
MKLDNNInferRequest::GetBlob(it.first.c_str(), blob);
|
||||
}
|
||||
|
||||
// Save all MemoryLayer data tensors. Will use insight about mechanics
|
||||
// of MemoryLayer implementation. It uses output edge of MemoryLayer
|
||||
// producer as storage for tensor to keep it between infer calls.
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
if (execNetwork->QueryState().size() == 0) {
|
||||
for (auto &node : graph->GetNodes()) {
|
||||
if (node->getType() == MemoryInput) {
|
||||
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
|
||||
auto state_store = memoryNode->getStore();
|
||||
auto state_name = memoryNode->getId();
|
||||
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos)
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
|
||||
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
memoryStates = execNetwork->QueryState();
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {
|
||||
@ -390,3 +416,7 @@ void MKLDNNPlugin::MKLDNNInferRequest::SetBatch(int new_batch) {
|
||||
|
||||
m_curBatch = new_batch;
|
||||
}
|
||||
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MKLDNNPlugin::MKLDNNInferRequest::QueryState() {
|
||||
return memoryStates;
|
||||
}
|
||||
|
@ -43,6 +43,8 @@ public:
|
||||
|
||||
void SetBatch(int batch = -1) override;
|
||||
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
|
||||
|
||||
private:
|
||||
void PushInputData();
|
||||
|
||||
@ -53,5 +55,6 @@ private:
|
||||
MKLDNNGraph* graph = nullptr;
|
||||
std::map<std::string, void*> externalPtr;
|
||||
openvino::itt::handle_t profilingTask;
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
|
||||
};
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -4,20 +4,21 @@
|
||||
|
||||
#include "mkldnn_memory_state.h"
|
||||
#include "mkldnn_extension_utils.h"
|
||||
#include "blob_factory.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
std::string MKLDNNMemoryState::GetName() const {
|
||||
std::string MKLDNNVariableState::GetName() const {
|
||||
return name;
|
||||
}
|
||||
|
||||
void MKLDNNMemoryState::Reset() {
|
||||
void MKLDNNVariableState::Reset() {
|
||||
storage->FillZero();
|
||||
}
|
||||
|
||||
void MKLDNNMemoryState::SetState(Blob::Ptr newState) {
|
||||
void MKLDNNVariableState::SetState(Blob::Ptr newState) {
|
||||
auto prec = newState->getTensorDesc().getPrecision();
|
||||
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(prec);
|
||||
auto data_layout = MKLDNNMemory::Convert(newState->getTensorDesc().getLayout());
|
||||
@ -27,9 +28,11 @@ void MKLDNNMemoryState::SetState(Blob::Ptr newState) {
|
||||
storage->SetData(data_type, data_layout, data_ptr, data_size);
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::CPtr MKLDNNMemoryState::GetLastState() const {
|
||||
THROW_IE_EXCEPTION << "GetLastState method is not implemented for MemoryState";
|
||||
return nullptr;
|
||||
InferenceEngine::Blob::CPtr MKLDNNVariableState::GetState() const {
|
||||
auto result_blob = make_blob_with_precision(MKLDNNMemoryDesc(storage->GetDescriptor()));
|
||||
result_blob->allocate();
|
||||
std::memcpy(result_blob->buffer(), storage->GetData(), storage->GetSize());
|
||||
return result_blob;
|
||||
}
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -11,15 +11,15 @@
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
class MKLDNNMemoryState : public InferenceEngine::IMemoryStateInternal {
|
||||
class MKLDNNVariableState : public InferenceEngine::IVariableStateInternal {
|
||||
public:
|
||||
MKLDNNMemoryState(std::string name, MKLDNNMemoryPtr storage) :
|
||||
MKLDNNVariableState(std::string name, MKLDNNMemoryPtr storage) :
|
||||
name(name), storage(storage) {}
|
||||
|
||||
std::string GetName() const override;
|
||||
void Reset() override;
|
||||
void SetState(InferenceEngine::Blob::Ptr newState) override;
|
||||
InferenceEngine::Blob::CPtr GetLastState() const override;
|
||||
InferenceEngine::Blob::CPtr GetState() const override;
|
||||
|
||||
private:
|
||||
std::string name;
|
||||
|
@ -66,19 +66,22 @@ public:
|
||||
TO_STATUS(graphPtr = _impl->GetExecGraphInfo());
|
||||
}
|
||||
|
||||
StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
try {
|
||||
auto v = _impl->QueryState();
|
||||
if (idx >= v.size()) {
|
||||
return OUT_OF_BOUNDS;
|
||||
}
|
||||
pState = std::make_shared<MemoryStateBase<IMemoryStateInternal>>(v[idx]);
|
||||
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
|
||||
return OK;
|
||||
} catch (const std::exception& ex) {
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
||||
} catch (...) {
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED);
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void Release() noexcept override {
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
#include "cpp_interfaces/exception2status.hpp"
|
||||
#include "cpp_interfaces/plugin_itt.hpp"
|
||||
#include <cpp_interfaces/base/ie_memory_state_base.hpp>
|
||||
#include "ie_iinfer_request.hpp"
|
||||
#include "ie_preprocess.hpp"
|
||||
#include "ie_profiling.hpp"
|
||||
@ -88,6 +89,21 @@ public:
|
||||
TO_STATUS(_impl->SetBatch(batch_size));
|
||||
}
|
||||
|
||||
StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
|
||||
try {
|
||||
auto v = _impl->QueryState();
|
||||
if (idx >= v.size()) {
|
||||
return OUT_OF_BOUNDS;
|
||||
}
|
||||
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
|
||||
return OK;
|
||||
} catch (const std::exception& ex) {
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
||||
} catch (...) {
|
||||
return InferenceEngine::DescriptionBuffer(UNEXPECTED);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
~InferRequestBase() = default;
|
||||
};
|
||||
|
@ -7,23 +7,24 @@
|
||||
#include <memory>
|
||||
|
||||
#include "cpp_interfaces/exception2status.hpp"
|
||||
#include "cpp_interfaces/impl/ie_memory_state_internal.hpp"
|
||||
#include "ie_imemory_state.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @brief default implementation for IMemoryState
|
||||
* @brief default implementation for IVariableState
|
||||
* @ingroup ie_dev_api_mem_state_api
|
||||
*/
|
||||
template <class T>
|
||||
class MemoryStateBase : public IMemoryState {
|
||||
class VariableStateBase : public IVariableState {
|
||||
protected:
|
||||
std::shared_ptr<T> impl;
|
||||
|
||||
public:
|
||||
explicit MemoryStateBase(std::shared_ptr<T> impl): impl(impl) {
|
||||
explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
|
||||
if (impl == nullptr) {
|
||||
THROW_IE_EXCEPTION << "MemoryStateBase implementation not defined";
|
||||
THROW_IE_EXCEPTION << "VariableStateBase implementation not defined";
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,9 +45,9 @@ public:
|
||||
TO_STATUS(impl->SetState(newState));
|
||||
}
|
||||
|
||||
StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept override {
|
||||
TO_STATUS(lastState = impl->GetLastState());
|
||||
StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept override {
|
||||
TO_STATUS(state = impl->GetState());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
} // namespace InferenceEngine
|
||||
|
@ -88,7 +88,7 @@ public:
|
||||
_plugin = plugin;
|
||||
}
|
||||
|
||||
std::vector<IMemoryStateInternal::Ptr> QueryState() override {
|
||||
std::vector<IVariableStateInternal::Ptr> QueryState() override {
|
||||
THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
|
||||
}
|
||||
|
||||
|
@ -152,6 +152,10 @@ public:
|
||||
_publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
|
||||
}
|
||||
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
||||
return _syncRequest->QueryState();
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation
|
||||
|
@ -223,6 +223,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<IVariableStateInternal::Ptr> QueryState() override {
|
||||
// meaning base plugin reports as no state available - plugin owners need to create proper override of this
|
||||
THROW_IE_EXCEPTION << "Plugin doesn't override QueryState";
|
||||
return {};
|
||||
}
|
||||
|
||||
protected:
|
||||
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
|
||||
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data
|
||||
|
@ -13,21 +13,25 @@ namespace InferenceEngine {
|
||||
* @brief minimal interface for memory state implementation
|
||||
* @ingroup ie_dev_api_mem_state_api
|
||||
*/
|
||||
class MemoryStateInternal : public IMemoryStateInternal {
|
||||
class VariableStateInternal : public IVariableStateInternal {
|
||||
std::string name;
|
||||
Blob::Ptr state;
|
||||
|
||||
public:
|
||||
explicit MemoryStateInternal(std::string name): name(name) {}
|
||||
explicit VariableStateInternal(std::string name): name(name) {}
|
||||
std::string GetName() const override {
|
||||
return name;
|
||||
}
|
||||
void SetState(Blob::Ptr newState) override {
|
||||
state = newState;
|
||||
}
|
||||
Blob::CPtr GetLastState() const override {
|
||||
Blob::CPtr GetState() const override {
|
||||
return state;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
/*
|
||||
* @brief For compatibility reasons.
|
||||
*/
|
||||
using MemoryStateInternal = VariableStateInternal;
|
||||
} // namespace InferenceEngine
|
||||
|
@ -79,7 +79,7 @@ public:
|
||||
* @brief Queries memory states.
|
||||
* @return Returns memory states
|
||||
*/
|
||||
virtual std::vector<IMemoryStateInternal::Ptr> QueryState() = 0;
|
||||
virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
|
||||
|
||||
/**
|
||||
* @brief Sets configuration for current executable network
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
|
||||
#include <ie_blob.h>
|
||||
#include <ie_common.h>
|
||||
#include <ie_preprocess.hpp>
|
||||
@ -83,6 +84,12 @@ public:
|
||||
* @param batch - new batch size to be used by all the following inference calls for this request.
|
||||
*/
|
||||
virtual void SetBatch(int batch) = 0;
|
||||
|
||||
/**
|
||||
* @brief Queries memory states.
|
||||
* @return Returns memory states
|
||||
*/
|
||||
virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -11,19 +11,25 @@
|
||||
|
||||
namespace InferenceEngine {
|
||||
/**
|
||||
* @interface IMemoryStateInternal
|
||||
* @interface IVariableStateInternal
|
||||
* @brief minimal interface for memory state implementation
|
||||
* @ingroup ie_dev_api_mem_state_api
|
||||
*/
|
||||
class IMemoryStateInternal {
|
||||
class IVariableStateInternal {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<IMemoryStateInternal>;
|
||||
using Ptr = std::shared_ptr<IVariableStateInternal>;
|
||||
|
||||
virtual ~IMemoryStateInternal() = default;
|
||||
virtual ~IVariableStateInternal() = default;
|
||||
virtual std::string GetName() const = 0;
|
||||
virtual void Reset() = 0;
|
||||
virtual void SetState(Blob::Ptr newState) = 0;
|
||||
virtual Blob::CPtr GetLastState() const = 0;
|
||||
virtual Blob::CPtr GetState() const = 0;
|
||||
INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
|
||||
virtual Blob::CPtr GetLastState() const {return GetState();}
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief For compatibility reasons.
|
||||
*/
|
||||
using IMemoryStateInternal = IVariableStateInternal;
|
||||
} // namespace InferenceEngine
|
||||
|
@ -83,3 +83,8 @@ TEST(InferRequestCPPTests, throwsOnUninitializedCast) {
|
||||
InferRequest req;
|
||||
ASSERT_THROW(auto &ireq = static_cast<IInferRequest::Ptr &>(req), InferenceEngine::details::InferenceEngineException);
|
||||
}
|
||||
|
||||
TEST(InferRequestCPPTests, throwsOnUninitializedQueryState) {
|
||||
InferRequest req;
|
||||
ASSERT_THROW(req.QueryState(), InferenceEngine::details::InferenceEngineException);
|
||||
}
|
||||
|
@ -46,8 +46,10 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) {
|
||||
}
|
||||
|
||||
TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
ExecutableNetwork exec;
|
||||
ASSERT_THROW(exec.QueryState(), InferenceEngine::details::InferenceEngineException);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {
|
||||
|
@ -10,12 +10,15 @@ namespace {
|
||||
// 0 - plugin
|
||||
// 1 - executable_network
|
||||
// 2 - infer_request
|
||||
{0, 1, 2},
|
||||
{0, 2, 1},
|
||||
{1, 0, 2},
|
||||
{1, 2, 0},
|
||||
{2, 0, 1},
|
||||
{2, 1, 0}
|
||||
// 3 - variable state
|
||||
{3, 0, 1, 2},
|
||||
{3, 0, 2, 1},
|
||||
{3, 1, 0, 2},
|
||||
{3, 1, 2, 0},
|
||||
{3, 2, 0, 1},
|
||||
{3, 2, 1, 0},
|
||||
{0, 3, 1, 2},
|
||||
{0, 1, 3, 2}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest,
|
||||
@ -24,4 +27,4 @@ namespace {
|
||||
::testing::ValuesIn(orders)),
|
||||
HoldersTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
@ -0,0 +1,22 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <common_test_utils/test_constants.hpp>
|
||||
#include "behavior/memory_states.hpp"
|
||||
#include "functional_test_utils/test_model/test_model.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
|
||||
InferenceEngine::CNNNetwork getNetwork() {
|
||||
auto model = FuncTestUtils::TestModel::getModelWithMultipleMemoryConnections(InferenceEngine::Precision::FP32);
|
||||
auto ie = PluginCache::get().ie();
|
||||
return ie->ReadNetwork(model.model_xml_str, model.weights_blob);
|
||||
}
|
||||
std::vector<memoryStateParams> memoryStateTestCases = {
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_CPU)
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest,
|
||||
::testing::ValuesIn(memoryStateTestCases),
|
||||
VariableStateTest::getTestCaseName);
|
@ -10,12 +10,15 @@ namespace {
|
||||
// 0 - plugin
|
||||
// 1 - executable_network
|
||||
// 2 - infer_request
|
||||
{0, 1, 2},
|
||||
{0, 2, 1},
|
||||
{1, 0, 2},
|
||||
{1, 2, 0},
|
||||
{2, 0, 1},
|
||||
{2, 1, 0}
|
||||
// 3 - variable state
|
||||
{3, 0, 1, 2},
|
||||
{3, 0, 2, 1},
|
||||
{3, 1, 0, 2},
|
||||
{3, 1, 2, 0},
|
||||
{3, 2, 0, 1},
|
||||
{3, 2, 1, 0},
|
||||
{0, 3, 1, 2},
|
||||
{0, 1, 3, 2}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest,
|
||||
|
@ -17,6 +17,6 @@ std::vector<memoryStateParams> memoryStateTestCases = {
|
||||
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA)
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_MemoryStateBasic, MemoryStateTest,
|
||||
INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest,
|
||||
::testing::ValuesIn(memoryStateTestCases),
|
||||
MemoryStateTest::getTestCaseName);
|
||||
VariableStateTest::getTestCaseName);
|
||||
|
@ -14,7 +14,7 @@ typedef std::tuple<
|
||||
std::string> // Target device name
|
||||
memoryStateParams;
|
||||
|
||||
class MemoryStateTest : public CommonTestUtils::TestsCommon,
|
||||
class VariableStateTest : public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<memoryStateParams> {
|
||||
protected:
|
||||
InferenceEngine::CNNNetwork net;
|
||||
|
@ -25,7 +25,11 @@ namespace BehaviorTestsDefinitions {
|
||||
if (deathTestStyle == "fast") {
|
||||
::testing::GTEST_FLAG(death_test_style) = "threadsafe";
|
||||
}
|
||||
function = ngraph::builder::subgraph::makeConvPoolRelu();
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
function = ngraph::builder::subgraph::makeReadConcatSplitAssign();
|
||||
} else {
|
||||
function = ngraph::builder::subgraph::makeConvPoolRelu();
|
||||
}
|
||||
}
|
||||
|
||||
void HoldersTest::TearDown() {
|
||||
@ -42,6 +46,12 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
|
||||
InferenceEngine::Core core;
|
||||
auto exe_net = core.LoadNetwork(cnnNet, deviceName);
|
||||
auto request = exe_net.CreateInferRequest();
|
||||
std::vector<InferenceEngine::VariableState> states = {};
|
||||
try {
|
||||
states = request.QueryState();
|
||||
} catch(...) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
auto release = [&](int i) {
|
||||
switch (i) {
|
||||
@ -54,6 +64,9 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
|
||||
case 2:
|
||||
request = {};
|
||||
break;
|
||||
case 3:
|
||||
states = {};
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -67,4 +80,4 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
|
||||
// Test failed if crash happens
|
||||
EXPECT_NO_CRASH(release_order_test(order, targetDevice, function));
|
||||
}
|
||||
} // namespace BehaviorTestsDefinitions
|
||||
} // namespace BehaviorTestsDefinitions
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include "behavior/memory_states.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
|
||||
std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
|
||||
std::string VariableStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
|
||||
std::ostringstream result;
|
||||
InferenceEngine::CNNNetwork net;
|
||||
std::string targetDevice;
|
||||
@ -17,22 +17,108 @@ std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo<memory
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void MemoryStateTest::SetUp() {
|
||||
void VariableStateTest::SetUp() {
|
||||
std::tie(net, statesToQuery, deviceName) = GetParam();
|
||||
}
|
||||
|
||||
InferenceEngine::ExecutableNetwork MemoryStateTest::PrepareNetwork() {
|
||||
InferenceEngine::ExecutableNetwork VariableStateTest::PrepareNetwork() {
|
||||
net.addOutput("Memory_1");
|
||||
net.addOutput("Memory_2");
|
||||
auto ie = PluginCache::get().ie(deviceName);
|
||||
return ie->LoadNetwork(net, deviceName);
|
||||
}
|
||||
|
||||
TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) {
|
||||
TEST_P(VariableStateTest, smoke_VariableState_QueryState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto executableNet = PrepareNetwork();
|
||||
|
||||
auto states = executableNet.QueryState();
|
||||
ASSERT_TRUE(states.size() == 2) << "Incorrect number of MemoryStates";
|
||||
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
|
||||
|
||||
for (auto&& state : states) {
|
||||
auto name = state.GetName();
|
||||
ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
|
||||
<< "State " << name << "expected to be in memory states but it is not!";
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_P(VariableStateTest, smoke_VariableState_SetState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto executableNet = PrepareNetwork();
|
||||
const float new_state_val = 13.0f;
|
||||
for (auto&& state : executableNet.QueryState()) {
|
||||
state.Reset();
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
std::vector<float> new_state_data(element_count, new_state_val);
|
||||
auto stateBlob = InferenceEngine::make_shared_blob<float>(
|
||||
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
|
||||
new_state_data.data(), new_state_data.size());
|
||||
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
for (auto&& state : executableNet.QueryState()) {
|
||||
auto lastState = state.GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
auto last_state_data = lastState->cbuffer().as<float*>();
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
|
||||
for (int i = 0; i < last_state_size; i++) {
|
||||
EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_P(VariableStateTest, smoke_VariableState_Reset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto executableNet = PrepareNetwork();
|
||||
const float new_state_val = 13.0f;
|
||||
for (auto&& state : executableNet.QueryState()) {
|
||||
state.Reset();
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
std::vector<float> new_state_data(element_count, new_state_val);
|
||||
auto stateBlob = InferenceEngine::make_shared_blob<float>(
|
||||
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
|
||||
new_state_data.data(), new_state_data.size());
|
||||
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
executableNet.QueryState().front().Reset();
|
||||
|
||||
auto states = executableNet.QueryState();
|
||||
for (int i = 0; i < states.size(); ++i) {
|
||||
auto lastState = states[i].GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
auto last_state_data = lastState->cbuffer().as<float*>();
|
||||
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
|
||||
if (i == 0) {
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(0, last_state_data[j], 1e-5);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(13.0f, last_state_data[j], 1e-5);
|
||||
}
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_P(VariableStateTest, inferreq_smoke_VariableState_QueryState) {
|
||||
auto executableNet = PrepareNetwork();
|
||||
auto inferReq = executableNet.CreateInferRequest();
|
||||
|
||||
auto states = inferReq.QueryState();
|
||||
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
|
||||
|
||||
for (auto&& state : states) {
|
||||
auto name = state.GetName();
|
||||
@ -41,23 +127,26 @@ TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemoryStateTest, smoke_MemoryState_SetState) {
|
||||
TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
|
||||
auto executableNet = PrepareNetwork();
|
||||
auto inferReq = executableNet.CreateInferRequest();
|
||||
|
||||
const float new_state_val = 13.0f;
|
||||
for (auto&& state : executableNet.QueryState()) {
|
||||
for (auto&& state : inferReq.QueryState()) {
|
||||
state.Reset();
|
||||
auto element_count = state.GetLastState()->size();
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
std::vector<float> new_state_data(element_count, new_state_val);
|
||||
auto stateBlob = InferenceEngine::make_shared_blob<float>(
|
||||
{ InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C },
|
||||
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
|
||||
new_state_data.data(), new_state_data.size());
|
||||
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
for (auto&& state : executableNet.QueryState()) {
|
||||
auto lastState = state.GetLastState();
|
||||
for (auto&& state : inferReq.QueryState()) {
|
||||
auto lastState = state.GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
auto last_state_data = lastState->cbuffer().as<float*>();
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
@ -68,26 +157,29 @@ TEST_P(MemoryStateTest, smoke_MemoryState_SetState) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemoryStateTest, smoke_MemoryState_Reset) {
|
||||
TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) {
|
||||
auto executableNet = PrepareNetwork();
|
||||
auto inferReq = executableNet.CreateInferRequest();
|
||||
|
||||
const float new_state_val = 13.0f;
|
||||
for (auto&& state : executableNet.QueryState()) {
|
||||
for (auto&& state : inferReq.QueryState()) {
|
||||
state.Reset();
|
||||
auto element_count = state.GetLastState()->size();
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
std::vector<float> new_state_data(element_count, new_state_val);
|
||||
auto stateBlob = InferenceEngine::make_shared_blob<float>(
|
||||
{ InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C },
|
||||
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
|
||||
new_state_data.data(), new_state_data.size());
|
||||
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
executableNet.QueryState().front().Reset();
|
||||
inferReq.QueryState().front().Reset();
|
||||
|
||||
auto states = executableNet.QueryState();
|
||||
auto states = inferReq.QueryState();
|
||||
for (int i = 0; i < states.size(); ++i) {
|
||||
auto lastState = states[i].GetLastState();
|
||||
auto lastState = states[i].GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
auto last_state_data = lastState->cbuffer().as<float*>();
|
||||
|
||||
|
@ -61,4 +61,5 @@ public:
|
||||
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(void));
|
||||
};
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
|
||||
|
||||
class MockIAsyncInferRequestInternal : public InferenceEngine::IAsyncInferRequestInternal {
|
||||
public:
|
||||
@ -26,4 +27,5 @@ public:
|
||||
MOCK_CONST_METHOD2(GetPreProcess, void(const char* name, const InferenceEngine::PreProcessInfo**));
|
||||
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
|
||||
};
|
||||
|
@ -28,7 +28,7 @@ public:
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
|
||||
MOCK_METHOD1(Export, void(const std::string &));
|
||||
void Export(std::ostream &) override {};
|
||||
MOCK_METHOD0(QueryState, std::vector<IMemoryStateInternal::Ptr>(void));
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
|
||||
MOCK_METHOD0(GetExecGraphInfo, CNNNetwork(void));
|
||||
|
||||
MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
|
||||
#include <cpp_interfaces/impl/ie_memory_state_internal.hpp>
|
||||
|
||||
class MockIInferRequestInternal : public InferenceEngine::IInferRequestInternal {
|
||||
public:
|
||||
@ -20,4 +21,5 @@ public:
|
||||
MOCK_METHOD2(GetBlob, void(const char *name, InferenceEngine::Blob::Ptr &));
|
||||
MOCK_METHOD3(SetBlob, void(const char*, const InferenceEngine::Blob::Ptr&, const InferenceEngine::PreProcessInfo&));
|
||||
MOCK_METHOD2(GetPreProcess, void(const char*, const InferenceEngine::PreProcessInfo**));
|
||||
MOCK_METHOD0(QueryState, std::vector<InferenceEngine::IVariableStateInternal::Ptr>());
|
||||
};
|
||||
|
@ -11,10 +11,10 @@
|
||||
|
||||
#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
|
||||
|
||||
class MockIMemoryStateInternal : public InferenceEngine::IMemoryStateInternal {
|
||||
class MockIVariableStateInternal : public InferenceEngine::IVariableStateInternal {
|
||||
public:
|
||||
MOCK_CONST_METHOD0(GetName, std::string());
|
||||
MOCK_METHOD0(Reset, void());
|
||||
MOCK_METHOD1(SetState, void(InferenceEngine::Blob::Ptr));
|
||||
MOCK_CONST_METHOD0(GetLastState, InferenceEngine::Blob::CPtr());
|
||||
MOCK_CONST_METHOD0(GetState, InferenceEngine::Blob::CPtr());
|
||||
};
|
||||
|
@ -12,10 +12,10 @@
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
class MockIMemoryState : public InferenceEngine::IMemoryState {
|
||||
class MockIVariableState : public InferenceEngine::IVariableState {
|
||||
public:
|
||||
MOCK_QUALIFIED_METHOD3(GetName, const noexcept, StatusCode(char * , size_t, ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD1(Reset, noexcept, StatusCode(ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD2(SetState, noexcept, StatusCode(Blob::Ptr, ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD2(GetLastState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD2(GetState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *));
|
||||
};
|
||||
|
@ -30,6 +30,6 @@ public:
|
||||
MOCK_QUALIFIED_METHOD3(GetConfig, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp));
|
||||
MOCK_QUALIFIED_METHOD3(GetMetric, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp));
|
||||
MOCK_QUALIFIED_METHOD2(GetContext, const noexcept, StatusCode(RemoteContext::Ptr &pContext, ResponseDesc *resp));
|
||||
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IMemoryState::Ptr &, size_t, ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD0(Release, noexcept, void());
|
||||
};
|
||||
|
@ -34,4 +34,5 @@ public:
|
||||
MOCK_QUALIFIED_METHOD3(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, ResponseDesc*));
|
||||
MOCK_QUALIFIED_METHOD4(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, const PreProcessInfo&, ResponseDesc*));
|
||||
MOCK_QUALIFIED_METHOD2(SetBatch, noexcept, StatusCode(int batch, ResponseDesc*));
|
||||
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
|
||||
};
|
||||
|
@ -484,6 +484,33 @@ static std::shared_ptr<ngraph::Function> makeConvBias(std::vector<size_t> inputS
|
||||
fn_ptr->set_friendly_name("ConvBias");
|
||||
return fn_ptr;
|
||||
}
|
||||
|
||||
static std::shared_ptr<ngraph::Function> makeReadConcatSplitAssign(std::vector<size_t> inputShape = {1, 1, 2, 4},
|
||||
InferenceEngine::Precision prc = InferenceEngine::Precision::FP32) {
|
||||
ngraph::element::Type type = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc);
|
||||
auto parameter = ngraph::builder::makeParams(type, {inputShape});
|
||||
parameter[0]->set_friendly_name("parameter");
|
||||
auto init_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {0, 0, 0, 0});
|
||||
auto read = std::make_shared<ngraph::op::ReadValue>(init_const, "v0");
|
||||
read->set_friendly_name("read");
|
||||
std::vector<std::shared_ptr<ngraph::Node>> args = {parameter[0], read};
|
||||
auto conc = std::make_shared<ngraph::op::Concat>(args, 3);
|
||||
conc->set_friendly_name("concat");
|
||||
auto res = std::make_shared<ngraph::op::Result>(conc);
|
||||
res->set_friendly_name("result");
|
||||
const auto axis = ngraph::op::Constant::create(element::i64, Shape{}, {3});
|
||||
axis->set_friendly_name("axis");
|
||||
auto crop = std::make_shared<ngraph::op::v1::Split>(conc, axis, 3);
|
||||
crop->set_friendly_name("crop");
|
||||
auto assign = std::make_shared<ngraph::op::Assign>(crop, "v0");
|
||||
assign->set_friendly_name("assign");
|
||||
|
||||
std::shared_ptr<ngraph::Function> fn_ptr = std::make_shared<ngraph::Function>(ngraph::ResultVector({res}),
|
||||
ngraph::SinkVector({assign}),
|
||||
ngraph::ParameterVector{parameter});
|
||||
fn_ptr->set_friendly_name("ReadConcatSplitAssign");
|
||||
return fn_ptr;
|
||||
}
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
|
@ -7,152 +7,183 @@
|
||||
#include <cpp/ie_executable_network.hpp>
|
||||
|
||||
#include <cpp_interfaces/base/ie_executable_network_base.hpp>
|
||||
#include <cpp_interfaces/impl/ie_memory_state_internal.hpp>
|
||||
#include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
|
||||
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
|
||||
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace std;
|
||||
using namespace InferenceEngine;
|
||||
using namespace InferenceEngine::details;
|
||||
|
||||
class MemoryStateTests : public ::testing::Test {
|
||||
template <class T>
|
||||
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
|
||||
typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
|
||||
p->Release();
|
||||
});
|
||||
return InferenceEngine::InferRequest(req);
|
||||
}
|
||||
|
||||
|
||||
class VariableStateTests : public ::testing::Test {
|
||||
protected:
|
||||
shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
|
||||
shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
|
||||
shared_ptr<MockIAsyncInferRequestInternal> mockInferRequestInternal;
|
||||
shared_ptr<MockIVariableStateInternal> mockVariableStateInternal;
|
||||
|
||||
virtual void SetUp() {
|
||||
mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
|
||||
mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
|
||||
mockInferRequestInternal = make_shared<MockIAsyncInferRequestInternal>();
|
||||
mockVariableStateInternal = make_shared<MockIVariableStateInternal>();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
|
||||
TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn(1);
|
||||
toReturn[0] = mockMemoryStateInternal;
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn(1);
|
||||
toReturn[0] = mockVariableStateInternal;
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = net.QueryState();
|
||||
ASSERT_EQ(state.size(), 1);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
|
||||
TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
|
||||
|
||||
auto state = net.QueryState();
|
||||
ASSERT_EQ(state.size(), 0);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
|
||||
TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = net.QueryState();
|
||||
ASSERT_EQ(state.size(), 2);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesReset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
|
||||
|
||||
auto state = net.QueryState();
|
||||
state.front().Reset();
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
|
||||
|
||||
auto state = net.QueryState();
|
||||
EXPECT_ANY_THROW(state.front().Reset());
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto state = net.QueryState();
|
||||
EXPECT_STREQ(state.front().GetName().c_str(), "someName");
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
IMemoryState::Ptr pState;
|
||||
IVariableState::Ptr pState;
|
||||
|
||||
static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
|
||||
char *name = reinterpret_cast<char *>(1);
|
||||
EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
IMemoryState::Ptr pState;
|
||||
IVariableState::Ptr pState;
|
||||
|
||||
static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
|
||||
char name[1];
|
||||
EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
|
||||
EXPECT_STREQ(name, "");
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
|
||||
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
IMemoryState::Ptr pState;
|
||||
IVariableState::Ptr pState;
|
||||
|
||||
static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
|
||||
char name[2];
|
||||
EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
|
||||
EXPECT_STREQ(name, "s");
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
|
||||
TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
Blob::Ptr saver;
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
@ -161,47 +192,50 @@ TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
|
||||
TEST_F(VariableStateTests, VariableStateCanPropagateGetLastState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto net = make_executable_network(mockExeNetworkInternal);
|
||||
std::vector<IMemoryStateInternal::Ptr> toReturn;
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
|
||||
toReturn.push_back(mockMemoryStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
|
||||
|
||||
|
||||
auto saver = net.QueryState().front().GetLastState();
|
||||
auto saver = net.QueryState().front().GetState();
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
class MemoryStateInternalMockImpl : public MemoryStateInternal {
|
||||
class VariableStateInternalMockImpl : public VariableStateInternal {
|
||||
public:
|
||||
using MemoryStateInternal::MemoryStateInternal;
|
||||
using VariableStateInternal::VariableStateInternal;
|
||||
MOCK_METHOD0(Reset, void());
|
||||
};
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
|
||||
IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
|
||||
TEST_F(VariableStateTests, VariableStateInternalCanSaveName) {
|
||||
IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
|
||||
ASSERT_STREQ(pState->GetName().c_str(), "name");
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
|
||||
IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
|
||||
TEST_F(VariableStateTests, VariableStateInternalCanSaveState) {
|
||||
IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
pState->SetState(stateBlob);
|
||||
auto saver = pState->GetLastState();
|
||||
auto saver = pState->GetState();
|
||||
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
|
||||
@ -209,8 +243,8 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
|
||||
IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
|
||||
TEST_F(VariableStateTests, VariableStateInternalCanSaveStateByReference) {
|
||||
IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
@ -219,9 +253,162 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
|
||||
data[0] = 121;
|
||||
data[1] = 122;
|
||||
data[2] = 123;
|
||||
auto saver = pState->GetLastState();
|
||||
auto saver = pState->GetState();
|
||||
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);
|
||||
}
|
||||
|
||||
// Tests for InferRequest::QueryState
|
||||
TEST_F(VariableStateTests, InferRequestCanConvertOneVariableStateFromCppToAPI) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn(1);
|
||||
toReturn[0] = mockVariableStateInternal;
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = req.QueryState();
|
||||
ASSERT_EQ(state.size(), 1);
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillOnce(Return(toReturn));
|
||||
|
||||
auto state = req.QueryState();
|
||||
ASSERT_EQ(state.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InferRequestCanConvert2VariableStatesFromCPPtoAPI) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = req.QueryState();
|
||||
ASSERT_EQ(state.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
|
||||
|
||||
auto state = req.QueryState();
|
||||
state.front().Reset();
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
|
||||
|
||||
auto state = req.QueryState();
|
||||
EXPECT_ANY_THROW(state.front().Reset());
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetName) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto state = req.QueryState();
|
||||
EXPECT_STREQ(state.front().GetName().c_str(), "someName");
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
IVariableState::Ptr pState;
|
||||
|
||||
static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
|
||||
char *name = reinterpret_cast<char *>(1);
|
||||
EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
IVariableState::Ptr pState;
|
||||
|
||||
static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
|
||||
char name[1];
|
||||
EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
|
||||
EXPECT_STREQ(name, "");
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
IVariableState::Ptr pState;
|
||||
|
||||
static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
|
||||
char name[2];
|
||||
EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
|
||||
EXPECT_STREQ(name, "s");
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
Blob::Ptr saver;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
EXPECT_NO_THROW(req.QueryState().front().SetState(stateBlob));
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
|
||||
}
|
||||
|
||||
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateGetLastState) {
|
||||
auto req = make_infer_request(mockInferRequestInternal);
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
|
||||
|
||||
auto saver = req.QueryState().front().GetState();
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
|
||||
}
|
||||
|
@ -127,6 +127,7 @@ TEST_F(ExecutableNetworkTests, OperatorAmpersand) {
|
||||
ASSERT_EQ(exeNet_p, mockIExeNet_p);
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) {
|
||||
EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
|
||||
.Times(1)
|
||||
@ -138,21 +139,22 @@ TEST_F(ExecutableNetworkTests, QueryStateIfReturnOutOfBounds) {
|
||||
EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
|
||||
.Times(1)
|
||||
.WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS));
|
||||
std::vector<InferenceEngine::MemoryState> MemState_;
|
||||
std::vector<InferenceEngine::VariableState> MemState_;
|
||||
EXPECT_NO_THROW(MemState_ = exeNetwork->QueryState());
|
||||
EXPECT_EQ(MemState_.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkTests, QueryState) {
|
||||
std::shared_ptr<MockIMemoryState> mockIMemState_p = std::make_shared<MockIMemoryState>();
|
||||
std::shared_ptr<MockIVariableState> mockIMemState_p = std::make_shared<MockIVariableState>();
|
||||
EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
|
||||
.Times(2)
|
||||
.WillOnce(DoAll(SetArgReferee<0>(mockIMemState_p), Return(InferenceEngine::OK)))
|
||||
.WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS));
|
||||
std::vector<InferenceEngine::MemoryState> MemState_v;
|
||||
std::vector<InferenceEngine::VariableState> MemState_v;
|
||||
EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState());
|
||||
EXPECT_EQ(MemState_v.size(), 1);
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
|
||||
protected:
|
||||
|
@ -207,6 +207,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() {
|
||||
}
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
if (fetchResult.reset) {
|
||||
auto states = executableApi.QueryState();
|
||||
ASSERT_FALSE(states.empty());
|
||||
@ -218,6 +219,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() {
|
||||
outputs["reset"] = nullptr;
|
||||
//continue;
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
//FAIL()<<"stop after one frame";
|
||||
|
||||
|
@ -808,6 +808,7 @@ void GNAQueryStateMatcher :: match() {
|
||||
|
||||
EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
|
||||
#endif
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
try {
|
||||
loadNetwork();
|
||||
if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) {
|
||||
@ -830,6 +831,7 @@ void GNAQueryStateMatcher :: match() {
|
||||
catch(...) {
|
||||
FAIL() << "unknown exception thrown";
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user