diff --git a/src/core/include/openvino/runtime/tensor.hpp b/src/core/include/openvino/runtime/tensor.hpp index ed027da1a06..dfbf71e22db 100644 --- a/src/core/include/openvino/runtime/tensor.hpp +++ b/src/core/include/openvino/runtime/tensor.hpp @@ -20,6 +20,7 @@ namespace InferenceEngine { class Blob; class IAsyncInferRequestWrapper; +class IVariableStateWrapper; } // namespace InferenceEngine namespace ov { @@ -31,6 +32,7 @@ class RemoteContext; class VariableState; class ISyncInferRequest; class IInferRequestInternalWrapper; +class IVariableStateInternalWrapper; /** * @brief Tensor API holding host memory @@ -57,7 +59,9 @@ protected: friend class ov::VariableState; friend class ov::ISyncInferRequest; friend class ov::IInferRequestInternalWrapper; + friend class ov::IVariableStateInternalWrapper; friend class InferenceEngine::IAsyncInferRequestWrapper; + friend class InferenceEngine::IVariableStateWrapper; public: /// @brief Default constructor diff --git a/src/inference/dev_api/openvino/runtime/iasync_infer_request.hpp b/src/inference/dev_api/openvino/runtime/iasync_infer_request.hpp index ab1ff490d4e..687b05030cd 100644 --- a/src/inference/dev_api/openvino/runtime/iasync_infer_request.hpp +++ b/src/inference/dev_api/openvino/runtime/iasync_infer_request.hpp @@ -129,7 +129,7 @@ public: * State control essential for recurrent models. * @return Vector of Variable State objects. */ - std::vector query_state() const override; + std::vector> query_state() const override; /** * @brief Gets pointer to compiled model (usually synchronous request holds the compiled model) diff --git a/src/inference/dev_api/openvino/runtime/iinfer_request.hpp b/src/inference/dev_api/openvino/runtime/iinfer_request.hpp index a5475ab1641..e183087ddfe 100644 --- a/src/inference/dev_api/openvino/runtime/iinfer_request.hpp +++ b/src/inference/dev_api/openvino/runtime/iinfer_request.hpp @@ -15,6 +15,7 @@ #include #include "openvino/runtime/common.hpp" +#include "openvino/runtime/ivariable_state.hpp" #include "openvino/runtime/profiling_info.hpp" #include "openvino/runtime/tensor.hpp" @@ -87,7 +88,7 @@ public: * State control essential for recurrent models. * @return Vector of Variable State objects. */ - virtual std::vector query_state() const = 0; + virtual std::vector> query_state() const = 0; /** * @brief Gets pointer to compiled model (usually synchronous request holds the compiled model) diff --git a/src/inference/dev_api/openvino/runtime/ivariable_state.hpp b/src/inference/dev_api/openvino/runtime/ivariable_state.hpp new file mode 100644 index 00000000000..b2dea052060 --- /dev/null +++ b/src/inference/dev_api/openvino/runtime/ivariable_state.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/** + * @brief OpenVINO Runtime IVariableState interface + * @file openvino/runtime/ivariable_state.hpp + */ + +#pragma once + +#include +#include + +#include "openvino/runtime/common.hpp" +#include "openvino/runtime/tensor.hpp" + +namespace ov { + +/** + * @interface IVariableState + * @brief Minimal interface for variable state implementation + * @ingroup ov_dev_api_variable_state_api + */ +class OPENVINO_RUNTIME_API IVariableState : public std::enable_shared_from_this { +public: + explicit IVariableState(const std::string& name); + + /** + * @brief Gets a variable state name + * @return A string representing variable state name + */ + virtual const std::string& get_name() const; + + /** + * @brief Reset internal variable state for relevant infer request, to a value specified as + * default for according `ReadValue` node + */ + virtual void reset(); + + /** + * @brief Sets the new state for the next inference + * @param newState A new state + */ + virtual void set_state(const ov::Tensor& state); + + /** + * @brief Returns the value of the variable state. + * @return The value of the variable state + */ + virtual const ov::Tensor& get_state() const; + +protected: + /** + * @brief A default dtor + */ + virtual ~IVariableState(); + + std::string m_name; + ov::Tensor m_state; +}; + +} // namespace ov diff --git a/src/inference/include/openvino/runtime/variable_state.hpp b/src/inference/include/openvino/runtime/variable_state.hpp index 533393da581..29c1baa6f39 100644 --- a/src/inference/include/openvino/runtime/variable_state.hpp +++ b/src/inference/include/openvino/runtime/variable_state.hpp @@ -16,13 +16,13 @@ #include "openvino/runtime/tensor.hpp" namespace InferenceEngine { -class IVariableStateInternal; class IAsyncInferRequestWrapper; } // namespace InferenceEngine namespace ov { class InferRequest; +class IVariableState; class IInferRequestInternalWrapper; /** @@ -30,7 +30,7 @@ class IInferRequestInternalWrapper; * @ingroup ov_runtime_cpp_api */ class OPENVINO_RUNTIME_API VariableState { - std::shared_ptr _impl; + std::shared_ptr _impl; std::vector> _so; /** @@ -39,8 +39,7 @@ class OPENVINO_RUNTIME_API VariableState { * @param so Optional: plugin to use. This is required to ensure that VariableState can work properly even if a * plugin object is destroyed. */ - VariableState(const std::shared_ptr& impl, - const std::vector>& so); + VariableState(const std::shared_ptr& impl, const std::vector>& so); friend class ov::InferRequest; friend class ov::IInferRequestInternalWrapper; diff --git a/src/inference/src/cpp/ie_variable_state.cpp b/src/inference/src/cpp/ie_variable_state.cpp index 212fa91158a..88d2221edfa 100644 --- a/src/inference/src/cpp/ie_variable_state.cpp +++ b/src/inference/src/cpp/ie_variable_state.cpp @@ -5,6 +5,7 @@ #include "cpp/ie_memory_state.hpp" #include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp" #include "openvino/core/except.hpp" +#include "openvino/runtime/ivariable_state.hpp" #include "openvino/runtime/variable_state.hpp" #define VARIABLE_CALL_STATEMENT(...) \ @@ -65,26 +66,27 @@ VariableState::~VariableState() { _impl = {}; } -VariableState::VariableState(const ie::IVariableStateInternal::Ptr& impl, const std::vector>& so) +VariableState::VariableState(const std::shared_ptr& impl, + const std::vector>& so) : _impl{impl}, _so{so} { OPENVINO_ASSERT(_impl != nullptr, "VariableState was not initialized."); } void VariableState::reset() { - OV_VARIABLE_CALL_STATEMENT(_impl->Reset()); + OV_VARIABLE_CALL_STATEMENT(_impl->reset()); } std::string VariableState::get_name() const { - OV_VARIABLE_CALL_STATEMENT(return _impl->GetName()); + OV_VARIABLE_CALL_STATEMENT(return _impl->get_name()); } Tensor VariableState::get_state() const { - OV_VARIABLE_CALL_STATEMENT(return {std::const_pointer_cast(_impl->GetState()), {_so}}); + OV_VARIABLE_CALL_STATEMENT(return _impl->get_state()); } void VariableState::set_state(const Tensor& state) { - OV_VARIABLE_CALL_STATEMENT(_impl->SetState(state._impl)); + OV_VARIABLE_CALL_STATEMENT(_impl->set_state(state)); } } // namespace ov diff --git a/src/inference/src/dev/converter_utils.cpp b/src/inference/src/dev/converter_utils.cpp index cdb47de5e8c..88bded83881 100644 --- a/src/inference/src/dev/converter_utils.cpp +++ b/src/inference/src/dev/converter_utils.cpp @@ -12,6 +12,7 @@ #include "cnn_network_ngraph_impl.hpp" #include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp" #include "cpp_interfaces/interface/ie_iplugin_internal.hpp" +#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp" #include "icompiled_model_wrapper.hpp" #include "ie_blob.h" #include "ie_common.h" @@ -29,6 +30,7 @@ #include "openvino/runtime/icompiled_model.hpp" #include "openvino/runtime/iinfer_request.hpp" #include "openvino/runtime/iplugin.hpp" +#include "openvino/runtime/ivariable_state.hpp" #include "openvino/runtime/profiling_info.hpp" #include "openvino/runtime/remote_context.hpp" #include "openvino/runtime/tensor.hpp" @@ -185,6 +187,31 @@ std::shared_ptr ov::legacy_convert::convert_model(const Inferen namespace ov { +class IVariableStateInternalWrapper : public InferenceEngine::IVariableStateInternal { + std::shared_ptr m_state; + +public: + IVariableStateInternalWrapper(const std::shared_ptr& state) + : InferenceEngine::IVariableStateInternal(state->get_name()), + m_state(state) {} + + std::string GetName() const override { + return m_state->get_name(); + } + + void Reset() override { + m_state->reset(); + } + + void SetState(const InferenceEngine::Blob::Ptr& newState) override { + m_state->set_state(ov::Tensor(newState, {})); + } + + InferenceEngine::Blob::CPtr GetState() const override { + return m_state->get_state()._impl; + } +}; + class IInferencePluginWrapper : public InferenceEngine::IInferencePlugin { public: IInferencePluginWrapper(const std::shared_ptr& plugin) : m_plugin(plugin) { @@ -521,7 +548,7 @@ public: auto res = m_request->query_state(); std::vector> ret; for (const auto& state : res) { - ret.emplace_back(state._impl); + ret.emplace_back(std::make_shared(state)); } return ret; } @@ -558,6 +585,30 @@ private: namespace InferenceEngine { +class IVariableStateWrapper : public ov::IVariableState { +private: + std::shared_ptr m_state; + mutable ov::Tensor m_converted_state; + +public: + explicit IVariableStateWrapper(const std::shared_ptr& state) + : ov::IVariableState(state->GetName()), + m_state(state) {} + + void reset() override { + m_state->Reset(); + } + + void set_state(const ov::Tensor& state) override { + m_state->SetState(state._impl); + } + + const ov::Tensor& get_state() const override { + m_converted_state = ov::Tensor(std::const_pointer_cast(m_state->GetState()), {}); + return m_converted_state; + } +}; + class IAsyncInferRequestWrapper : public ov::IAsyncInferRequest { public: IAsyncInferRequestWrapper(const std::shared_ptr& request) @@ -676,12 +727,11 @@ public: m_request->SetBlobs(get_legacy_name_from_port(port), blobs); } - std::vector query_state() const override { - std::vector variable_states; + std::vector> query_state() const override { + std::vector> variable_states; std::vector> soVec; - soVec = {m_request->getPointerToSo()}; for (auto&& state : m_request->QueryState()) { - variable_states.emplace_back(ov::VariableState{state, soVec}); + variable_states.emplace_back(std::make_shared(state)); } return variable_states; } diff --git a/src/inference/src/dev/iasync_infer_request.cpp b/src/inference/src/dev/iasync_infer_request.cpp index e1c5bd16852..385baba838c 100644 --- a/src/inference/src/dev/iasync_infer_request.cpp +++ b/src/inference/src/dev/iasync_infer_request.cpp @@ -7,6 +7,7 @@ #include #include "openvino/runtime/isync_infer_request.hpp" +#include "openvino/runtime/ivariable_state.hpp" #include "openvino/runtime/variable_state.hpp" #include "threading/ie_immediate_executor.hpp" #include "threading/ie_istreams_executor.hpp" @@ -101,7 +102,7 @@ void ov::IAsyncInferRequest::set_callback(std::function ov::IAsyncInferRequest::query_state() const { +std::vector> ov::IAsyncInferRequest::query_state() const { check_state(); return m_sync_request->query_state(); } diff --git a/src/inference/src/dev/ivariable_state.cpp b/src/inference/src/dev/ivariable_state.cpp new file mode 100644 index 00000000000..d65409024a9 --- /dev/null +++ b/src/inference/src/dev/ivariable_state.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/runtime/ivariable_state.hpp" + +#include "openvino/core/except.hpp" + +ov::IVariableState::IVariableState(const std::string& name) : m_name(name) {} + +ov::IVariableState::~IVariableState() = default; + +const std::string& ov::IVariableState::get_name() const { + return m_name; +} + +void ov::IVariableState::reset() { + OPENVINO_NOT_IMPLEMENTED; +} + +void ov::IVariableState::set_state(const ov::Tensor& state) { + m_state = state; +} + +const ov::Tensor& ov::IVariableState::get_state() const { + return m_state; +} diff --git a/src/inference/src/infer_request.cpp b/src/inference/src/infer_request.cpp index 3bd40ce89be..ca407c6545b 100644 --- a/src/inference/src/infer_request.cpp +++ b/src/inference/src/infer_request.cpp @@ -264,9 +264,7 @@ std::vector InferRequest::query_state() { std::vector variable_states; OV_INFER_REQ_CALL_STATEMENT({ for (auto&& state : _impl->query_state()) { - auto soVec = state._so; - soVec.emplace_back(_so); - variable_states.emplace_back(ov::VariableState{state._impl, soVec}); + variable_states.emplace_back(ov::VariableState{state, {_so}}); } }) return variable_states; diff --git a/src/plugins/template/src/infer_request.cpp b/src/plugins/template/src/infer_request.cpp index 0b82095db30..8d2863c224c 100644 --- a/src/plugins/template/src/infer_request.cpp +++ b/src/plugins/template/src/infer_request.cpp @@ -75,7 +75,7 @@ TemplatePlugin::InferRequest::InferRequest(const std::shared_ptr TemplatePlugin::InferRequest::query_state() const { +std::vector> TemplatePlugin::InferRequest::query_state() const { OPENVINO_NOT_IMPLEMENTED; } diff --git a/src/plugins/template/src/infer_request.hpp b/src/plugins/template/src/infer_request.hpp index a79398bdb90..35b9c6f75b5 100644 --- a/src/plugins/template/src/infer_request.hpp +++ b/src/plugins/template/src/infer_request.hpp @@ -28,7 +28,7 @@ public: ~InferRequest(); void infer() override; - std::vector query_state() const override; + std::vector> query_state() const override; std::vector get_profiling_info() const override; // pipeline methods-stages which are used in async infer request implementation and assigned to particular executor