Move QueryState from ExecutableNetwork to InferRequest (#2818)

* QueryState moved to InferRequest

* deprecate ExecutableNetwork::QueryState,chaged tests (without any check yet)

* fix build

* review fixes + build fix

* build fix + review changes

* remove blank line

* style fixes

* test build fixes

* style fix

* style fix

* fixed build of tests

* fix

* mac build fix

* hddl plugin build fix

* clean up unneeded implementation for method

* fixed tests build

* add implementation for getstate, correct getName for MklDNN

* fixed description of state API in comments

* lint fixes

* Rename MemoryState to VariableState

* added tests for cpu for VariableStates, several small fixes in tests and code

* merge fix

* lint fix

* remove whitespaces

* spaces fix

* fix in test to make it workable for all plugins

* fix typo

* fix test for gna

* remove extra comment

* fix test for gna
This commit is contained in:
Svetlana Dolinina 2020-11-12 12:40:43 +03:00 committed by GitHub
parent 809c504d0a
commit 7bd76dc12b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 717 additions and 196 deletions

View File

@ -175,11 +175,13 @@ public:
* Wraps IExecutableNetwork::QueryState
* @return A vector of Memory State objects
*/
std::vector<MemoryState> QueryState() {
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<VariableState> QueryState() {
IE_SUPPRESS_DEPRECATED_START
if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized.";
IMemoryState::Ptr pState = nullptr;
IVariableState::Ptr pState = nullptr;
auto res = OK;
std::vector<MemoryState> controller;
std::vector<VariableState> controller;
for (size_t idx = 0; res == OK; ++idx) {
ResponseDesc resp;
res = actual->QueryState(pState, idx, &resp);
@ -187,10 +189,11 @@ public:
THROW_IE_EXCEPTION << resp.msg;
}
if (res != OUT_OF_BOUNDS) {
controller.push_back(MemoryState(pState));
controller.push_back(VariableState(pState, plg));
}
}
IE_SUPPRESS_DEPRECATED_END
return controller;
}

View File

@ -13,6 +13,7 @@
#include <memory>
#include <string>
#include "cpp/ie_memory_state.hpp"
#include "ie_iinfer_request.hpp"
#include "details/ie_exception_conversion.hpp"
#include "details/ie_so_loader.h"
@ -250,6 +251,31 @@ public:
actual->SetCompletionCallback(callWrapper);
}
/**
* @copybrief IExecutableNetwork::QueryState
*
* Wraps IExecutableNetwork::QueryState
* @return A vector of Memory State objects
*/
std::vector<VariableState> QueryState() {
if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized.";
IVariableState::Ptr pState = nullptr;
auto res = OK;
std::vector<VariableState> controller;
for (size_t idx = 0; res == OK; ++idx) {
ResponseDesc resp;
res = actual->QueryState(pState, idx, &resp);
if (res != OK && res != OUT_OF_BOUNDS) {
THROW_IE_EXCEPTION << resp.msg;
}
if (res != OUT_OF_BOUNDS) {
controller.push_back(VariableState(pState, plg));
}
}
return controller;
}
/**
* @brief IInferRequest pointer to be used directly in CreateInferRequest functions
* @return A shared pointer to underlying IInferRequest interface

View File

@ -11,39 +11,42 @@
#include <string>
#include "ie_imemory_state.hpp"
#include "details/ie_exception_conversion.hpp"
#include "details/ie_so_loader.h"
namespace InferenceEngine {
/**
* @brief C++ exception based error reporting wrapper of API class IMemoryState
* @brief C++ exception based error reporting wrapper of API class IVariableState
*/
class MemoryState {
IMemoryState::Ptr actual = nullptr;
class VariableState {
IVariableState::Ptr actual = nullptr;
details::SharedObjectLoader::Ptr plugin = {};
public:
/**
* constructs MemoryState from the initialized shared_pointer
* constructs VariableState from the initialized shared_pointer
* @param pState Initialized shared pointer
*/
explicit MemoryState(IMemoryState::Ptr pState): actual(pState) {
explicit VariableState(IVariableState::Ptr pState, details::SharedObjectLoader::Ptr plg = {}) : actual(pState), plugin(plg) {
if (actual == nullptr) {
THROW_IE_EXCEPTION << "MemoryState wrapper was not initialized.";
THROW_IE_EXCEPTION << "VariableState wrapper was not initialized.";
}
}
/**
* @copybrief IMemoryState::Reset
* @copybrief IVariableState::Reset
*
* Wraps IMemoryState::Reset
* Wraps IVariableState::Reset
*/
void Reset() {
CALL_STATUS_FNC_NO_ARGS(Reset);
}
/**
* @copybrief IMemoryState::GetName
* @copybrief IVariableState::GetName
*
* Wraps IMemoryState::GetName
* Wraps IVariableState::GetName
* @return A string representing a state name
*/
std::string GetName() const {
@ -53,21 +56,26 @@ public:
}
/**
* @copybrief IMemoryState::GetLastState
* @copybrief IVariableState::GetState
*
* Wraps IMemoryState::GetLastState
* Wraps IVariableState::GetState
* @return A blob representing a last state
*/
Blob::CPtr GetLastState() const {
Blob::CPtr GetState() const {
Blob::CPtr stateBlob;
CALL_STATUS_FNC(GetLastState, stateBlob);
CALL_STATUS_FNC(GetState, stateBlob);
return stateBlob;
}
INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
Blob::CPtr GetLastState() const {
return GetState();
}
/**
* @copybrief IMemoryState::SetState
* @copybrief IVariableState::SetState
*
* Wraps IMemoryState::SetState
* Wraps IVariableState::SetState
* @param state The current state to set
*/
void SetState(Blob::Ptr state) {
@ -75,4 +83,8 @@ public:
}
};
} // namespace InferenceEngine
/*
* @brief For compatibility reasons.
*/
using MemoryState = VariableState;
} // namespace InferenceEngine

View File

@ -118,7 +118,7 @@ public:
* @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for
* given index
*/
virtual StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
/**
* @brief Sets configuration for current executable network

View File

@ -17,6 +17,7 @@
#include "ie_blob.h"
#include "ie_common.h"
#include "ie_preprocess.hpp"
#include "ie_imemory_state.hpp"
#include "details/ie_irelease.hpp"
namespace InferenceEngine {
@ -177,6 +178,18 @@ public:
* @return Enumeration of the resulted action: InferenceEngine::OK (0) for success
*/
virtual InferenceEngine::StatusCode SetBatch(int batch_size, ResponseDesc* resp) noexcept = 0;
};
} // namespace InferenceEngine
/**
* @brief Gets state control interface for given infer request.
*
* State control essential for recurrent networks
*
* @param pState reference to a pointer that receives internal states
* @param idx requested index for receiving memory state
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
* @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for
* given index
*/
virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
};
} // namespace InferenceEngine

View File

@ -3,7 +3,7 @@
//
/**
* @brief a header file for IMemoryState interface
* @brief a header file for IVariableState interface
*
* @file ie_imemory_state.hpp
*/
@ -19,19 +19,19 @@
namespace InferenceEngine {
/**
* @interface IMemoryState
* @interface IVariableState
* @brief manages data for reset operations
*/
class IMemoryState : public details::no_copy {
class IVariableState : public details::no_copy {
public:
/**
* @brief A shared pointer to the IMemoryState interface
* @brief A shared pointer to the IVariableState interface
*/
using Ptr = std::shared_ptr<IMemoryState>;
using Ptr = std::shared_ptr<IVariableState>;
/**
* @brief Gets name of current memory state, if length of array is not enough name is truncated by len, null
* terminator is inserted as well.
* terminator is inserted as well. As memory state name variable_id from according ReadValue used.
*
* @param name preallocated buffer for receiving name
* @param len Length of the buffer
@ -41,7 +41,7 @@ public:
virtual StatusCode GetName(char* name, size_t len, ResponseDesc* resp) const noexcept = 0;
/**
* @brief reset internal memory state for relevant iexecutable network, to a value specified in SetState
* @brief Reset internal memory state for relevant infer request, to a value specified as default for according ReadValue node
*
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
* @return Status code of the operation: InferenceEngine::OK (0) for success*
@ -49,25 +49,30 @@ public:
virtual StatusCode Reset(ResponseDesc* resp) noexcept = 0;
/**
* @brief Sets the new state that is used for all future Reset() operations as a base.
* @brief Sets the new state for the next inference.
*
* This method can fail if Blob size does not match the internal state size or precision
*
* @param newState is the data to use as base state
* @param newState is the data to use as new state
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
* @return Status code of the operation: InferenceEngine::OK (0) for success
*/
virtual StatusCode SetState(Blob::Ptr newState, ResponseDesc* resp) noexcept = 0;
/**
* @brief returns the value of the last memory state.
* @brief Returns the value of the memory state.
*
* @details Since we roll memory after each infer, we can query the input state always and still get the last state.
* @param lastState
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
* @return Status code of the operation: InferenceEngine::OK (0) for success
* */
virtual StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept = 0;
INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
virtual StatusCode GetLastState(Blob::CPtr& state, ResponseDesc* resp) const noexcept {return GetState(state, resp);}
virtual StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept = 0;
};
/*
* @brief For compatibility reasons.
*/
using IMemoryState = IVariableState;
} // namespace InferenceEngine

View File

@ -845,7 +845,7 @@ int main(int argc, char *argv[]) {
ptrUtterances.resize(inputArkFiles.size());
// initialize memory state before starting
for (auto &&state : executableNet.QueryState()) {
for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
state.Reset();
}
@ -1080,7 +1080,7 @@ int main(int argc, char *argv[]) {
totalTime += d.count();
// resetting state between utterances
for (auto &&state : executableNet.QueryState()) {
for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
state.Reset();
}

View File

@ -59,12 +59,13 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
return std::make_shared<GNAInferRequest>(plg, networkInputs, networkOutputs);
}
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState() override {
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
IE_SUPPRESS_DEPRECATED_START
auto pluginStates = plg->QueryState();
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
return plg->QueryState();
IE_SUPPRESS_DEPRECATED_END
}
void Export(const std::string &modelFileName) override {

View File

@ -111,5 +111,13 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
}
return InferenceEngine::OK;
}
IE_SUPPRESS_DEPRECATED_START
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
auto pluginStates = plg->QueryState();
std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
return plg->QueryState();
}
IE_SUPPRESS_DEPRECATED_END
};
} // namespace GNAPluginNS

View File

@ -1186,11 +1186,11 @@ Blob::Ptr GNAPlugin::GetInputBlob(const std::string& name, InferenceEngine::Prec
return inputBlob;
}
std::vector<InferenceEngine::MemoryStateInternal::Ptr> GNAPlugin::QueryState() {
std::vector<InferenceEngine::VariableStateInternal::Ptr> GNAPlugin::QueryState() {
if (memoryStates.size() != graphCompiler.memory_connection.size()) {
memoryStates.clear();
for (auto& connection : graphCompiler.memory_connection) {
auto state = std::make_shared<memory::GNAMemoryState>(connection.first, std::make_shared <GNAMemoryLayer>(connection.second));
auto state = std::make_shared<memory::GNAVariableState>(connection.first, std::make_shared <GNAMemoryLayer>(connection.second));
memoryStates.emplace_back(state);
}
}

View File

@ -84,7 +84,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
InferenceEngine::InputsDataMap inputsDataMap;
InferenceEngine::OutputsDataMap outputsDataMap;
std::vector<InferenceEngine::MemoryStateInternal::Ptr> memoryStates;
std::vector<InferenceEngine::VariableStateInternal::Ptr> memoryStates;
public:
explicit GNAPlugin(const std::map<std::string, std::string>& configMap);
@ -159,7 +159,8 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
* QueryState API
* @return
*/
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState();
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState();
/**
* test-wise API

View File

@ -12,15 +12,15 @@ namespace GNAPluginNS {
namespace memory {
std::string GNAMemoryState::GetName() const {
std::string GNAVariableState::GetName() const {
return name;
}
void GNAMemoryState::Reset() {
void GNAVariableState::Reset() {
state->Reset();
}
InferenceEngine::Precision GNAMemoryState::getPrecision() const {
InferenceEngine::Precision GNAVariableState::getPrecision() const {
InferenceEngine::Precision state_precision;
if (state->getInput()) {
@ -36,14 +36,14 @@ namespace memory {
break;
default:
THROW_GNA_EXCEPTION << "Incorrect state element size " << element_size <<
" to determine precision for MemoryState " << name;
" to determine precision for VariableState " << name;
}
}
return state_precision;
}
void GNAMemoryState::SetState(InferenceEngine::Blob::Ptr newState) {
void GNAVariableState::SetState(InferenceEngine::Blob::Ptr newState) {
IE_ASSERT(newState != nullptr);
auto data_ptr = newState->cbuffer().as<void*>();
@ -78,20 +78,20 @@ namespace memory {
data_elements,
scale_factor);
} else {
THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name
THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name
<< ". If old state precision is I16 only I16 and FP32 are allowed as new state precisions."
<< " Old state: " << state_precision << " New state: " << new_state_precision;
}
break;
}
default:
THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name
THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name
<< ". Incorrect new/old precision pair"
<< " Old state: " << state_precision << " New state: " << new_state_precision;
}
}
InferenceEngine::Blob::CPtr GNAMemoryState::GetLastState() const {
InferenceEngine::Blob::CPtr GNAVariableState::GetState() const {
auto elements = state->reserved_size / state->elementSizeBytes();
InferenceEngine::Precision state_precision = getPrecision();

View File

@ -11,14 +11,14 @@
namespace GNAPluginNS {
namespace memory {
class GNAMemoryState : public InferenceEngine::IMemoryStateInternal {
class GNAVariableState : public InferenceEngine::IVariableStateInternal {
public:
GNAMemoryState(std::string name, std::shared_ptr<GNAMemoryLayer> state)
GNAVariableState(std::string name, std::shared_ptr<GNAMemoryLayer> state)
: name(name), state(state) { IE_ASSERT(state != nullptr); }
void Reset() override;
void SetState(InferenceEngine::Blob::Ptr newState) override;
InferenceEngine::Blob::CPtr GetLastState() const override;
InferenceEngine::Blob::CPtr GetState() const override;
std::string GetName() const override;
private:

View File

@ -183,14 +183,14 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network
if (node->getType() == MemoryInput) {
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
auto state_store = memoryNode->getStore();
auto state_name = node->getName();
auto state_name = memoryNode->getId();
// Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos)
state_name = state_name.substr(0, suffix_idx);
memoryStates.emplace_back(new MKLDNNMemoryState(state_name, state_store));
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
}
}
}
@ -314,6 +314,8 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::ICNNNetwork &n
return check_result;
}
std::vector<IMemoryStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
return memoryStates;
}
IE_SUPPRESS_DEPRECATED_END

View File

@ -42,14 +42,15 @@ public:
InferenceEngine::CNNNetwork GetExecGraphInfo() override;
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState() override;
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
InferenceEngine::ThreadLocal<MKLDNNGraph::Ptr> _graphs;
protected:
friend class MKLDNNInferRequest;
MKLDNNExtensionManager::Ptr extensionManager;
std::vector<InferenceEngine::IMemoryStateInternal::Ptr> memoryStates;
std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
InferenceEngine::details::CNNNetworkImplPtr _clonedNetwork;
std::mutex _cfgMutex;
Config _cfg;

View File

@ -14,6 +14,8 @@
#include "mkldnn_exec_network.h"
#include "mkldnn_itt.h"
#include "nodes/common/cpu_convert.h"
#include "mkldnn_memory_state.h"
#include "nodes/mkldnn_memory_node.hpp"
MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs,
InferenceEngine::OutputsDataMap networkOutputs,
@ -35,6 +37,30 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData
InferenceEngine::Blob::Ptr blob;
MKLDNNInferRequest::GetBlob(it.first.c_str(), blob);
}
// Save all MemoryLayer data tensors. Will use insight about mechanics
// of MemoryLayer implementation. It uses output edge of MemoryLayer
// producer as storage for tensor to keep it between infer calls.
IE_SUPPRESS_DEPRECATED_START
if (execNetwork->QueryState().size() == 0) {
for (auto &node : graph->GetNodes()) {
if (node->getType() == MemoryInput) {
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
auto state_store = memoryNode->getStore();
auto state_name = memoryNode->getId();
// Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos)
state_name = state_name.substr(0, suffix_idx);
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
}
}
} else {
memoryStates = execNetwork->QueryState();
}
IE_SUPPRESS_DEPRECATED_END
}
MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {
@ -390,3 +416,7 @@ void MKLDNNPlugin::MKLDNNInferRequest::SetBatch(int new_batch) {
m_curBatch = new_batch;
}
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MKLDNNPlugin::MKLDNNInferRequest::QueryState() {
return memoryStates;
}

View File

@ -43,6 +43,8 @@ public:
void SetBatch(int batch = -1) override;
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
private:
void PushInputData();
@ -53,5 +55,6 @@ private:
MKLDNNGraph* graph = nullptr;
std::map<std::string, void*> externalPtr;
openvino::itt::handle_t profilingTask;
std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
};
} // namespace MKLDNNPlugin

View File

@ -4,20 +4,21 @@
#include "mkldnn_memory_state.h"
#include "mkldnn_extension_utils.h"
#include "blob_factory.hpp"
using namespace InferenceEngine;
namespace MKLDNNPlugin {
std::string MKLDNNMemoryState::GetName() const {
std::string MKLDNNVariableState::GetName() const {
return name;
}
void MKLDNNMemoryState::Reset() {
void MKLDNNVariableState::Reset() {
storage->FillZero();
}
void MKLDNNMemoryState::SetState(Blob::Ptr newState) {
void MKLDNNVariableState::SetState(Blob::Ptr newState) {
auto prec = newState->getTensorDesc().getPrecision();
auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(prec);
auto data_layout = MKLDNNMemory::Convert(newState->getTensorDesc().getLayout());
@ -27,9 +28,11 @@ void MKLDNNMemoryState::SetState(Blob::Ptr newState) {
storage->SetData(data_type, data_layout, data_ptr, data_size);
}
InferenceEngine::Blob::CPtr MKLDNNMemoryState::GetLastState() const {
THROW_IE_EXCEPTION << "GetLastState method is not implemented for MemoryState";
return nullptr;
InferenceEngine::Blob::CPtr MKLDNNVariableState::GetState() const {
auto result_blob = make_blob_with_precision(MKLDNNMemoryDesc(storage->GetDescriptor()));
result_blob->allocate();
std::memcpy(result_blob->buffer(), storage->GetData(), storage->GetSize());
return result_blob;
}
} // namespace MKLDNNPlugin
} // namespace MKLDNNPlugin

View File

@ -11,15 +11,15 @@
namespace MKLDNNPlugin {
class MKLDNNMemoryState : public InferenceEngine::IMemoryStateInternal {
class MKLDNNVariableState : public InferenceEngine::IVariableStateInternal {
public:
MKLDNNMemoryState(std::string name, MKLDNNMemoryPtr storage) :
MKLDNNVariableState(std::string name, MKLDNNMemoryPtr storage) :
name(name), storage(storage) {}
std::string GetName() const override;
void Reset() override;
void SetState(InferenceEngine::Blob::Ptr newState) override;
InferenceEngine::Blob::CPtr GetLastState() const override;
InferenceEngine::Blob::CPtr GetState() const override;
private:
std::string name;

View File

@ -66,19 +66,22 @@ public:
TO_STATUS(graphPtr = _impl->GetExecGraphInfo());
}
StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
IE_SUPPRESS_DEPRECATED_START
try {
auto v = _impl->QueryState();
if (idx >= v.size()) {
return OUT_OF_BOUNDS;
}
pState = std::make_shared<MemoryStateBase<IMemoryStateInternal>>(v[idx]);
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
return OK;
} catch (const std::exception& ex) {
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
} catch (...) {
return InferenceEngine::DescriptionBuffer(UNEXPECTED);
}
IE_SUPPRESS_DEPRECATED_END
}
void Release() noexcept override {

View File

@ -10,6 +10,7 @@
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/plugin_itt.hpp"
#include <cpp_interfaces/base/ie_memory_state_base.hpp>
#include "ie_iinfer_request.hpp"
#include "ie_preprocess.hpp"
#include "ie_profiling.hpp"
@ -88,6 +89,21 @@ public:
TO_STATUS(_impl->SetBatch(batch_size));
}
StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
try {
auto v = _impl->QueryState();
if (idx >= v.size()) {
return OUT_OF_BOUNDS;
}
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
return OK;
} catch (const std::exception& ex) {
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
} catch (...) {
return InferenceEngine::DescriptionBuffer(UNEXPECTED);
}
}
private:
~InferRequestBase() = default;
};

View File

@ -7,23 +7,24 @@
#include <memory>
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/impl/ie_memory_state_internal.hpp"
#include "ie_imemory_state.hpp"
namespace InferenceEngine {
/**
* @brief default implementation for IMemoryState
* @brief default implementation for IVariableState
* @ingroup ie_dev_api_mem_state_api
*/
template <class T>
class MemoryStateBase : public IMemoryState {
class VariableStateBase : public IVariableState {
protected:
std::shared_ptr<T> impl;
public:
explicit MemoryStateBase(std::shared_ptr<T> impl): impl(impl) {
explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
if (impl == nullptr) {
THROW_IE_EXCEPTION << "MemoryStateBase implementation not defined";
THROW_IE_EXCEPTION << "VariableStateBase implementation not defined";
}
}
@ -44,9 +45,9 @@ public:
TO_STATUS(impl->SetState(newState));
}
StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept override {
TO_STATUS(lastState = impl->GetLastState());
StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept override {
TO_STATUS(state = impl->GetState());
}
};
} // namespace InferenceEngine
} // namespace InferenceEngine

View File

@ -88,7 +88,7 @@ public:
_plugin = plugin;
}
std::vector<IMemoryStateInternal::Ptr> QueryState() override {
std::vector<IVariableStateInternal::Ptr> QueryState() override {
THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
}

View File

@ -152,6 +152,10 @@ public:
_publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
}
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
return _syncRequest->QueryState();
}
protected:
/**
* @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation

View File

@ -223,6 +223,12 @@ public:
}
}
std::vector<IVariableStateInternal::Ptr> QueryState() override {
// meaning base plugin reports as no state available - plugin owners need to create proper override of this
THROW_IE_EXCEPTION << "Plugin doesn't override QueryState";
return {};
}
protected:
InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info
InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data

View File

@ -13,21 +13,25 @@ namespace InferenceEngine {
* @brief minimal interface for memory state implementation
* @ingroup ie_dev_api_mem_state_api
*/
class MemoryStateInternal : public IMemoryStateInternal {
class VariableStateInternal : public IVariableStateInternal {
std::string name;
Blob::Ptr state;
public:
explicit MemoryStateInternal(std::string name): name(name) {}
explicit VariableStateInternal(std::string name): name(name) {}
std::string GetName() const override {
return name;
}
void SetState(Blob::Ptr newState) override {
state = newState;
}
Blob::CPtr GetLastState() const override {
Blob::CPtr GetState() const override {
return state;
}
};
} // namespace InferenceEngine
/*
* @brief For compatibility reasons.
*/
using MemoryStateInternal = VariableStateInternal;
} // namespace InferenceEngine

View File

@ -79,7 +79,7 @@ public:
* @brief Queries memory states.
* @return Returns memory states
*/
virtual std::vector<IMemoryStateInternal::Ptr> QueryState() = 0;
virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
/**
* @brief Sets configuration for current executable network

View File

@ -4,6 +4,7 @@
#pragma once
#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
#include <ie_blob.h>
#include <ie_common.h>
#include <ie_preprocess.hpp>
@ -83,6 +84,12 @@ public:
* @param batch - new batch size to be used by all the following inference calls for this request.
*/
virtual void SetBatch(int batch) = 0;
/**
* @brief Queries memory states.
* @return Returns memory states
*/
virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
};
} // namespace InferenceEngine

View File

@ -11,19 +11,25 @@
namespace InferenceEngine {
/**
* @interface IMemoryStateInternal
* @interface IVariableStateInternal
* @brief minimal interface for memory state implementation
* @ingroup ie_dev_api_mem_state_api
*/
class IMemoryStateInternal {
class IVariableStateInternal {
public:
using Ptr = std::shared_ptr<IMemoryStateInternal>;
using Ptr = std::shared_ptr<IVariableStateInternal>;
virtual ~IMemoryStateInternal() = default;
virtual ~IVariableStateInternal() = default;
virtual std::string GetName() const = 0;
virtual void Reset() = 0;
virtual void SetState(Blob::Ptr newState) = 0;
virtual Blob::CPtr GetLastState() const = 0;
virtual Blob::CPtr GetState() const = 0;
INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
virtual Blob::CPtr GetLastState() const {return GetState();}
};
/*
* @brief For compatibility reasons.
*/
using IMemoryStateInternal = IVariableStateInternal;
} // namespace InferenceEngine

View File

@ -83,3 +83,8 @@ TEST(InferRequestCPPTests, throwsOnUninitializedCast) {
InferRequest req;
ASSERT_THROW(auto &ireq = static_cast<IInferRequest::Ptr &>(req), InferenceEngine::details::InferenceEngineException);
}
TEST(InferRequestCPPTests, throwsOnUninitializedQueryState) {
InferRequest req;
ASSERT_THROW(req.QueryState(), InferenceEngine::details::InferenceEngineException);
}

View File

@ -46,8 +46,10 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) {
}
TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) {
IE_SUPPRESS_DEPRECATED_START
ExecutableNetwork exec;
ASSERT_THROW(exec.QueryState(), InferenceEngine::details::InferenceEngineException);
IE_SUPPRESS_DEPRECATED_END
}
TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {

View File

@ -10,12 +10,15 @@ namespace {
// 0 - plugin
// 1 - executable_network
// 2 - infer_request
{0, 1, 2},
{0, 2, 1},
{1, 0, 2},
{1, 2, 0},
{2, 0, 1},
{2, 1, 0}
// 3 - variable state
{3, 0, 1, 2},
{3, 0, 2, 1},
{3, 1, 0, 2},
{3, 1, 2, 0},
{3, 2, 0, 1},
{3, 2, 1, 0},
{0, 3, 1, 2},
{0, 1, 3, 2}
};
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest,
@ -24,4 +27,4 @@ namespace {
::testing::ValuesIn(orders)),
HoldersTest::getTestCaseName);
} // namespace
} // namespace

View File

@ -0,0 +1,22 @@
// Copyright (C) 2020 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//
#include <common_test_utils/test_constants.hpp>
#include "behavior/memory_states.hpp"
#include "functional_test_utils/test_model/test_model.hpp"
#include "functional_test_utils/plugin_cache.hpp"
InferenceEngine::CNNNetwork getNetwork() {
auto model = FuncTestUtils::TestModel::getModelWithMultipleMemoryConnections(InferenceEngine::Precision::FP32);
auto ie = PluginCache::get().ie();
return ie->ReadNetwork(model.model_xml_str, model.weights_blob);
}
std::vector<memoryStateParams> memoryStateTestCases = {
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_CPU)
};
INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest,
::testing::ValuesIn(memoryStateTestCases),
VariableStateTest::getTestCaseName);

View File

@ -10,12 +10,15 @@ namespace {
// 0 - plugin
// 1 - executable_network
// 2 - infer_request
{0, 1, 2},
{0, 2, 1},
{1, 0, 2},
{1, 2, 0},
{2, 0, 1},
{2, 1, 0}
// 3 - variable state
{3, 0, 1, 2},
{3, 0, 2, 1},
{3, 1, 0, 2},
{3, 1, 2, 0},
{3, 2, 0, 1},
{3, 2, 1, 0},
{0, 3, 1, 2},
{0, 1, 3, 2}
};
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest,

View File

@ -17,6 +17,6 @@ std::vector<memoryStateParams> memoryStateTestCases = {
memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA)
};
INSTANTIATE_TEST_CASE_P(smoke_MemoryStateBasic, MemoryStateTest,
INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest,
::testing::ValuesIn(memoryStateTestCases),
MemoryStateTest::getTestCaseName);
VariableStateTest::getTestCaseName);

View File

@ -14,7 +14,7 @@ typedef std::tuple<
std::string> // Target device name
memoryStateParams;
class MemoryStateTest : public CommonTestUtils::TestsCommon,
class VariableStateTest : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<memoryStateParams> {
protected:
InferenceEngine::CNNNetwork net;

View File

@ -25,7 +25,11 @@ namespace BehaviorTestsDefinitions {
if (deathTestStyle == "fast") {
::testing::GTEST_FLAG(death_test_style) = "threadsafe";
}
function = ngraph::builder::subgraph::makeConvPoolRelu();
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
function = ngraph::builder::subgraph::makeReadConcatSplitAssign();
} else {
function = ngraph::builder::subgraph::makeConvPoolRelu();
}
}
void HoldersTest::TearDown() {
@ -42,6 +46,12 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
InferenceEngine::Core core;
auto exe_net = core.LoadNetwork(cnnNet, deviceName);
auto request = exe_net.CreateInferRequest();
std::vector<InferenceEngine::VariableState> states = {};
try {
states = request.QueryState();
} catch(...) {
// do nothing
}
auto release = [&](int i) {
switch (i) {
@ -54,6 +64,9 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
case 2:
request = {};
break;
case 3:
states = {};
break;
default:
break;
}
@ -67,4 +80,4 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
// Test failed if crash happens
EXPECT_NO_CRASH(release_order_test(order, targetDevice, function));
}
} // namespace BehaviorTestsDefinitions
} // namespace BehaviorTestsDefinitions

View File

@ -7,7 +7,7 @@
#include "behavior/memory_states.hpp"
#include "functional_test_utils/plugin_cache.hpp"
std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
std::string VariableStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
std::ostringstream result;
InferenceEngine::CNNNetwork net;
std::string targetDevice;
@ -17,22 +17,108 @@ std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo<memory
return result.str();
}
void MemoryStateTest::SetUp() {
void VariableStateTest::SetUp() {
std::tie(net, statesToQuery, deviceName) = GetParam();
}
InferenceEngine::ExecutableNetwork MemoryStateTest::PrepareNetwork() {
InferenceEngine::ExecutableNetwork VariableStateTest::PrepareNetwork() {
net.addOutput("Memory_1");
net.addOutput("Memory_2");
auto ie = PluginCache::get().ie(deviceName);
return ie->LoadNetwork(net, deviceName);
}
TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) {
TEST_P(VariableStateTest, smoke_VariableState_QueryState) {
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
auto states = executableNet.QueryState();
ASSERT_TRUE(states.size() == 2) << "Incorrect number of MemoryStates";
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
for (auto&& state : states) {
auto name = state.GetName();
ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
<< "State " << name << "expected to be in memory states but it is not!";
}
IE_SUPPRESS_DEPRECATED_END
}
TEST_P(VariableStateTest, smoke_VariableState_SetState) {
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
const float new_state_val = 13.0f;
for (auto&& state : executableNet.QueryState()) {
state.Reset();
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
state.SetState(stateBlob);
}
for (auto&& state : executableNet.QueryState()) {
auto lastState = state.GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int i = 0; i < last_state_size; i++) {
EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
}
}
IE_SUPPRESS_DEPRECATED_END
}
TEST_P(VariableStateTest, smoke_VariableState_Reset) {
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
const float new_state_val = 13.0f;
for (auto&& state : executableNet.QueryState()) {
state.Reset();
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
state.SetState(stateBlob);
}
executableNet.QueryState().front().Reset();
auto states = executableNet.QueryState();
for (int i = 0; i < states.size(); ++i) {
auto lastState = states[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
if (i == 0) {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0, last_state_data[j], 1e-5);
}
} else {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(13.0f, last_state_data[j], 1e-5);
}
}
}
IE_SUPPRESS_DEPRECATED_END
}
TEST_P(VariableStateTest, inferreq_smoke_VariableState_QueryState) {
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
auto states = inferReq.QueryState();
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
for (auto&& state : states) {
auto name = state.GetName();
@ -41,23 +127,26 @@ TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) {
}
}
TEST_P(MemoryStateTest, smoke_MemoryState_SetState) {
TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
const float new_state_val = 13.0f;
for (auto&& state : executableNet.QueryState()) {
for (auto&& state : inferReq.QueryState()) {
state.Reset();
auto element_count = state.GetLastState()->size();
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C },
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
state.SetState(stateBlob);
}
for (auto&& state : executableNet.QueryState()) {
auto lastState = state.GetLastState();
for (auto&& state : inferReq.QueryState()) {
auto lastState = state.GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
@ -68,26 +157,29 @@ TEST_P(MemoryStateTest, smoke_MemoryState_SetState) {
}
}
TEST_P(MemoryStateTest, smoke_MemoryState_Reset) {
TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) {
auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest();
const float new_state_val = 13.0f;
for (auto&& state : executableNet.QueryState()) {
for (auto&& state : inferReq.QueryState()) {
state.Reset();
auto element_count = state.GetLastState()->size();
auto state_val = state.GetState();
auto element_count = state_val->size();
std::vector<float> new_state_data(element_count, new_state_val);
auto stateBlob = InferenceEngine::make_shared_blob<float>(
{ InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C },
{ state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
new_state_data.data(), new_state_data.size());
state.SetState(stateBlob);
}
executableNet.QueryState().front().Reset();
inferReq.QueryState().front().Reset();
auto states = executableNet.QueryState();
auto states = inferReq.QueryState();
for (int i = 0; i < states.size(); ++i) {
auto lastState = states[i].GetLastState();
auto lastState = states[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float*>();

View File

@ -61,4 +61,5 @@ public:
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int));
MOCK_METHOD0(QueryState, std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(void));
};

View File

@ -11,6 +11,7 @@
#include <vector>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
class MockIAsyncInferRequestInternal : public InferenceEngine::IAsyncInferRequestInternal {
public:
@ -26,4 +27,5 @@ public:
MOCK_CONST_METHOD2(GetPreProcess, void(const char* name, const InferenceEngine::PreProcessInfo**));
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
};

View File

@ -28,7 +28,7 @@ public:
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
MOCK_METHOD1(Export, void(const std::string &));
void Export(std::ostream &) override {};
MOCK_METHOD0(QueryState, std::vector<IMemoryStateInternal::Ptr>(void));
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
MOCK_METHOD0(GetExecGraphInfo, CNNNetwork(void));
MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));

View File

@ -11,6 +11,7 @@
#include <vector>
#include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
#include <cpp_interfaces/impl/ie_memory_state_internal.hpp>
class MockIInferRequestInternal : public InferenceEngine::IInferRequestInternal {
public:
@ -20,4 +21,5 @@ public:
MOCK_METHOD2(GetBlob, void(const char *name, InferenceEngine::Blob::Ptr &));
MOCK_METHOD3(SetBlob, void(const char*, const InferenceEngine::Blob::Ptr&, const InferenceEngine::PreProcessInfo&));
MOCK_METHOD2(GetPreProcess, void(const char*, const InferenceEngine::PreProcessInfo**));
MOCK_METHOD0(QueryState, std::vector<InferenceEngine::IVariableStateInternal::Ptr>());
};

View File

@ -11,10 +11,10 @@
#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
class MockIMemoryStateInternal : public InferenceEngine::IMemoryStateInternal {
class MockIVariableStateInternal : public InferenceEngine::IVariableStateInternal {
public:
MOCK_CONST_METHOD0(GetName, std::string());
MOCK_METHOD0(Reset, void());
MOCK_METHOD1(SetState, void(InferenceEngine::Blob::Ptr));
MOCK_CONST_METHOD0(GetLastState, InferenceEngine::Blob::CPtr());
MOCK_CONST_METHOD0(GetState, InferenceEngine::Blob::CPtr());
};

View File

@ -12,10 +12,10 @@
using namespace InferenceEngine;
class MockIMemoryState : public InferenceEngine::IMemoryState {
class MockIVariableState : public InferenceEngine::IVariableState {
public:
MOCK_QUALIFIED_METHOD3(GetName, const noexcept, StatusCode(char * , size_t, ResponseDesc *));
MOCK_QUALIFIED_METHOD1(Reset, noexcept, StatusCode(ResponseDesc *));
MOCK_QUALIFIED_METHOD2(SetState, noexcept, StatusCode(Blob::Ptr, ResponseDesc *));
MOCK_QUALIFIED_METHOD2(GetLastState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *));
MOCK_QUALIFIED_METHOD2(GetState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *));
};

View File

@ -30,6 +30,6 @@ public:
MOCK_QUALIFIED_METHOD3(GetConfig, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp));
MOCK_QUALIFIED_METHOD3(GetMetric, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp));
MOCK_QUALIFIED_METHOD2(GetContext, const noexcept, StatusCode(RemoteContext::Ptr &pContext, ResponseDesc *resp));
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IMemoryState::Ptr &, size_t, ResponseDesc *));
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
MOCK_QUALIFIED_METHOD0(Release, noexcept, void());
};

View File

@ -34,4 +34,5 @@ public:
MOCK_QUALIFIED_METHOD3(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, ResponseDesc*));
MOCK_QUALIFIED_METHOD4(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, const PreProcessInfo&, ResponseDesc*));
MOCK_QUALIFIED_METHOD2(SetBatch, noexcept, StatusCode(int batch, ResponseDesc*));
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
};

View File

@ -484,6 +484,33 @@ static std::shared_ptr<ngraph::Function> makeConvBias(std::vector<size_t> inputS
fn_ptr->set_friendly_name("ConvBias");
return fn_ptr;
}
static std::shared_ptr<ngraph::Function> makeReadConcatSplitAssign(std::vector<size_t> inputShape = {1, 1, 2, 4},
InferenceEngine::Precision prc = InferenceEngine::Precision::FP32) {
ngraph::element::Type type = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc);
auto parameter = ngraph::builder::makeParams(type, {inputShape});
parameter[0]->set_friendly_name("parameter");
auto init_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {0, 0, 0, 0});
auto read = std::make_shared<ngraph::op::ReadValue>(init_const, "v0");
read->set_friendly_name("read");
std::vector<std::shared_ptr<ngraph::Node>> args = {parameter[0], read};
auto conc = std::make_shared<ngraph::op::Concat>(args, 3);
conc->set_friendly_name("concat");
auto res = std::make_shared<ngraph::op::Result>(conc);
res->set_friendly_name("result");
const auto axis = ngraph::op::Constant::create(element::i64, Shape{}, {3});
axis->set_friendly_name("axis");
auto crop = std::make_shared<ngraph::op::v1::Split>(conc, axis, 3);
crop->set_friendly_name("crop");
auto assign = std::make_shared<ngraph::op::Assign>(crop, "v0");
assign->set_friendly_name("assign");
std::shared_ptr<ngraph::Function> fn_ptr = std::make_shared<ngraph::Function>(ngraph::ResultVector({res}),
ngraph::SinkVector({assign}),
ngraph::ParameterVector{parameter});
fn_ptr->set_friendly_name("ReadConcatSplitAssign");
return fn_ptr;
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -7,152 +7,183 @@
#include <cpp/ie_executable_network.hpp>
#include <cpp_interfaces/base/ie_executable_network_base.hpp>
#include <cpp_interfaces/impl/ie_memory_state_internal.hpp>
#include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
using namespace ::testing;
using namespace std;
using namespace InferenceEngine;
using namespace InferenceEngine::details;
class MemoryStateTests : public ::testing::Test {
template <class T>
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
p->Release();
});
return InferenceEngine::InferRequest(req);
}
class VariableStateTests : public ::testing::Test {
protected:
shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
shared_ptr<MockIAsyncInferRequestInternal> mockInferRequestInternal;
shared_ptr<MockIVariableStateInternal> mockVariableStateInternal;
virtual void SetUp() {
mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
mockInferRequestInternal = make_shared<MockIAsyncInferRequestInternal>();
mockVariableStateInternal = make_shared<MockIVariableStateInternal>();
}
};
TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn(1);
toReturn[0] = mockMemoryStateInternal;
std::vector<IVariableStateInternal::Ptr> toReturn(1);
toReturn[0] = mockVariableStateInternal;
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
auto state = net.QueryState();
ASSERT_EQ(state.size(), 1);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
std::vector<IVariableStateInternal::Ptr> toReturn;
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
auto state = net.QueryState();
ASSERT_EQ(state.size(), 0);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
auto state = net.QueryState();
ASSERT_EQ(state.size(), 2);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
TEST_F(VariableStateTests, VariableStatePropagatesReset) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
auto state = net.QueryState();
state.front().Reset();
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
auto state = net.QueryState();
EXPECT_ANY_THROW(state.front().Reset());
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto state = net.QueryState();
EXPECT_STREQ(state.front().GetName().c_str(), "someName");
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
IMemoryState::Ptr pState;
IVariableState::Ptr pState;
static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
char *name = reinterpret_cast<char *>(1);
EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
IMemoryState::Ptr pState;
IVariableState::Ptr pState;
static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
char name[1];
EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
EXPECT_STREQ(name, "");
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
toReturn.push_back(mockMemoryStateInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
IMemoryState::Ptr pState;
IVariableState::Ptr pState;
static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
char name[2];
EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
EXPECT_STREQ(name, "s");
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
std::vector<IVariableStateInternal::Ptr> toReturn;
Blob::Ptr saver;
toReturn.push_back(mockMemoryStateInternal);
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
@ -161,47 +192,50 @@ TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
TEST_F(VariableStateTests, VariableStateCanPropagateGetLastState) {
IE_SUPPRESS_DEPRECATED_START
auto net = make_executable_network(mockExeNetworkInternal);
std::vector<IMemoryStateInternal::Ptr> toReturn;
std::vector<IVariableStateInternal::Ptr> toReturn;
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
toReturn.push_back(mockMemoryStateInternal);
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
auto saver = net.QueryState().front().GetLastState();
auto saver = net.QueryState().front().GetState();
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
IE_SUPPRESS_DEPRECATED_END
}
class MemoryStateInternalMockImpl : public MemoryStateInternal {
class VariableStateInternalMockImpl : public VariableStateInternal {
public:
using MemoryStateInternal::MemoryStateInternal;
using VariableStateInternal::VariableStateInternal;
MOCK_METHOD0(Reset, void());
};
TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
TEST_F(VariableStateTests, VariableStateInternalCanSaveName) {
IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
ASSERT_STREQ(pState->GetName().c_str(), "name");
}
TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
TEST_F(VariableStateTests, VariableStateInternalCanSaveState) {
IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
pState->SetState(stateBlob);
auto saver = pState->GetLastState();
auto saver = pState->GetState();
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
@ -209,8 +243,8 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
}
TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
TEST_F(VariableStateTests, VariableStateInternalCanSaveStateByReference) {
IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
@ -219,9 +253,162 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
data[0] = 121;
data[1] = 122;
data[2] = 123;
auto saver = pState->GetLastState();
auto saver = pState->GetState();
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);
}
// Tests for InferRequest::QueryState
TEST_F(VariableStateTests, InferRequestCanConvertOneVariableStateFromCppToAPI) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn(1);
toReturn[0] = mockVariableStateInternal;
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
auto state = req.QueryState();
ASSERT_EQ(state.size(), 1);
}
TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillOnce(Return(toReturn));
auto state = req.QueryState();
ASSERT_EQ(state.size(), 0);
}
TEST_F(VariableStateTests, InferRequestCanConvert2VariableStatesFromCPPtoAPI) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
auto state = req.QueryState();
ASSERT_EQ(state.size(), 2);
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
auto state = req.QueryState();
state.front().Reset();
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
auto state = req.QueryState();
EXPECT_ANY_THROW(state.front().Reset());
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetName) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto state = req.QueryState();
EXPECT_STREQ(state.front().GetName().c_str(), "someName");
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
IVariableState::Ptr pState;
static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
char *name = reinterpret_cast<char *>(1);
EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
IVariableState::Ptr pState;
static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
char name[1];
EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
EXPECT_STREQ(name, "");
}
TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
IVariableState::Ptr pState;
static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
char name[2];
EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
EXPECT_STREQ(name, "s");
}
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
Blob::Ptr saver;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
EXPECT_NO_THROW(req.QueryState().front().SetState(stateBlob));
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
}
TEST_F(VariableStateTests, InfReqVariableStateCanPropagateGetLastState) {
auto req = make_infer_request(mockInferRequestInternal);
std::vector<IVariableStateInternal::Ptr> toReturn;
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
auto saver = req.QueryState().front().GetState();
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
}

View File

@ -127,6 +127,7 @@ TEST_F(ExecutableNetworkTests, OperatorAmpersand) {
ASSERT_EQ(exeNet_p, mockIExeNet_p);
}
IE_SUPPRESS_DEPRECATED_START
TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) {
EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
.Times(1)
@ -138,21 +139,22 @@ TEST_F(ExecutableNetworkTests, QueryStateIfReturnOutOfBounds) {
EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
.Times(1)
.WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS));
std::vector<InferenceEngine::MemoryState> MemState_;
std::vector<InferenceEngine::VariableState> MemState_;
EXPECT_NO_THROW(MemState_ = exeNetwork->QueryState());
EXPECT_EQ(MemState_.size(), 0);
}
TEST_F(ExecutableNetworkTests, QueryState) {
std::shared_ptr<MockIMemoryState> mockIMemState_p = std::make_shared<MockIMemoryState>();
std::shared_ptr<MockIVariableState> mockIMemState_p = std::make_shared<MockIVariableState>();
EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
.Times(2)
.WillOnce(DoAll(SetArgReferee<0>(mockIMemState_p), Return(InferenceEngine::OK)))
.WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS));
std::vector<InferenceEngine::MemoryState> MemState_v;
std::vector<InferenceEngine::VariableState> MemState_v;
EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState());
EXPECT_EQ(MemState_v.size(), 1);
}
IE_SUPPRESS_DEPRECATED_END
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
protected:

View File

@ -207,6 +207,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() {
}
}
IE_SUPPRESS_DEPRECATED_START
if (fetchResult.reset) {
auto states = executableApi.QueryState();
ASSERT_FALSE(states.empty());
@ -218,6 +219,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() {
outputs["reset"] = nullptr;
//continue;
}
IE_SUPPRESS_DEPRECATED_END
//FAIL()<<"stop after one frame";

View File

@ -808,6 +808,7 @@ void GNAQueryStateMatcher :: match() {
EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
#endif
IE_SUPPRESS_DEPRECATED_START
try {
loadNetwork();
if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) {
@ -830,6 +831,7 @@ void GNAQueryStateMatcher :: match() {
catch(...) {
FAIL() << "unknown exception thrown";
}
IE_SUPPRESS_DEPRECATED_END
}