Added ov::IVariableState (#15843)

* Added ov::IVariableState

* Added variable state

* Try to fix Windows

* Fixed export
This commit is contained in:
Ilya Churaev 2023-02-22 14:30:46 +04:00 committed by GitHub
parent c8643a9a30
commit 548f972e19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 167 additions and 22 deletions

View File

@ -20,6 +20,7 @@
namespace InferenceEngine { namespace InferenceEngine {
class Blob; class Blob;
class IAsyncInferRequestWrapper; class IAsyncInferRequestWrapper;
class IVariableStateWrapper;
} // namespace InferenceEngine } // namespace InferenceEngine
namespace ov { namespace ov {
@ -31,6 +32,7 @@ class RemoteContext;
class VariableState; class VariableState;
class ISyncInferRequest; class ISyncInferRequest;
class IInferRequestInternalWrapper; class IInferRequestInternalWrapper;
class IVariableStateInternalWrapper;
/** /**
* @brief Tensor API holding host memory * @brief Tensor API holding host memory
@ -57,7 +59,9 @@ protected:
friend class ov::VariableState; friend class ov::VariableState;
friend class ov::ISyncInferRequest; friend class ov::ISyncInferRequest;
friend class ov::IInferRequestInternalWrapper; friend class ov::IInferRequestInternalWrapper;
friend class ov::IVariableStateInternalWrapper;
friend class InferenceEngine::IAsyncInferRequestWrapper; friend class InferenceEngine::IAsyncInferRequestWrapper;
friend class InferenceEngine::IVariableStateWrapper;
public: public:
/// @brief Default constructor /// @brief Default constructor

View File

@ -129,7 +129,7 @@ public:
* State control essential for recurrent models. * State control essential for recurrent models.
* @return Vector of Variable State objects. * @return Vector of Variable State objects.
*/ */
std::vector<ov::VariableState> query_state() const override; std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override;
/** /**
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model) * @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)

View File

@ -15,6 +15,7 @@
#include <vector> #include <vector>
#include "openvino/runtime/common.hpp" #include "openvino/runtime/common.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/profiling_info.hpp" #include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/tensor.hpp" #include "openvino/runtime/tensor.hpp"
@ -87,7 +88,7 @@ public:
* State control essential for recurrent models. * State control essential for recurrent models.
* @return Vector of Variable State objects. * @return Vector of Variable State objects.
*/ */
virtual std::vector<ov::VariableState> query_state() const = 0; virtual std::vector<std::shared_ptr<ov::IVariableState>> query_state() const = 0;
/** /**
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model) * @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)

View File

@ -0,0 +1,63 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
/**
* @brief OpenVINO Runtime IVariableState interface
* @file openvino/runtime/ivariable_state.hpp
*/
#pragma once
#include <memory>
#include <string>
#include "openvino/runtime/common.hpp"
#include "openvino/runtime/tensor.hpp"
namespace ov {
/**
* @interface IVariableState
* @brief Minimal interface for variable state implementation
* @ingroup ov_dev_api_variable_state_api
*/
class OPENVINO_RUNTIME_API IVariableState : public std::enable_shared_from_this<IVariableState> {
public:
explicit IVariableState(const std::string& name);
/**
* @brief Gets a variable state name
* @return A string representing variable state name
*/
virtual const std::string& get_name() const;
/**
* @brief Reset internal variable state for relevant infer request, to a value specified as
* default for according `ReadValue` node
*/
virtual void reset();
/**
* @brief Sets the new state for the next inference
* @param newState A new state
*/
virtual void set_state(const ov::Tensor& state);
/**
* @brief Returns the value of the variable state.
* @return The value of the variable state
*/
virtual const ov::Tensor& get_state() const;
protected:
/**
* @brief A default dtor
*/
virtual ~IVariableState();
std::string m_name;
ov::Tensor m_state;
};
} // namespace ov

View File

@ -16,13 +16,13 @@
#include "openvino/runtime/tensor.hpp" #include "openvino/runtime/tensor.hpp"
namespace InferenceEngine { namespace InferenceEngine {
class IVariableStateInternal;
class IAsyncInferRequestWrapper; class IAsyncInferRequestWrapper;
} // namespace InferenceEngine } // namespace InferenceEngine
namespace ov { namespace ov {
class InferRequest; class InferRequest;
class IVariableState;
class IInferRequestInternalWrapper; class IInferRequestInternalWrapper;
/** /**
@ -30,7 +30,7 @@ class IInferRequestInternalWrapper;
* @ingroup ov_runtime_cpp_api * @ingroup ov_runtime_cpp_api
*/ */
class OPENVINO_RUNTIME_API VariableState { class OPENVINO_RUNTIME_API VariableState {
std::shared_ptr<InferenceEngine::IVariableStateInternal> _impl; std::shared_ptr<ov::IVariableState> _impl;
std::vector<std::shared_ptr<void>> _so; std::vector<std::shared_ptr<void>> _so;
/** /**
@ -39,8 +39,7 @@ class OPENVINO_RUNTIME_API VariableState {
* @param so Optional: plugin to use. This is required to ensure that VariableState can work properly even if a * @param so Optional: plugin to use. This is required to ensure that VariableState can work properly even if a
* plugin object is destroyed. * plugin object is destroyed.
*/ */
VariableState(const std::shared_ptr<InferenceEngine::IVariableStateInternal>& impl, VariableState(const std::shared_ptr<ov::IVariableState>& impl, const std::vector<std::shared_ptr<void>>& so);
const std::vector<std::shared_ptr<void>>& so);
friend class ov::InferRequest; friend class ov::InferRequest;
friend class ov::IInferRequestInternalWrapper; friend class ov::IInferRequestInternalWrapper;

View File

@ -5,6 +5,7 @@
#include "cpp/ie_memory_state.hpp" #include "cpp/ie_memory_state.hpp"
#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp" #include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp"
#include "openvino/core/except.hpp" #include "openvino/core/except.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/variable_state.hpp" #include "openvino/runtime/variable_state.hpp"
#define VARIABLE_CALL_STATEMENT(...) \ #define VARIABLE_CALL_STATEMENT(...) \
@ -65,26 +66,27 @@ VariableState::~VariableState() {
_impl = {}; _impl = {};
} }
VariableState::VariableState(const ie::IVariableStateInternal::Ptr& impl, const std::vector<std::shared_ptr<void>>& so) VariableState::VariableState(const std::shared_ptr<ov::IVariableState>& impl,
const std::vector<std::shared_ptr<void>>& so)
: _impl{impl}, : _impl{impl},
_so{so} { _so{so} {
OPENVINO_ASSERT(_impl != nullptr, "VariableState was not initialized."); OPENVINO_ASSERT(_impl != nullptr, "VariableState was not initialized.");
} }
void VariableState::reset() { void VariableState::reset() {
OV_VARIABLE_CALL_STATEMENT(_impl->Reset()); OV_VARIABLE_CALL_STATEMENT(_impl->reset());
} }
std::string VariableState::get_name() const { std::string VariableState::get_name() const {
OV_VARIABLE_CALL_STATEMENT(return _impl->GetName()); OV_VARIABLE_CALL_STATEMENT(return _impl->get_name());
} }
Tensor VariableState::get_state() const { Tensor VariableState::get_state() const {
OV_VARIABLE_CALL_STATEMENT(return {std::const_pointer_cast<ie::Blob>(_impl->GetState()), {_so}}); OV_VARIABLE_CALL_STATEMENT(return _impl->get_state());
} }
void VariableState::set_state(const Tensor& state) { void VariableState::set_state(const Tensor& state) {
OV_VARIABLE_CALL_STATEMENT(_impl->SetState(state._impl)); OV_VARIABLE_CALL_STATEMENT(_impl->set_state(state));
} }
} // namespace ov } // namespace ov

View File

@ -12,6 +12,7 @@
#include "cnn_network_ngraph_impl.hpp" #include "cnn_network_ngraph_impl.hpp"
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp" #include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp" #include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp"
#include "icompiled_model_wrapper.hpp" #include "icompiled_model_wrapper.hpp"
#include "ie_blob.h" #include "ie_blob.h"
#include "ie_common.h" #include "ie_common.h"
@ -29,6 +30,7 @@
#include "openvino/runtime/icompiled_model.hpp" #include "openvino/runtime/icompiled_model.hpp"
#include "openvino/runtime/iinfer_request.hpp" #include "openvino/runtime/iinfer_request.hpp"
#include "openvino/runtime/iplugin.hpp" #include "openvino/runtime/iplugin.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/profiling_info.hpp" #include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/remote_context.hpp" #include "openvino/runtime/remote_context.hpp"
#include "openvino/runtime/tensor.hpp" #include "openvino/runtime/tensor.hpp"
@ -185,6 +187,31 @@ std::shared_ptr<const ov::Model> ov::legacy_convert::convert_model(const Inferen
namespace ov { namespace ov {
class IVariableStateInternalWrapper : public InferenceEngine::IVariableStateInternal {
std::shared_ptr<ov::IVariableState> m_state;
public:
IVariableStateInternalWrapper(const std::shared_ptr<ov::IVariableState>& state)
: InferenceEngine::IVariableStateInternal(state->get_name()),
m_state(state) {}
std::string GetName() const override {
return m_state->get_name();
}
void Reset() override {
m_state->reset();
}
void SetState(const InferenceEngine::Blob::Ptr& newState) override {
m_state->set_state(ov::Tensor(newState, {}));
}
InferenceEngine::Blob::CPtr GetState() const override {
return m_state->get_state()._impl;
}
};
class IInferencePluginWrapper : public InferenceEngine::IInferencePlugin { class IInferencePluginWrapper : public InferenceEngine::IInferencePlugin {
public: public:
IInferencePluginWrapper(const std::shared_ptr<ov::IPlugin>& plugin) : m_plugin(plugin) { IInferencePluginWrapper(const std::shared_ptr<ov::IPlugin>& plugin) : m_plugin(plugin) {
@ -521,7 +548,7 @@ public:
auto res = m_request->query_state(); auto res = m_request->query_state();
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> ret; std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> ret;
for (const auto& state : res) { for (const auto& state : res) {
ret.emplace_back(state._impl); ret.emplace_back(std::make_shared<ov::IVariableStateInternalWrapper>(state));
} }
return ret; return ret;
} }
@ -558,6 +585,30 @@ private:
namespace InferenceEngine { namespace InferenceEngine {
class IVariableStateWrapper : public ov::IVariableState {
private:
std::shared_ptr<InferenceEngine::IVariableStateInternal> m_state;
mutable ov::Tensor m_converted_state;
public:
explicit IVariableStateWrapper(const std::shared_ptr<InferenceEngine::IVariableStateInternal>& state)
: ov::IVariableState(state->GetName()),
m_state(state) {}
void reset() override {
m_state->Reset();
}
void set_state(const ov::Tensor& state) override {
m_state->SetState(state._impl);
}
const ov::Tensor& get_state() const override {
m_converted_state = ov::Tensor(std::const_pointer_cast<InferenceEngine::Blob>(m_state->GetState()), {});
return m_converted_state;
}
};
class IAsyncInferRequestWrapper : public ov::IAsyncInferRequest { class IAsyncInferRequestWrapper : public ov::IAsyncInferRequest {
public: public:
IAsyncInferRequestWrapper(const std::shared_ptr<InferenceEngine::IInferRequestInternal>& request) IAsyncInferRequestWrapper(const std::shared_ptr<InferenceEngine::IInferRequestInternal>& request)
@ -676,12 +727,11 @@ public:
m_request->SetBlobs(get_legacy_name_from_port(port), blobs); m_request->SetBlobs(get_legacy_name_from_port(port), blobs);
} }
std::vector<ov::VariableState> query_state() const override { std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override {
std::vector<ov::VariableState> variable_states; std::vector<std::shared_ptr<ov::IVariableState>> variable_states;
std::vector<std::shared_ptr<void>> soVec; std::vector<std::shared_ptr<void>> soVec;
soVec = {m_request->getPointerToSo()};
for (auto&& state : m_request->QueryState()) { for (auto&& state : m_request->QueryState()) {
variable_states.emplace_back(ov::VariableState{state, soVec}); variable_states.emplace_back(std::make_shared<InferenceEngine::IVariableStateWrapper>(state));
} }
return variable_states; return variable_states;
} }

View File

@ -7,6 +7,7 @@
#include <memory> #include <memory>
#include "openvino/runtime/isync_infer_request.hpp" #include "openvino/runtime/isync_infer_request.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/variable_state.hpp" #include "openvino/runtime/variable_state.hpp"
#include "threading/ie_immediate_executor.hpp" #include "threading/ie_immediate_executor.hpp"
#include "threading/ie_istreams_executor.hpp" #include "threading/ie_istreams_executor.hpp"
@ -101,7 +102,7 @@ void ov::IAsyncInferRequest::set_callback(std::function<void(std::exception_ptr)
m_callback = std::move(callback); m_callback = std::move(callback);
} }
std::vector<ov::VariableState> ov::IAsyncInferRequest::query_state() const { std::vector<std::shared_ptr<ov::IVariableState>> ov::IAsyncInferRequest::query_state() const {
check_state(); check_state();
return m_sync_request->query_state(); return m_sync_request->query_state();
} }

View File

@ -0,0 +1,27 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/core/except.hpp"
ov::IVariableState::IVariableState(const std::string& name) : m_name(name) {}
ov::IVariableState::~IVariableState() = default;
const std::string& ov::IVariableState::get_name() const {
return m_name;
}
void ov::IVariableState::reset() {
OPENVINO_NOT_IMPLEMENTED;
}
void ov::IVariableState::set_state(const ov::Tensor& state) {
m_state = state;
}
const ov::Tensor& ov::IVariableState::get_state() const {
return m_state;
}

View File

@ -264,9 +264,7 @@ std::vector<VariableState> InferRequest::query_state() {
std::vector<VariableState> variable_states; std::vector<VariableState> variable_states;
OV_INFER_REQ_CALL_STATEMENT({ OV_INFER_REQ_CALL_STATEMENT({
for (auto&& state : _impl->query_state()) { for (auto&& state : _impl->query_state()) {
auto soVec = state._so; variable_states.emplace_back(ov::VariableState{state, {_so}});
soVec.emplace_back(_so);
variable_states.emplace_back(ov::VariableState{state._impl, soVec});
} }
}) })
return variable_states; return variable_states;

View File

@ -75,7 +75,7 @@ TemplatePlugin::InferRequest::InferRequest(const std::shared_ptr<const TemplateP
} }
// ! [infer_request:ctor] // ! [infer_request:ctor]
std::vector<ov::VariableState> TemplatePlugin::InferRequest::query_state() const { std::vector<std::shared_ptr<ov::IVariableState>> TemplatePlugin::InferRequest::query_state() const {
OPENVINO_NOT_IMPLEMENTED; OPENVINO_NOT_IMPLEMENTED;
} }

View File

@ -28,7 +28,7 @@ public:
~InferRequest(); ~InferRequest();
void infer() override; void infer() override;
std::vector<ov::VariableState> query_state() const override; std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override;
std::vector<ov::ProfilingInfo> get_profiling_info() const override; std::vector<ov::ProfilingInfo> get_profiling_info() const override;
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor // pipeline methods-stages which are used in async infer request implementation and assigned to particular executor