diff --git a/inference-engine/include/cpp/ie_executable_network.hpp b/inference-engine/include/cpp/ie_executable_network.hpp index 31d246c0779..d748820c120 100644 --- a/inference-engine/include/cpp/ie_executable_network.hpp +++ b/inference-engine/include/cpp/ie_executable_network.hpp @@ -175,11 +175,13 @@ public: * Wraps IExecutableNetwork::QueryState * @return A vector of Memory State objects */ - std::vector QueryState() { + INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead") + std::vector QueryState() { + IE_SUPPRESS_DEPRECATED_START if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized."; - IMemoryState::Ptr pState = nullptr; + IVariableState::Ptr pState = nullptr; auto res = OK; - std::vector controller; + std::vector controller; for (size_t idx = 0; res == OK; ++idx) { ResponseDesc resp; res = actual->QueryState(pState, idx, &resp); @@ -187,10 +189,11 @@ public: THROW_IE_EXCEPTION << resp.msg; } if (res != OUT_OF_BOUNDS) { - controller.push_back(MemoryState(pState)); + controller.push_back(VariableState(pState, plg)); } } + IE_SUPPRESS_DEPRECATED_END return controller; } diff --git a/inference-engine/include/cpp/ie_infer_request.hpp b/inference-engine/include/cpp/ie_infer_request.hpp index 18daaf79c3c..c750a5d4c90 100644 --- a/inference-engine/include/cpp/ie_infer_request.hpp +++ b/inference-engine/include/cpp/ie_infer_request.hpp @@ -13,6 +13,7 @@ #include #include +#include "cpp/ie_memory_state.hpp" #include "ie_iinfer_request.hpp" #include "details/ie_exception_conversion.hpp" #include "details/ie_so_loader.h" @@ -250,6 +251,31 @@ public: actual->SetCompletionCallback(callWrapper); } + /** + * @copybrief IExecutableNetwork::QueryState + * + * Wraps IExecutableNetwork::QueryState + * @return A vector of Memory State objects + */ + std::vector QueryState() { + if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized."; + IVariableState::Ptr pState = nullptr; + auto res = OK; + std::vector controller; + for (size_t idx = 0; res == OK; ++idx) { + ResponseDesc resp; + res = actual->QueryState(pState, idx, &resp); + if (res != OK && res != OUT_OF_BOUNDS) { + THROW_IE_EXCEPTION << resp.msg; + } + if (res != OUT_OF_BOUNDS) { + controller.push_back(VariableState(pState, plg)); + } + } + + return controller; + } + /** * @brief IInferRequest pointer to be used directly in CreateInferRequest functions * @return A shared pointer to underlying IInferRequest interface diff --git a/inference-engine/include/cpp/ie_memory_state.hpp b/inference-engine/include/cpp/ie_memory_state.hpp index 3fe79d0da38..24fd4d7fa1f 100644 --- a/inference-engine/include/cpp/ie_memory_state.hpp +++ b/inference-engine/include/cpp/ie_memory_state.hpp @@ -11,39 +11,42 @@ #include #include "ie_imemory_state.hpp" +#include "details/ie_exception_conversion.hpp" +#include "details/ie_so_loader.h" namespace InferenceEngine { /** - * @brief C++ exception based error reporting wrapper of API class IMemoryState + * @brief C++ exception based error reporting wrapper of API class IVariableState */ -class MemoryState { - IMemoryState::Ptr actual = nullptr; +class VariableState { + IVariableState::Ptr actual = nullptr; + details::SharedObjectLoader::Ptr plugin = {}; public: /** - * constructs MemoryState from the initialized shared_pointer + * constructs VariableState from the initialized shared_pointer * @param pState Initialized shared pointer */ - explicit MemoryState(IMemoryState::Ptr pState): actual(pState) { + explicit VariableState(IVariableState::Ptr pState, details::SharedObjectLoader::Ptr plg = {}) : actual(pState), plugin(plg) { if (actual == nullptr) { - THROW_IE_EXCEPTION << "MemoryState wrapper was not initialized."; + THROW_IE_EXCEPTION << "VariableState wrapper was not initialized."; } } /** - * @copybrief IMemoryState::Reset + * @copybrief IVariableState::Reset * - * Wraps IMemoryState::Reset + * Wraps IVariableState::Reset */ void Reset() { CALL_STATUS_FNC_NO_ARGS(Reset); } /** - * @copybrief IMemoryState::GetName + * @copybrief IVariableState::GetName * - * Wraps IMemoryState::GetName + * Wraps IVariableState::GetName * @return A string representing a state name */ std::string GetName() const { @@ -53,21 +56,26 @@ public: } /** - * @copybrief IMemoryState::GetLastState + * @copybrief IVariableState::GetState * - * Wraps IMemoryState::GetLastState + * Wraps IVariableState::GetState * @return A blob representing a last state */ - Blob::CPtr GetLastState() const { + Blob::CPtr GetState() const { Blob::CPtr stateBlob; - CALL_STATUS_FNC(GetLastState, stateBlob); + CALL_STATUS_FNC(GetState, stateBlob); return stateBlob; } + INFERENCE_ENGINE_DEPRECATED("Use GetState function instead") + Blob::CPtr GetLastState() const { + return GetState(); + } + /** - * @copybrief IMemoryState::SetState + * @copybrief IVariableState::SetState * - * Wraps IMemoryState::SetState + * Wraps IVariableState::SetState * @param state The current state to set */ void SetState(Blob::Ptr state) { @@ -75,4 +83,8 @@ public: } }; -} // namespace InferenceEngine \ No newline at end of file +/* + * @brief For compatibility reasons. + */ +using MemoryState = VariableState; +} // namespace InferenceEngine diff --git a/inference-engine/include/ie_iexecutable_network.hpp b/inference-engine/include/ie_iexecutable_network.hpp index 491c24e0736..8e7c5fa0bef 100644 --- a/inference-engine/include/ie_iexecutable_network.hpp +++ b/inference-engine/include/ie_iexecutable_network.hpp @@ -118,7 +118,7 @@ public: * @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for * given index */ - virtual StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0; + virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0; /** * @brief Sets configuration for current executable network diff --git a/inference-engine/include/ie_iinfer_request.hpp b/inference-engine/include/ie_iinfer_request.hpp index a83b613348d..e5674fe7929 100644 --- a/inference-engine/include/ie_iinfer_request.hpp +++ b/inference-engine/include/ie_iinfer_request.hpp @@ -17,6 +17,7 @@ #include "ie_blob.h" #include "ie_common.h" #include "ie_preprocess.hpp" +#include "ie_imemory_state.hpp" #include "details/ie_irelease.hpp" namespace InferenceEngine { @@ -177,6 +178,18 @@ public: * @return Enumeration of the resulted action: InferenceEngine::OK (0) for success */ virtual InferenceEngine::StatusCode SetBatch(int batch_size, ResponseDesc* resp) noexcept = 0; -}; -} // namespace InferenceEngine + /** + * @brief Gets state control interface for given infer request. + * + * State control essential for recurrent networks + * + * @param pState reference to a pointer that receives internal states + * @param idx requested index for receiving memory state + * @param resp Optional: pointer to an already allocated object to contain information in case of failure + * @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for + * given index + */ + virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0; +}; +} // namespace InferenceEngine \ No newline at end of file diff --git a/inference-engine/include/ie_imemory_state.hpp b/inference-engine/include/ie_imemory_state.hpp index 98c9f345445..2e44350b5fa 100644 --- a/inference-engine/include/ie_imemory_state.hpp +++ b/inference-engine/include/ie_imemory_state.hpp @@ -3,7 +3,7 @@ // /** - * @brief a header file for IMemoryState interface + * @brief a header file for IVariableState interface * * @file ie_imemory_state.hpp */ @@ -19,19 +19,19 @@ namespace InferenceEngine { /** - * @interface IMemoryState + * @interface IVariableState * @brief manages data for reset operations */ -class IMemoryState : public details::no_copy { +class IVariableState : public details::no_copy { public: /** - * @brief A shared pointer to the IMemoryState interface + * @brief A shared pointer to the IVariableState interface */ - using Ptr = std::shared_ptr; + using Ptr = std::shared_ptr; /** * @brief Gets name of current memory state, if length of array is not enough name is truncated by len, null - * terminator is inserted as well. + * terminator is inserted as well. As memory state name variable_id from according ReadValue used. * * @param name preallocated buffer for receiving name * @param len Length of the buffer @@ -41,7 +41,7 @@ public: virtual StatusCode GetName(char* name, size_t len, ResponseDesc* resp) const noexcept = 0; /** - * @brief reset internal memory state for relevant iexecutable network, to a value specified in SetState + * @brief Reset internal memory state for relevant infer request, to a value specified as default for according ReadValue node * * @param resp Optional: pointer to an already allocated object to contain information in case of failure * @return Status code of the operation: InferenceEngine::OK (0) for success* @@ -49,25 +49,30 @@ public: virtual StatusCode Reset(ResponseDesc* resp) noexcept = 0; /** - * @brief Sets the new state that is used for all future Reset() operations as a base. + * @brief Sets the new state for the next inference. * * This method can fail if Blob size does not match the internal state size or precision * - * @param newState is the data to use as base state + * @param newState is the data to use as new state * @param resp Optional: pointer to an already allocated object to contain information in case of failure * @return Status code of the operation: InferenceEngine::OK (0) for success */ virtual StatusCode SetState(Blob::Ptr newState, ResponseDesc* resp) noexcept = 0; /** - * @brief returns the value of the last memory state. + * @brief Returns the value of the memory state. * - * @details Since we roll memory after each infer, we can query the input state always and still get the last state. * @param lastState * @param resp Optional: pointer to an already allocated object to contain information in case of failure * @return Status code of the operation: InferenceEngine::OK (0) for success * */ - virtual StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept = 0; + INFERENCE_ENGINE_DEPRECATED("Use GetState function instead") + virtual StatusCode GetLastState(Blob::CPtr& state, ResponseDesc* resp) const noexcept {return GetState(state, resp);} + virtual StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept = 0; }; +/* + * @brief For compatibility reasons. + */ +using IMemoryState = IVariableState; } // namespace InferenceEngine \ No newline at end of file diff --git a/inference-engine/samples/speech_sample/main.cpp b/inference-engine/samples/speech_sample/main.cpp index c9174739b86..c7db028a73e 100644 --- a/inference-engine/samples/speech_sample/main.cpp +++ b/inference-engine/samples/speech_sample/main.cpp @@ -845,7 +845,7 @@ int main(int argc, char *argv[]) { ptrUtterances.resize(inputArkFiles.size()); // initialize memory state before starting - for (auto &&state : executableNet.QueryState()) { + for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) { state.Reset(); } @@ -1080,7 +1080,7 @@ int main(int argc, char *argv[]) { totalTime += d.count(); // resetting state between utterances - for (auto &&state : executableNet.QueryState()) { + for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) { state.Reset(); } diff --git a/inference-engine/src/gna_plugin/gna_executable_network.hpp b/inference-engine/src/gna_plugin/gna_executable_network.hpp index b7a108821de..d240c7863ed 100644 --- a/inference-engine/src/gna_plugin/gna_executable_network.hpp +++ b/inference-engine/src/gna_plugin/gna_executable_network.hpp @@ -59,12 +59,13 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe return std::make_shared(plg, networkInputs, networkOutputs); } - - - std::vector QueryState() override { + INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead") + std::vector QueryState() override { + IE_SUPPRESS_DEPRECATED_START auto pluginStates = plg->QueryState(); - std::vector state(pluginStates.begin(), pluginStates.end()); + std::vector state(pluginStates.begin(), pluginStates.end()); return plg->QueryState(); + IE_SUPPRESS_DEPRECATED_END } void Export(const std::string &modelFileName) override { diff --git a/inference-engine/src/gna_plugin/gna_infer_request.hpp b/inference-engine/src/gna_plugin/gna_infer_request.hpp index fd2cc69d61d..fcdc92b8d4f 100644 --- a/inference-engine/src/gna_plugin/gna_infer_request.hpp +++ b/inference-engine/src/gna_plugin/gna_infer_request.hpp @@ -111,5 +111,13 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal { } return InferenceEngine::OK; } + + IE_SUPPRESS_DEPRECATED_START + std::vector QueryState() override { + auto pluginStates = plg->QueryState(); + std::vector state(pluginStates.begin(), pluginStates.end()); + return plg->QueryState(); + } + IE_SUPPRESS_DEPRECATED_END }; } // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 5f0a04f93dd..7d6e6768ce9 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -1186,11 +1186,11 @@ Blob::Ptr GNAPlugin::GetInputBlob(const std::string& name, InferenceEngine::Prec return inputBlob; } -std::vector GNAPlugin::QueryState() { +std::vector GNAPlugin::QueryState() { if (memoryStates.size() != graphCompiler.memory_connection.size()) { memoryStates.clear(); for (auto& connection : graphCompiler.memory_connection) { - auto state = std::make_shared(connection.first, std::make_shared (connection.second)); + auto state = std::make_shared(connection.first, std::make_shared (connection.second)); memoryStates.emplace_back(state); } } diff --git a/inference-engine/src/gna_plugin/gna_plugin.hpp b/inference-engine/src/gna_plugin/gna_plugin.hpp index 1e4c4fd4828..dbe98fd37a4 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.hpp +++ b/inference-engine/src/gna_plugin/gna_plugin.hpp @@ -84,7 +84,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin { InferenceEngine::InputsDataMap inputsDataMap; InferenceEngine::OutputsDataMap outputsDataMap; - std::vector memoryStates; + std::vector memoryStates; public: explicit GNAPlugin(const std::map& configMap); @@ -159,7 +159,8 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin { * QueryState API * @return */ - std::vector QueryState(); + INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead") + std::vector QueryState(); /** * test-wise API diff --git a/inference-engine/src/gna_plugin/memory/gna_memory_state.cpp b/inference-engine/src/gna_plugin/memory/gna_memory_state.cpp index bb25cd941a0..27e938468d7 100644 --- a/inference-engine/src/gna_plugin/memory/gna_memory_state.cpp +++ b/inference-engine/src/gna_plugin/memory/gna_memory_state.cpp @@ -12,15 +12,15 @@ namespace GNAPluginNS { namespace memory { - std::string GNAMemoryState::GetName() const { + std::string GNAVariableState::GetName() const { return name; } - void GNAMemoryState::Reset() { + void GNAVariableState::Reset() { state->Reset(); } - InferenceEngine::Precision GNAMemoryState::getPrecision() const { + InferenceEngine::Precision GNAVariableState::getPrecision() const { InferenceEngine::Precision state_precision; if (state->getInput()) { @@ -36,14 +36,14 @@ namespace memory { break; default: THROW_GNA_EXCEPTION << "Incorrect state element size " << element_size << - " to determine precision for MemoryState " << name; + " to determine precision for VariableState " << name; } } return state_precision; } - void GNAMemoryState::SetState(InferenceEngine::Blob::Ptr newState) { + void GNAVariableState::SetState(InferenceEngine::Blob::Ptr newState) { IE_ASSERT(newState != nullptr); auto data_ptr = newState->cbuffer().as(); @@ -78,20 +78,20 @@ namespace memory { data_elements, scale_factor); } else { - THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name + THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name << ". If old state precision is I16 only I16 and FP32 are allowed as new state precisions." << " Old state: " << state_precision << " New state: " << new_state_precision; } break; } default: - THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name + THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name << ". Incorrect new/old precision pair" << " Old state: " << state_precision << " New state: " << new_state_precision; } } - InferenceEngine::Blob::CPtr GNAMemoryState::GetLastState() const { + InferenceEngine::Blob::CPtr GNAVariableState::GetState() const { auto elements = state->reserved_size / state->elementSizeBytes(); InferenceEngine::Precision state_precision = getPrecision(); diff --git a/inference-engine/src/gna_plugin/memory/gna_memory_state.hpp b/inference-engine/src/gna_plugin/memory/gna_memory_state.hpp index 499c4c9e82d..2a7c83d6dae 100644 --- a/inference-engine/src/gna_plugin/memory/gna_memory_state.hpp +++ b/inference-engine/src/gna_plugin/memory/gna_memory_state.hpp @@ -11,14 +11,14 @@ namespace GNAPluginNS { namespace memory { -class GNAMemoryState : public InferenceEngine::IMemoryStateInternal { +class GNAVariableState : public InferenceEngine::IVariableStateInternal { public: - GNAMemoryState(std::string name, std::shared_ptr state) + GNAVariableState(std::string name, std::shared_ptr state) : name(name), state(state) { IE_ASSERT(state != nullptr); } void Reset() override; void SetState(InferenceEngine::Blob::Ptr newState) override; - InferenceEngine::Blob::CPtr GetLastState() const override; + InferenceEngine::Blob::CPtr GetState() const override; std::string GetName() const override; private: diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp index e6bd3b265e2..94919a18732 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp @@ -183,14 +183,14 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network if (node->getType() == MemoryInput) { auto memoryNode = dynamic_cast(node.get()); auto state_store = memoryNode->getStore(); - auto state_name = node->getName(); + auto state_name = memoryNode->getId(); // Remove suffix with pair ID. Internal information. auto suffix_idx = state_name.find("/id="); if (suffix_idx != std::string::npos) state_name = state_name.substr(0, suffix_idx); - memoryStates.emplace_back(new MKLDNNMemoryState(state_name, state_store)); + memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store)); } } } @@ -314,6 +314,8 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::ICNNNetwork &n return check_result; } -std::vector MKLDNNExecNetwork::QueryState() { +IE_SUPPRESS_DEPRECATED_START +std::vector MKLDNNExecNetwork::QueryState() { return memoryStates; } +IE_SUPPRESS_DEPRECATED_END diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h index 8ea85bbbba9..4247503c134 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h @@ -42,14 +42,15 @@ public: InferenceEngine::CNNNetwork GetExecGraphInfo() override; - std::vector QueryState() override; + INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead") + std::vector QueryState() override; InferenceEngine::ThreadLocal _graphs; protected: friend class MKLDNNInferRequest; MKLDNNExtensionManager::Ptr extensionManager; - std::vector memoryStates; + std::vector memoryStates; InferenceEngine::details::CNNNetworkImplPtr _clonedNetwork; std::mutex _cfgMutex; Config _cfg; diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp index e07dfcaacda..ae7db884339 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp @@ -14,6 +14,8 @@ #include "mkldnn_exec_network.h" #include "mkldnn_itt.h" #include "nodes/common/cpu_convert.h" +#include "mkldnn_memory_state.h" +#include "nodes/mkldnn_memory_node.hpp" MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap networkInputs, InferenceEngine::OutputsDataMap networkOutputs, @@ -35,6 +37,30 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData InferenceEngine::Blob::Ptr blob; MKLDNNInferRequest::GetBlob(it.first.c_str(), blob); } + + // Save all MemoryLayer data tensors. Will use insight about mechanics + // of MemoryLayer implementation. It uses output edge of MemoryLayer + // producer as storage for tensor to keep it between infer calls. + IE_SUPPRESS_DEPRECATED_START + if (execNetwork->QueryState().size() == 0) { + for (auto &node : graph->GetNodes()) { + if (node->getType() == MemoryInput) { + auto memoryNode = dynamic_cast(node.get()); + auto state_store = memoryNode->getStore(); + auto state_name = memoryNode->getId(); + + // Remove suffix with pair ID. Internal information. + auto suffix_idx = state_name.find("/id="); + if (suffix_idx != std::string::npos) + state_name = state_name.substr(0, suffix_idx); + + memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store)); + } + } + } else { + memoryStates = execNetwork->QueryState(); + } + IE_SUPPRESS_DEPRECATED_END } MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() { @@ -390,3 +416,7 @@ void MKLDNNPlugin::MKLDNNInferRequest::SetBatch(int new_batch) { m_curBatch = new_batch; } + +std::vector MKLDNNPlugin::MKLDNNInferRequest::QueryState() { + return memoryStates; +} diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.h b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.h index 4c058a479e2..e9863be75f0 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.h @@ -43,6 +43,8 @@ public: void SetBatch(int batch = -1) override; + std::vector QueryState() override; + private: void PushInputData(); @@ -53,5 +55,6 @@ private: MKLDNNGraph* graph = nullptr; std::map externalPtr; openvino::itt::handle_t profilingTask; + std::vector memoryStates; }; } // namespace MKLDNNPlugin diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.cpp index 56d74d25889..4af75508d55 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.cpp @@ -4,20 +4,21 @@ #include "mkldnn_memory_state.h" #include "mkldnn_extension_utils.h" +#include "blob_factory.hpp" using namespace InferenceEngine; namespace MKLDNNPlugin { -std::string MKLDNNMemoryState::GetName() const { +std::string MKLDNNVariableState::GetName() const { return name; } -void MKLDNNMemoryState::Reset() { +void MKLDNNVariableState::Reset() { storage->FillZero(); } -void MKLDNNMemoryState::SetState(Blob::Ptr newState) { +void MKLDNNVariableState::SetState(Blob::Ptr newState) { auto prec = newState->getTensorDesc().getPrecision(); auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(prec); auto data_layout = MKLDNNMemory::Convert(newState->getTensorDesc().getLayout()); @@ -27,9 +28,11 @@ void MKLDNNMemoryState::SetState(Blob::Ptr newState) { storage->SetData(data_type, data_layout, data_ptr, data_size); } -InferenceEngine::Blob::CPtr MKLDNNMemoryState::GetLastState() const { - THROW_IE_EXCEPTION << "GetLastState method is not implemented for MemoryState"; - return nullptr; +InferenceEngine::Blob::CPtr MKLDNNVariableState::GetState() const { + auto result_blob = make_blob_with_precision(MKLDNNMemoryDesc(storage->GetDescriptor())); + result_blob->allocate(); + std::memcpy(result_blob->buffer(), storage->GetData(), storage->GetSize()); + return result_blob; } -} // namespace MKLDNNPlugin \ No newline at end of file +} // namespace MKLDNNPlugin diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.h b/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.h index cb024dfb08c..751635b7709 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_memory_state.h @@ -11,15 +11,15 @@ namespace MKLDNNPlugin { -class MKLDNNMemoryState : public InferenceEngine::IMemoryStateInternal { +class MKLDNNVariableState : public InferenceEngine::IVariableStateInternal { public: - MKLDNNMemoryState(std::string name, MKLDNNMemoryPtr storage) : + MKLDNNVariableState(std::string name, MKLDNNMemoryPtr storage) : name(name), storage(storage) {} std::string GetName() const override; void Reset() override; void SetState(InferenceEngine::Blob::Ptr newState) override; - InferenceEngine::Blob::CPtr GetLastState() const override; + InferenceEngine::Blob::CPtr GetState() const override; private: std::string name; diff --git a/inference-engine/src/plugin_api/cpp_interfaces/base/ie_executable_network_base.hpp b/inference-engine/src/plugin_api/cpp_interfaces/base/ie_executable_network_base.hpp index fd867805d74..b9d7833e357 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/base/ie_executable_network_base.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/base/ie_executable_network_base.hpp @@ -66,19 +66,22 @@ public: TO_STATUS(graphPtr = _impl->GetExecGraphInfo()); } - StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override { + INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead") + StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override { + IE_SUPPRESS_DEPRECATED_START try { auto v = _impl->QueryState(); if (idx >= v.size()) { return OUT_OF_BOUNDS; } - pState = std::make_shared>(v[idx]); + pState = std::make_shared>(v[idx]); return OK; } catch (const std::exception& ex) { return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); } catch (...) { return InferenceEngine::DescriptionBuffer(UNEXPECTED); } + IE_SUPPRESS_DEPRECATED_END } void Release() noexcept override { diff --git a/inference-engine/src/plugin_api/cpp_interfaces/base/ie_infer_async_request_base.hpp b/inference-engine/src/plugin_api/cpp_interfaces/base/ie_infer_async_request_base.hpp index e350e6e5cba..5892ef02846 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/base/ie_infer_async_request_base.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/base/ie_infer_async_request_base.hpp @@ -10,6 +10,7 @@ #include "cpp_interfaces/exception2status.hpp" #include "cpp_interfaces/plugin_itt.hpp" +#include #include "ie_iinfer_request.hpp" #include "ie_preprocess.hpp" #include "ie_profiling.hpp" @@ -88,6 +89,21 @@ public: TO_STATUS(_impl->SetBatch(batch_size)); } + StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override { + try { + auto v = _impl->QueryState(); + if (idx >= v.size()) { + return OUT_OF_BOUNDS; + } + pState = std::make_shared>(v[idx]); + return OK; + } catch (const std::exception& ex) { + return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what(); + } catch (...) { + return InferenceEngine::DescriptionBuffer(UNEXPECTED); + } + } + private: ~InferRequestBase() = default; }; diff --git a/inference-engine/src/plugin_api/cpp_interfaces/base/ie_memory_state_base.hpp b/inference-engine/src/plugin_api/cpp_interfaces/base/ie_memory_state_base.hpp index 2b88ee5f70d..fe191a8e7f6 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/base/ie_memory_state_base.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/base/ie_memory_state_base.hpp @@ -7,23 +7,24 @@ #include #include "cpp_interfaces/exception2status.hpp" +#include "cpp_interfaces/impl/ie_memory_state_internal.hpp" #include "ie_imemory_state.hpp" namespace InferenceEngine { /** - * @brief default implementation for IMemoryState + * @brief default implementation for IVariableState * @ingroup ie_dev_api_mem_state_api */ template -class MemoryStateBase : public IMemoryState { +class VariableStateBase : public IVariableState { protected: std::shared_ptr impl; public: - explicit MemoryStateBase(std::shared_ptr impl): impl(impl) { + explicit VariableStateBase(std::shared_ptr impl): impl(impl) { if (impl == nullptr) { - THROW_IE_EXCEPTION << "MemoryStateBase implementation not defined"; + THROW_IE_EXCEPTION << "VariableStateBase implementation not defined"; } } @@ -44,9 +45,9 @@ public: TO_STATUS(impl->SetState(newState)); } - StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept override { - TO_STATUS(lastState = impl->GetLastState()); + StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept override { + TO_STATUS(state = impl->GetState()); } }; -} // namespace InferenceEngine \ No newline at end of file +} // namespace InferenceEngine diff --git a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp index 41f5d16fe06..c2e70b5bf73 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp @@ -88,7 +88,7 @@ public: _plugin = plugin; } - std::vector QueryState() override { + std::vector QueryState() override { THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str; } diff --git a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp index d7b2da1d01d..71d2f5a75f8 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp @@ -152,6 +152,10 @@ public: _publicInterface = std::shared_ptr(ptr.get(), [](IInferRequest*) {}); } + std::vector QueryState() override { + return _syncRequest->QueryState(); + } + protected: /** * @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation diff --git a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_request_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_request_internal.hpp index 7fe1c30b84b..50671b4a3c2 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_request_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_request_internal.hpp @@ -223,6 +223,12 @@ public: } } + std::vector QueryState() override { + // meaning base plugin reports as no state available - plugin owners need to create proper override of this + THROW_IE_EXCEPTION << "Plugin doesn't override QueryState"; + return {}; + } + protected: InferenceEngine::InputsDataMap _networkInputs; //!< Holds information about network inputs info InferenceEngine::OutputsDataMap _networkOutputs; //!< Holds information about network outputs data diff --git a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_memory_state_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_memory_state_internal.hpp index 5da62e3c068..05f96d5f4e7 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_memory_state_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_memory_state_internal.hpp @@ -13,21 +13,25 @@ namespace InferenceEngine { * @brief minimal interface for memory state implementation * @ingroup ie_dev_api_mem_state_api */ -class MemoryStateInternal : public IMemoryStateInternal { +class VariableStateInternal : public IVariableStateInternal { std::string name; Blob::Ptr state; public: - explicit MemoryStateInternal(std::string name): name(name) {} + explicit VariableStateInternal(std::string name): name(name) {} std::string GetName() const override { return name; } void SetState(Blob::Ptr newState) override { state = newState; } - Blob::CPtr GetLastState() const override { + Blob::CPtr GetState() const override { return state; } }; -} // namespace InferenceEngine \ No newline at end of file +/* + * @brief For compatibility reasons. + */ +using MemoryStateInternal = VariableStateInternal; +} // namespace InferenceEngine diff --git a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iexecutable_network_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iexecutable_network_internal.hpp index 9efdb66004a..17cc927813f 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iexecutable_network_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iexecutable_network_internal.hpp @@ -79,7 +79,7 @@ public: * @brief Queries memory states. * @return Returns memory states */ - virtual std::vector QueryState() = 0; + virtual std::vector QueryState() = 0; /** * @brief Sets configuration for current executable network diff --git a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iinfer_request_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iinfer_request_internal.hpp index d62cd42b3c2..c09a15aa25b 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iinfer_request_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iinfer_request_internal.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -83,6 +84,12 @@ public: * @param batch - new batch size to be used by all the following inference calls for this request. */ virtual void SetBatch(int batch) = 0; + + /** + * @brief Queries memory states. + * @return Returns memory states + */ + virtual std::vector QueryState() = 0; }; } // namespace InferenceEngine diff --git a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_imemory_state_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_imemory_state_internal.hpp index aa81e695c15..ef37d8b8241 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_imemory_state_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_imemory_state_internal.hpp @@ -11,19 +11,25 @@ namespace InferenceEngine { /** - * @interface IMemoryStateInternal + * @interface IVariableStateInternal * @brief minimal interface for memory state implementation * @ingroup ie_dev_api_mem_state_api */ -class IMemoryStateInternal { +class IVariableStateInternal { public: - using Ptr = std::shared_ptr; + using Ptr = std::shared_ptr; - virtual ~IMemoryStateInternal() = default; + virtual ~IVariableStateInternal() = default; virtual std::string GetName() const = 0; virtual void Reset() = 0; virtual void SetState(Blob::Ptr newState) = 0; - virtual Blob::CPtr GetLastState() const = 0; + virtual Blob::CPtr GetState() const = 0; + INFERENCE_ENGINE_DEPRECATED("Use GetState function instead") + virtual Blob::CPtr GetLastState() const {return GetState();} }; +/* + * @brief For compatibility reasons. + */ +using IMemoryStateInternal = IVariableStateInternal; } // namespace InferenceEngine diff --git a/inference-engine/tests/functional/inference_engine/async_infer_request_test.cpp b/inference-engine/tests/functional/inference_engine/async_infer_request_test.cpp index 861ceee547f..1ce30df9738 100644 --- a/inference-engine/tests/functional/inference_engine/async_infer_request_test.cpp +++ b/inference-engine/tests/functional/inference_engine/async_infer_request_test.cpp @@ -83,3 +83,8 @@ TEST(InferRequestCPPTests, throwsOnUninitializedCast) { InferRequest req; ASSERT_THROW(auto &ireq = static_cast(req), InferenceEngine::details::InferenceEngineException); } + +TEST(InferRequestCPPTests, throwsOnUninitializedQueryState) { + InferRequest req; + ASSERT_THROW(req.QueryState(), InferenceEngine::details::InferenceEngineException); +} diff --git a/inference-engine/tests/functional/inference_engine/executable_network.cpp b/inference-engine/tests/functional/inference_engine/executable_network.cpp index f449c4cc806..89f3b75b98f 100644 --- a/inference-engine/tests/functional/inference_engine/executable_network.cpp +++ b/inference-engine/tests/functional/inference_engine/executable_network.cpp @@ -46,8 +46,10 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) { } TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) { + IE_SUPPRESS_DEPRECATED_START ExecutableNetwork exec; ASSERT_THROW(exec.QueryState(), InferenceEngine::details::InferenceEngineException); + IE_SUPPRESS_DEPRECATED_END } TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) { diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/cpp_holders.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/cpp_holders.cpp index 17de549f2f0..f01442e9278 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/cpp_holders.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/cpp_holders.cpp @@ -10,12 +10,15 @@ namespace { // 0 - plugin // 1 - executable_network // 2 - infer_request - {0, 1, 2}, - {0, 2, 1}, - {1, 0, 2}, - {1, 2, 0}, - {2, 0, 1}, - {2, 1, 0} + // 3 - variable state + {3, 0, 1, 2}, + {3, 0, 2, 1}, + {3, 1, 0, 2}, + {3, 1, 2, 0}, + {3, 2, 0, 1}, + {3, 2, 1, 0}, + {0, 3, 1, 2}, + {0, 1, 3, 2} }; INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest, @@ -24,4 +27,4 @@ namespace { ::testing::ValuesIn(orders)), HoldersTest::getTestCaseName); -} // namespace \ No newline at end of file +} // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/memory_states.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/memory_states.cpp new file mode 100644 index 00000000000..0a7bc37cb63 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/memory_states.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "behavior/memory_states.hpp" +#include "functional_test_utils/test_model/test_model.hpp" +#include "functional_test_utils/plugin_cache.hpp" + +InferenceEngine::CNNNetwork getNetwork() { + auto model = FuncTestUtils::TestModel::getModelWithMultipleMemoryConnections(InferenceEngine::Precision::FP32); + auto ie = PluginCache::get().ie(); + return ie->ReadNetwork(model.model_xml_str, model.weights_blob); +} +std::vector memoryStateTestCases = { + memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_CPU) +}; + +INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest, + ::testing::ValuesIn(memoryStateTestCases), + VariableStateTest::getTestCaseName); diff --git a/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/cpp_holders.cpp b/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/cpp_holders.cpp index 2c27cccaffd..c14c4dc1729 100644 --- a/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/cpp_holders.cpp +++ b/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/cpp_holders.cpp @@ -10,12 +10,15 @@ namespace { // 0 - plugin // 1 - executable_network // 2 - infer_request - {0, 1, 2}, - {0, 2, 1}, - {1, 0, 2}, - {1, 2, 0}, - {2, 0, 1}, - {2, 1, 0} + // 3 - variable state + {3, 0, 1, 2}, + {3, 0, 2, 1}, + {3, 1, 0, 2}, + {3, 1, 2, 0}, + {3, 2, 0, 1}, + {3, 2, 1, 0}, + {0, 3, 1, 2}, + {0, 1, 3, 2} }; INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest, diff --git a/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/memory_states.cpp b/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/memory_states.cpp index 62ab38b7a37..6a7b4ccb69b 100644 --- a/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/memory_states.cpp +++ b/inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/memory_states.cpp @@ -17,6 +17,6 @@ std::vector memoryStateTestCases = { memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA) }; -INSTANTIATE_TEST_CASE_P(smoke_MemoryStateBasic, MemoryStateTest, +INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest, ::testing::ValuesIn(memoryStateTestCases), - MemoryStateTest::getTestCaseName); + VariableStateTest::getTestCaseName); diff --git a/inference-engine/tests/functional/plugin/shared/include/behavior/memory_states.hpp b/inference-engine/tests/functional/plugin/shared/include/behavior/memory_states.hpp index bac01c1ccb8..a9718b1b6c8 100644 --- a/inference-engine/tests/functional/plugin/shared/include/behavior/memory_states.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/behavior/memory_states.hpp @@ -14,7 +14,7 @@ typedef std::tuple< std::string> // Target device name memoryStateParams; -class MemoryStateTest : public CommonTestUtils::TestsCommon, +class VariableStateTest : public CommonTestUtils::TestsCommon, public testing::WithParamInterface { protected: InferenceEngine::CNNNetwork net; diff --git a/inference-engine/tests/functional/plugin/shared/src/behavior/cpp_holders.cpp b/inference-engine/tests/functional/plugin/shared/src/behavior/cpp_holders.cpp index 61f9eb2ece3..62db86b1cd3 100644 --- a/inference-engine/tests/functional/plugin/shared/src/behavior/cpp_holders.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/behavior/cpp_holders.cpp @@ -25,7 +25,11 @@ namespace BehaviorTestsDefinitions { if (deathTestStyle == "fast") { ::testing::GTEST_FLAG(death_test_style) = "threadsafe"; } - function = ngraph::builder::subgraph::makeConvPoolRelu(); + if (targetDevice == CommonTestUtils::DEVICE_CPU) { + function = ngraph::builder::subgraph::makeReadConcatSplitAssign(); + } else { + function = ngraph::builder::subgraph::makeConvPoolRelu(); + } } void HoldersTest::TearDown() { @@ -42,6 +46,12 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "") InferenceEngine::Core core; auto exe_net = core.LoadNetwork(cnnNet, deviceName); auto request = exe_net.CreateInferRequest(); + std::vector states = {}; + try { + states = request.QueryState(); + } catch(...) { + // do nothing + } auto release = [&](int i) { switch (i) { @@ -54,6 +64,9 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "") case 2: request = {}; break; + case 3: + states = {}; + break; default: break; } @@ -67,4 +80,4 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "") // Test failed if crash happens EXPECT_NO_CRASH(release_order_test(order, targetDevice, function)); } -} // namespace BehaviorTestsDefinitions \ No newline at end of file +} // namespace BehaviorTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/behavior/memory_states.cpp b/inference-engine/tests/functional/plugin/shared/src/behavior/memory_states.cpp index 2aa6694839a..4ef378eb5e4 100644 --- a/inference-engine/tests/functional/plugin/shared/src/behavior/memory_states.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/behavior/memory_states.cpp @@ -7,7 +7,7 @@ #include "behavior/memory_states.hpp" #include "functional_test_utils/plugin_cache.hpp" -std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo &obj) { +std::string VariableStateTest::getTestCaseName(const testing::TestParamInfo &obj) { std::ostringstream result; InferenceEngine::CNNNetwork net; std::string targetDevice; @@ -17,22 +17,108 @@ std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfoLoadNetwork(net, deviceName); } -TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) { +TEST_P(VariableStateTest, smoke_VariableState_QueryState) { + IE_SUPPRESS_DEPRECATED_START auto executableNet = PrepareNetwork(); auto states = executableNet.QueryState(); - ASSERT_TRUE(states.size() == 2) << "Incorrect number of MemoryStates"; + ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates"; + + for (auto&& state : states) { + auto name = state.GetName(); + ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end()) + << "State " << name << "expected to be in memory states but it is not!"; + } + IE_SUPPRESS_DEPRECATED_END +} + +TEST_P(VariableStateTest, smoke_VariableState_SetState) { + IE_SUPPRESS_DEPRECATED_START + auto executableNet = PrepareNetwork(); + const float new_state_val = 13.0f; + for (auto&& state : executableNet.QueryState()) { + state.Reset(); + auto state_val = state.GetState(); + auto element_count = state_val->size(); + + std::vector new_state_data(element_count, new_state_val); + auto stateBlob = InferenceEngine::make_shared_blob( + { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() }, + new_state_data.data(), new_state_data.size()); + + state.SetState(stateBlob); + } + + for (auto&& state : executableNet.QueryState()) { + auto lastState = state.GetState(); + auto last_state_size = lastState->size(); + auto last_state_data = lastState->cbuffer().as(); + ASSERT_TRUE(last_state_size != 0) << "State size should not be 0"; + + for (int i = 0; i < last_state_size; i++) { + EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5); + } + } + IE_SUPPRESS_DEPRECATED_END +} + +TEST_P(VariableStateTest, smoke_VariableState_Reset) { + IE_SUPPRESS_DEPRECATED_START + auto executableNet = PrepareNetwork(); + const float new_state_val = 13.0f; + for (auto&& state : executableNet.QueryState()) { + state.Reset(); + auto state_val = state.GetState(); + auto element_count = state_val->size(); + + std::vector new_state_data(element_count, new_state_val); + auto stateBlob = InferenceEngine::make_shared_blob( + { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() }, + new_state_data.data(), new_state_data.size()); + + state.SetState(stateBlob); + } + + executableNet.QueryState().front().Reset(); + + auto states = executableNet.QueryState(); + for (int i = 0; i < states.size(); ++i) { + auto lastState = states[i].GetState(); + auto last_state_size = lastState->size(); + auto last_state_data = lastState->cbuffer().as(); + + ASSERT_TRUE(last_state_size != 0) << "State size should not be 0"; + + if (i == 0) { + for (int j = 0; j < last_state_size; ++j) { + EXPECT_NEAR(0, last_state_data[j], 1e-5); + } + } else { + for (int j = 0; j < last_state_size; ++j) { + EXPECT_NEAR(13.0f, last_state_data[j], 1e-5); + } + } + } + IE_SUPPRESS_DEPRECATED_END +} + +TEST_P(VariableStateTest, inferreq_smoke_VariableState_QueryState) { + auto executableNet = PrepareNetwork(); + auto inferReq = executableNet.CreateInferRequest(); + + auto states = inferReq.QueryState(); + ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates"; for (auto&& state : states) { auto name = state.GetName(); @@ -41,23 +127,26 @@ TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) { } } -TEST_P(MemoryStateTest, smoke_MemoryState_SetState) { +TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) { auto executableNet = PrepareNetwork(); + auto inferReq = executableNet.CreateInferRequest(); + const float new_state_val = 13.0f; - for (auto&& state : executableNet.QueryState()) { + for (auto&& state : inferReq.QueryState()) { state.Reset(); - auto element_count = state.GetLastState()->size(); + auto state_val = state.GetState(); + auto element_count = state_val->size(); std::vector new_state_data(element_count, new_state_val); auto stateBlob = InferenceEngine::make_shared_blob( - { InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C }, + { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() }, new_state_data.data(), new_state_data.size()); state.SetState(stateBlob); } - for (auto&& state : executableNet.QueryState()) { - auto lastState = state.GetLastState(); + for (auto&& state : inferReq.QueryState()) { + auto lastState = state.GetState(); auto last_state_size = lastState->size(); auto last_state_data = lastState->cbuffer().as(); ASSERT_TRUE(last_state_size != 0) << "State size should not be 0"; @@ -68,26 +157,29 @@ TEST_P(MemoryStateTest, smoke_MemoryState_SetState) { } } -TEST_P(MemoryStateTest, smoke_MemoryState_Reset) { +TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) { auto executableNet = PrepareNetwork(); + auto inferReq = executableNet.CreateInferRequest(); + const float new_state_val = 13.0f; - for (auto&& state : executableNet.QueryState()) { + for (auto&& state : inferReq.QueryState()) { state.Reset(); - auto element_count = state.GetLastState()->size(); + auto state_val = state.GetState(); + auto element_count = state_val->size(); std::vector new_state_data(element_count, new_state_val); auto stateBlob = InferenceEngine::make_shared_blob( - { InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C }, + { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() }, new_state_data.data(), new_state_data.size()); state.SetState(stateBlob); } - executableNet.QueryState().front().Reset(); + inferReq.QueryState().front().Reset(); - auto states = executableNet.QueryState(); + auto states = inferReq.QueryState(); for (int i = 0; i < states.size(); ++i) { - auto lastState = states[i].GetLastState(); + auto lastState = states[i].GetState(); auto last_state_size = lastState->size(); auto last_state_data = lastState->cbuffer().as(); diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp index c412959cae1..d487b9d22ed 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp @@ -61,4 +61,5 @@ public: MOCK_METHOD1(SetBatch, void(int)); MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int)); + MOCK_METHOD0(QueryState, std::vector>(void)); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp index 1c60681c542..d544e6b0438 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp @@ -11,6 +11,7 @@ #include #include +#include class MockIAsyncInferRequestInternal : public InferenceEngine::IAsyncInferRequestInternal { public: @@ -26,4 +27,5 @@ public: MOCK_CONST_METHOD2(GetPreProcess, void(const char* name, const InferenceEngine::PreProcessInfo**)); MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback)); MOCK_METHOD1(SetBatch, void(int)); + MOCK_METHOD0(QueryState, std::vector()); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp index 9cec0ff321b..54a26b977b0 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp @@ -28,7 +28,7 @@ public: MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void)); MOCK_METHOD1(Export, void(const std::string &)); void Export(std::ostream &) override {}; - MOCK_METHOD0(QueryState, std::vector(void)); + MOCK_METHOD0(QueryState, std::vector(void)); MOCK_METHOD0(GetExecGraphInfo, CNNNetwork(void)); MOCK_METHOD1(SetConfig, void(const std::map &config)); diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp index 30343f8ab40..0cc1e7f919f 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp @@ -11,6 +11,7 @@ #include #include +#include class MockIInferRequestInternal : public InferenceEngine::IInferRequestInternal { public: @@ -20,4 +21,5 @@ public: MOCK_METHOD2(GetBlob, void(const char *name, InferenceEngine::Blob::Ptr &)); MOCK_METHOD3(SetBlob, void(const char*, const InferenceEngine::Blob::Ptr&, const InferenceEngine::PreProcessInfo&)); MOCK_METHOD2(GetPreProcess, void(const char*, const InferenceEngine::PreProcessInfo**)); + MOCK_METHOD0(QueryState, std::vector()); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp index c57cae8276e..13cd11033f5 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp @@ -11,10 +11,10 @@ #include -class MockIMemoryStateInternal : public InferenceEngine::IMemoryStateInternal { +class MockIVariableStateInternal : public InferenceEngine::IVariableStateInternal { public: MOCK_CONST_METHOD0(GetName, std::string()); MOCK_METHOD0(Reset, void()); MOCK_METHOD1(SetState, void(InferenceEngine::Blob::Ptr)); - MOCK_CONST_METHOD0(GetLastState, InferenceEngine::Blob::CPtr()); + MOCK_CONST_METHOD0(GetState, InferenceEngine::Blob::CPtr()); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_ie_imemory_state.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_ie_imemory_state.hpp index 62dc9decc50..32135cc7dbe 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_ie_imemory_state.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_ie_imemory_state.hpp @@ -12,10 +12,10 @@ using namespace InferenceEngine; -class MockIMemoryState : public InferenceEngine::IMemoryState { +class MockIVariableState : public InferenceEngine::IVariableState { public: MOCK_QUALIFIED_METHOD3(GetName, const noexcept, StatusCode(char * , size_t, ResponseDesc *)); MOCK_QUALIFIED_METHOD1(Reset, noexcept, StatusCode(ResponseDesc *)); MOCK_QUALIFIED_METHOD2(SetState, noexcept, StatusCode(Blob::Ptr, ResponseDesc *)); - MOCK_QUALIFIED_METHOD2(GetLastState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *)); + MOCK_QUALIFIED_METHOD2(GetState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *)); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iexecutable_network.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iexecutable_network.hpp index 2af9e8390b8..903cb04beb0 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iexecutable_network.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iexecutable_network.hpp @@ -30,6 +30,6 @@ public: MOCK_QUALIFIED_METHOD3(GetConfig, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp)); MOCK_QUALIFIED_METHOD3(GetMetric, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp)); MOCK_QUALIFIED_METHOD2(GetContext, const noexcept, StatusCode(RemoteContext::Ptr &pContext, ResponseDesc *resp)); - MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IMemoryState::Ptr &, size_t, ResponseDesc *)); + MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *)); MOCK_QUALIFIED_METHOD0(Release, noexcept, void()); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iinfer_request.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iinfer_request.hpp index 3489c0e48b2..613898a6492 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iinfer_request.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iinfer_request.hpp @@ -34,4 +34,5 @@ public: MOCK_QUALIFIED_METHOD3(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, ResponseDesc*)); MOCK_QUALIFIED_METHOD4(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, const PreProcessInfo&, ResponseDesc*)); MOCK_QUALIFIED_METHOD2(SetBatch, noexcept, StatusCode(int batch, ResponseDesc*)); + MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *)); }; diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp index 8064ffb0164..91d43e21eda 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp @@ -484,6 +484,33 @@ static std::shared_ptr makeConvBias(std::vector inputS fn_ptr->set_friendly_name("ConvBias"); return fn_ptr; } + +static std::shared_ptr makeReadConcatSplitAssign(std::vector inputShape = {1, 1, 2, 4}, + InferenceEngine::Precision prc = InferenceEngine::Precision::FP32) { + ngraph::element::Type type = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc); + auto parameter = ngraph::builder::makeParams(type, {inputShape}); + parameter[0]->set_friendly_name("parameter"); + auto init_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {0, 0, 0, 0}); + auto read = std::make_shared(init_const, "v0"); + read->set_friendly_name("read"); + std::vector> args = {parameter[0], read}; + auto conc = std::make_shared(args, 3); + conc->set_friendly_name("concat"); + auto res = std::make_shared(conc); + res->set_friendly_name("result"); + const auto axis = ngraph::op::Constant::create(element::i64, Shape{}, {3}); + axis->set_friendly_name("axis"); + auto crop = std::make_shared(conc, axis, 3); + crop->set_friendly_name("crop"); + auto assign = std::make_shared(crop, "v0"); + assign->set_friendly_name("assign"); + + std::shared_ptr fn_ptr = std::make_shared(ngraph::ResultVector({res}), + ngraph::SinkVector({assign}), + ngraph::ParameterVector{parameter}); + fn_ptr->set_friendly_name("ReadConcatSplitAssign"); + return fn_ptr; +} } // namespace subgraph } // namespace builder } // namespace ngraph diff --git a/inference-engine/tests/unit/inference_engine/cpp_interfaces/ie_memory_state_internal_test.cpp b/inference-engine/tests/unit/inference_engine/cpp_interfaces/ie_memory_state_internal_test.cpp index 64499f2f337..ec2cd9c0133 100644 --- a/inference-engine/tests/unit/inference_engine/cpp_interfaces/ie_memory_state_internal_test.cpp +++ b/inference-engine/tests/unit/inference_engine/cpp_interfaces/ie_memory_state_internal_test.cpp @@ -7,152 +7,183 @@ #include #include -#include +#include #include "unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp" #include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp" +#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp" using namespace ::testing; using namespace std; using namespace InferenceEngine; using namespace InferenceEngine::details; -class MemoryStateTests : public ::testing::Test { +template +inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr impl) { + typename InferRequestBase::Ptr req(new InferRequestBase(impl), [](IInferRequest* p) { + p->Release(); + }); + return InferenceEngine::InferRequest(req); +} + + +class VariableStateTests : public ::testing::Test { protected: shared_ptr mockExeNetworkInternal; - shared_ptr mockMemoryStateInternal; + shared_ptr mockInferRequestInternal; + shared_ptr mockVariableStateInternal; virtual void SetUp() { mockExeNetworkInternal = make_shared(); - mockMemoryStateInternal = make_shared(); + mockInferRequestInternal = make_shared(); + mockVariableStateInternal = make_shared(); } }; -TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) { +TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn(1); - toReturn[0] = mockMemoryStateInternal; + std::vector toReturn(1); + toReturn[0] = mockVariableStateInternal; EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); auto state = net.QueryState(); ASSERT_EQ(state.size(), 1); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) { +TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; + std::vector toReturn; EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn)); auto state = net.QueryState(); ASSERT_EQ(state.size(), 0); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) { +TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn)); auto state = net.QueryState(); ASSERT_EQ(state.size(), 2); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStatePropagatesReset) { +TEST_F(VariableStateTests, VariableStatePropagatesReset) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1); + EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1); auto state = net.QueryState(); state.front().Reset(); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) { +TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error"))); + EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error"))); auto state = net.QueryState(); EXPECT_ANY_THROW(state.front().Reset()); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) { +TEST_F(VariableStateTests, VariableStatePropagatesGetName) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName")); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); auto state = net.QueryState(); EXPECT_STREQ(state.front().GetName().c_str(), "someName"); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) { +TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName")); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); - IMemoryState::Ptr pState; + IVariableState::Ptr pState; static_cast(net)->QueryState(pState, 0, nullptr); char *name = reinterpret_cast(1); EXPECT_NO_THROW(pState->GetName(name, 0, nullptr)); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) { +TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName")); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); - IMemoryState::Ptr pState; + IVariableState::Ptr pState; static_cast(net)->QueryState(pState, 0, nullptr); char name[1]; EXPECT_NO_THROW(pState->GetName(name, 1, nullptr)); EXPECT_STREQ(name, ""); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) { +TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; - toReturn.push_back(mockMemoryStateInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName")); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); - IMemoryState::Ptr pState; + IVariableState::Ptr pState; static_cast(net)->QueryState(pState, 0, nullptr); char name[2]; EXPECT_NO_THROW(pState->GetName(name, 2, nullptr)); EXPECT_STREQ(name, "s"); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) { +TEST_F(VariableStateTests, VariableStateCanPropagateSetState) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; + std::vector toReturn; Blob::Ptr saver; - toReturn.push_back(mockMemoryStateInternal); + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver)); + EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver)); float data[] = {123, 124, 125}; auto stateBlob = make_shared_blob({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data)); @@ -161,47 +192,50 @@ TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) { ASSERT_FLOAT_EQ(saver->buffer().as()[0], 123); ASSERT_FLOAT_EQ(saver->buffer().as()[1], 124); ASSERT_FLOAT_EQ(saver->buffer().as()[2], 125); + IE_SUPPRESS_DEPRECATED_END } -TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) { +TEST_F(VariableStateTests, VariableStateCanPropagateGetLastState) { + IE_SUPPRESS_DEPRECATED_START auto net = make_executable_network(mockExeNetworkInternal); - std::vector toReturn; + std::vector toReturn; float data[] = {123, 124, 125}; auto stateBlob = make_shared_blob({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data)); - toReturn.push_back(mockMemoryStateInternal); + toReturn.push_back(mockVariableStateInternal); EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn)); - EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob)); + EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob)); - auto saver = net.QueryState().front().GetLastState(); + auto saver = net.QueryState().front().GetState(); ASSERT_FLOAT_EQ(saver->cbuffer().as()[0], 123); ASSERT_FLOAT_EQ(saver->cbuffer().as()[1], 124); ASSERT_FLOAT_EQ(saver->cbuffer().as()[2], 125); + IE_SUPPRESS_DEPRECATED_END } -class MemoryStateInternalMockImpl : public MemoryStateInternal { +class VariableStateInternalMockImpl : public VariableStateInternal { public: - using MemoryStateInternal::MemoryStateInternal; + using VariableStateInternal::VariableStateInternal; MOCK_METHOD0(Reset, void()); }; -TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) { - IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name")); +TEST_F(VariableStateTests, VariableStateInternalCanSaveName) { + IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name")); ASSERT_STREQ(pState->GetName().c_str(), "name"); } -TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) { - IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name")); +TEST_F(VariableStateTests, VariableStateInternalCanSaveState) { + IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name")); float data[] = {123, 124, 125}; auto stateBlob = make_shared_blob({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data)); pState->SetState(stateBlob); - auto saver = pState->GetLastState(); + auto saver = pState->GetState(); ASSERT_FLOAT_EQ(saver->cbuffer().as()[0], 123); ASSERT_FLOAT_EQ(saver->cbuffer().as()[1], 124); @@ -209,8 +243,8 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) { } -TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) { - IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name")); +TEST_F(VariableStateTests, VariableStateInternalCanSaveStateByReference) { + IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name")); float data[] = {123, 124, 125}; auto stateBlob = make_shared_blob({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data)); @@ -219,9 +253,162 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) { data[0] = 121; data[1] = 122; data[2] = 123; - auto saver = pState->GetLastState(); + auto saver = pState->GetState(); ASSERT_FLOAT_EQ(saver->cbuffer().as()[0], 121); ASSERT_FLOAT_EQ(saver->cbuffer().as()[1], 122); ASSERT_FLOAT_EQ(saver->cbuffer().as()[2], 123); } + +// Tests for InferRequest::QueryState +TEST_F(VariableStateTests, InferRequestCanConvertOneVariableStateFromCppToAPI) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn(1); + toReturn[0] = mockVariableStateInternal; + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); + + auto state = req.QueryState(); + ASSERT_EQ(state.size(), 1); +} + +TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillOnce(Return(toReturn)); + + auto state = req.QueryState(); + ASSERT_EQ(state.size(), 0); +} + +TEST_F(VariableStateTests, InferRequestCanConvert2VariableStatesFromCPPtoAPI) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn)); + + auto state = req.QueryState(); + ASSERT_EQ(state.size(), 2); +} + +TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1); + + auto state = req.QueryState(); + state.front().Reset(); +} + +TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error"))); + + auto state = req.QueryState(); + EXPECT_ANY_THROW(state.front().Reset()); +} + +TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetName) { +auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); + + auto state = req.QueryState(); + EXPECT_STREQ(state.front().GetName().c_str(), "someName"); +} + +TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); + + IVariableState::Ptr pState; + + static_cast(req)->QueryState(pState, 0, nullptr); + char *name = reinterpret_cast(1); + EXPECT_NO_THROW(pState->GetName(name, 0, nullptr)); +} + +TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); + + IVariableState::Ptr pState; + + static_cast(req)->QueryState(pState, 0, nullptr); + char name[1]; + EXPECT_NO_THROW(pState->GetName(name, 1, nullptr)); + EXPECT_STREQ(name, ""); +} + +TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName")); + + IVariableState::Ptr pState; + + static_cast(req)->QueryState(pState, 0, nullptr); + char name[2]; + EXPECT_NO_THROW(pState->GetName(name, 2, nullptr)); + EXPECT_STREQ(name, "s"); +} + +TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + Blob::Ptr saver; + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver)); + + float data[] = {123, 124, 125}; + auto stateBlob = make_shared_blob({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data)); + + EXPECT_NO_THROW(req.QueryState().front().SetState(stateBlob)); + ASSERT_FLOAT_EQ(saver->buffer().as()[0], 123); + ASSERT_FLOAT_EQ(saver->buffer().as()[1], 124); + ASSERT_FLOAT_EQ(saver->buffer().as()[2], 125); +} + +TEST_F(VariableStateTests, InfReqVariableStateCanPropagateGetLastState) { + auto req = make_infer_request(mockInferRequestInternal); + std::vector toReturn; + + float data[] = {123, 124, 125}; + auto stateBlob = make_shared_blob({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data)); + + toReturn.push_back(mockVariableStateInternal); + + EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn)); + EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob)); + + auto saver = req.QueryState().front().GetState(); + ASSERT_FLOAT_EQ(saver->cbuffer().as()[0], 123); + ASSERT_FLOAT_EQ(saver->cbuffer().as()[1], 124); + ASSERT_FLOAT_EQ(saver->cbuffer().as()[2], 125); +} diff --git a/inference-engine/tests/unit/inference_engine/ie_executable_network_test.cpp b/inference-engine/tests/unit/inference_engine/ie_executable_network_test.cpp index 0632a65c2a2..810341ff9b5 100644 --- a/inference-engine/tests/unit/inference_engine/ie_executable_network_test.cpp +++ b/inference-engine/tests/unit/inference_engine/ie_executable_network_test.cpp @@ -127,6 +127,7 @@ TEST_F(ExecutableNetworkTests, OperatorAmpersand) { ASSERT_EQ(exeNet_p, mockIExeNet_p); } +IE_SUPPRESS_DEPRECATED_START TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) { EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _)) .Times(1) @@ -138,21 +139,22 @@ TEST_F(ExecutableNetworkTests, QueryStateIfReturnOutOfBounds) { EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _)) .Times(1) .WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS)); - std::vector MemState_; + std::vector MemState_; EXPECT_NO_THROW(MemState_ = exeNetwork->QueryState()); EXPECT_EQ(MemState_.size(), 0); } TEST_F(ExecutableNetworkTests, QueryState) { - std::shared_ptr mockIMemState_p = std::make_shared(); + std::shared_ptr mockIMemState_p = std::make_shared(); EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _)) .Times(2) .WillOnce(DoAll(SetArgReferee<0>(mockIMemState_p), Return(InferenceEngine::OK))) .WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS)); - std::vector MemState_v; + std::vector MemState_v; EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState()); EXPECT_EQ(MemState_v.size(), 1); } +IE_SUPPRESS_DEPRECATED_END class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests { protected: diff --git a/inference-engine/tests_deprecated/functional/ie_tests/src/custom_matcher.cpp b/inference-engine/tests_deprecated/functional/ie_tests/src/custom_matcher.cpp index 43f7f9a969b..658c1618a93 100644 --- a/inference-engine/tests_deprecated/functional/ie_tests/src/custom_matcher.cpp +++ b/inference-engine/tests_deprecated/functional/ie_tests/src/custom_matcher.cpp @@ -207,6 +207,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() { } } + IE_SUPPRESS_DEPRECATED_START if (fetchResult.reset) { auto states = executableApi.QueryState(); ASSERT_FALSE(states.empty()); @@ -218,6 +219,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() { outputs["reset"] = nullptr; //continue; } + IE_SUPPRESS_DEPRECATED_END //FAIL()<<"stop after one frame"; diff --git a/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp b/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp index 0e8443df233..947e248de6e 100644 --- a/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp +++ b/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp @@ -808,6 +808,7 @@ void GNAQueryStateMatcher :: match() { EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess)); #endif + IE_SUPPRESS_DEPRECATED_START try { loadNetwork(); if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) { @@ -830,6 +831,7 @@ void GNAQueryStateMatcher :: match() { catch(...) { FAIL() << "unknown exception thrown"; } + IE_SUPPRESS_DEPRECATED_END }