Added ov::IVariableState (#15843)
* Added ov::IVariableState * Added variable state * Try to fix Windows * Fixed export
This commit is contained in:
parent
c8643a9a30
commit
548f972e19
@ -20,6 +20,7 @@
|
||||
namespace InferenceEngine {
|
||||
class Blob;
|
||||
class IAsyncInferRequestWrapper;
|
||||
class IVariableStateWrapper;
|
||||
} // namespace InferenceEngine
|
||||
|
||||
namespace ov {
|
||||
@ -31,6 +32,7 @@ class RemoteContext;
|
||||
class VariableState;
|
||||
class ISyncInferRequest;
|
||||
class IInferRequestInternalWrapper;
|
||||
class IVariableStateInternalWrapper;
|
||||
|
||||
/**
|
||||
* @brief Tensor API holding host memory
|
||||
@ -57,7 +59,9 @@ protected:
|
||||
friend class ov::VariableState;
|
||||
friend class ov::ISyncInferRequest;
|
||||
friend class ov::IInferRequestInternalWrapper;
|
||||
friend class ov::IVariableStateInternalWrapper;
|
||||
friend class InferenceEngine::IAsyncInferRequestWrapper;
|
||||
friend class InferenceEngine::IVariableStateWrapper;
|
||||
|
||||
public:
|
||||
/// @brief Default constructor
|
||||
|
@ -129,7 +129,7 @@ public:
|
||||
* State control essential for recurrent models.
|
||||
* @return Vector of Variable State objects.
|
||||
*/
|
||||
std::vector<ov::VariableState> query_state() const override;
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override;
|
||||
|
||||
/**
|
||||
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/runtime/common.hpp"
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
#include "openvino/runtime/profiling_info.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
|
||||
@ -87,7 +88,7 @@ public:
|
||||
* State control essential for recurrent models.
|
||||
* @return Vector of Variable State objects.
|
||||
*/
|
||||
virtual std::vector<ov::VariableState> query_state() const = 0;
|
||||
virtual std::vector<std::shared_ptr<ov::IVariableState>> query_state() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)
|
||||
|
63
src/inference/dev_api/openvino/runtime/ivariable_state.hpp
Normal file
63
src/inference/dev_api/openvino/runtime/ivariable_state.hpp
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief OpenVINO Runtime IVariableState interface
|
||||
* @file openvino/runtime/ivariable_state.hpp
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "openvino/runtime/common.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
|
||||
namespace ov {
|
||||
|
||||
/**
|
||||
* @interface IVariableState
|
||||
* @brief Minimal interface for variable state implementation
|
||||
* @ingroup ov_dev_api_variable_state_api
|
||||
*/
|
||||
class OPENVINO_RUNTIME_API IVariableState : public std::enable_shared_from_this<IVariableState> {
|
||||
public:
|
||||
explicit IVariableState(const std::string& name);
|
||||
|
||||
/**
|
||||
* @brief Gets a variable state name
|
||||
* @return A string representing variable state name
|
||||
*/
|
||||
virtual const std::string& get_name() const;
|
||||
|
||||
/**
|
||||
* @brief Reset internal variable state for relevant infer request, to a value specified as
|
||||
* default for according `ReadValue` node
|
||||
*/
|
||||
virtual void reset();
|
||||
|
||||
/**
|
||||
* @brief Sets the new state for the next inference
|
||||
* @param newState A new state
|
||||
*/
|
||||
virtual void set_state(const ov::Tensor& state);
|
||||
|
||||
/**
|
||||
* @brief Returns the value of the variable state.
|
||||
* @return The value of the variable state
|
||||
*/
|
||||
virtual const ov::Tensor& get_state() const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief A default dtor
|
||||
*/
|
||||
virtual ~IVariableState();
|
||||
|
||||
std::string m_name;
|
||||
ov::Tensor m_state;
|
||||
};
|
||||
|
||||
} // namespace ov
|
@ -16,13 +16,13 @@
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
class IVariableStateInternal;
|
||||
class IAsyncInferRequestWrapper;
|
||||
} // namespace InferenceEngine
|
||||
|
||||
namespace ov {
|
||||
|
||||
class InferRequest;
|
||||
class IVariableState;
|
||||
class IInferRequestInternalWrapper;
|
||||
|
||||
/**
|
||||
@ -30,7 +30,7 @@ class IInferRequestInternalWrapper;
|
||||
* @ingroup ov_runtime_cpp_api
|
||||
*/
|
||||
class OPENVINO_RUNTIME_API VariableState {
|
||||
std::shared_ptr<InferenceEngine::IVariableStateInternal> _impl;
|
||||
std::shared_ptr<ov::IVariableState> _impl;
|
||||
std::vector<std::shared_ptr<void>> _so;
|
||||
|
||||
/**
|
||||
@ -39,8 +39,7 @@ class OPENVINO_RUNTIME_API VariableState {
|
||||
* @param so Optional: plugin to use. This is required to ensure that VariableState can work properly even if a
|
||||
* plugin object is destroyed.
|
||||
*/
|
||||
VariableState(const std::shared_ptr<InferenceEngine::IVariableStateInternal>& impl,
|
||||
const std::vector<std::shared_ptr<void>>& so);
|
||||
VariableState(const std::shared_ptr<ov::IVariableState>& impl, const std::vector<std::shared_ptr<void>>& so);
|
||||
|
||||
friend class ov::InferRequest;
|
||||
friend class ov::IInferRequestInternalWrapper;
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "cpp/ie_memory_state.hpp"
|
||||
#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp"
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
#include "openvino/runtime/variable_state.hpp"
|
||||
|
||||
#define VARIABLE_CALL_STATEMENT(...) \
|
||||
@ -65,26 +66,27 @@ VariableState::~VariableState() {
|
||||
_impl = {};
|
||||
}
|
||||
|
||||
VariableState::VariableState(const ie::IVariableStateInternal::Ptr& impl, const std::vector<std::shared_ptr<void>>& so)
|
||||
VariableState::VariableState(const std::shared_ptr<ov::IVariableState>& impl,
|
||||
const std::vector<std::shared_ptr<void>>& so)
|
||||
: _impl{impl},
|
||||
_so{so} {
|
||||
OPENVINO_ASSERT(_impl != nullptr, "VariableState was not initialized.");
|
||||
}
|
||||
|
||||
void VariableState::reset() {
|
||||
OV_VARIABLE_CALL_STATEMENT(_impl->Reset());
|
||||
OV_VARIABLE_CALL_STATEMENT(_impl->reset());
|
||||
}
|
||||
|
||||
std::string VariableState::get_name() const {
|
||||
OV_VARIABLE_CALL_STATEMENT(return _impl->GetName());
|
||||
OV_VARIABLE_CALL_STATEMENT(return _impl->get_name());
|
||||
}
|
||||
|
||||
Tensor VariableState::get_state() const {
|
||||
OV_VARIABLE_CALL_STATEMENT(return {std::const_pointer_cast<ie::Blob>(_impl->GetState()), {_so}});
|
||||
OV_VARIABLE_CALL_STATEMENT(return _impl->get_state());
|
||||
}
|
||||
|
||||
void VariableState::set_state(const Tensor& state) {
|
||||
OV_VARIABLE_CALL_STATEMENT(_impl->SetState(state._impl));
|
||||
OV_VARIABLE_CALL_STATEMENT(_impl->set_state(state));
|
||||
}
|
||||
|
||||
} // namespace ov
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
|
||||
#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp"
|
||||
#include "icompiled_model_wrapper.hpp"
|
||||
#include "ie_blob.h"
|
||||
#include "ie_common.h"
|
||||
@ -29,6 +30,7 @@
|
||||
#include "openvino/runtime/icompiled_model.hpp"
|
||||
#include "openvino/runtime/iinfer_request.hpp"
|
||||
#include "openvino/runtime/iplugin.hpp"
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
#include "openvino/runtime/profiling_info.hpp"
|
||||
#include "openvino/runtime/remote_context.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
@ -185,6 +187,31 @@ std::shared_ptr<const ov::Model> ov::legacy_convert::convert_model(const Inferen
|
||||
|
||||
namespace ov {
|
||||
|
||||
class IVariableStateInternalWrapper : public InferenceEngine::IVariableStateInternal {
|
||||
std::shared_ptr<ov::IVariableState> m_state;
|
||||
|
||||
public:
|
||||
IVariableStateInternalWrapper(const std::shared_ptr<ov::IVariableState>& state)
|
||||
: InferenceEngine::IVariableStateInternal(state->get_name()),
|
||||
m_state(state) {}
|
||||
|
||||
std::string GetName() const override {
|
||||
return m_state->get_name();
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
m_state->reset();
|
||||
}
|
||||
|
||||
void SetState(const InferenceEngine::Blob::Ptr& newState) override {
|
||||
m_state->set_state(ov::Tensor(newState, {}));
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::CPtr GetState() const override {
|
||||
return m_state->get_state()._impl;
|
||||
}
|
||||
};
|
||||
|
||||
class IInferencePluginWrapper : public InferenceEngine::IInferencePlugin {
|
||||
public:
|
||||
IInferencePluginWrapper(const std::shared_ptr<ov::IPlugin>& plugin) : m_plugin(plugin) {
|
||||
@ -521,7 +548,7 @@ public:
|
||||
auto res = m_request->query_state();
|
||||
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> ret;
|
||||
for (const auto& state : res) {
|
||||
ret.emplace_back(state._impl);
|
||||
ret.emplace_back(std::make_shared<ov::IVariableStateInternalWrapper>(state));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@ -558,6 +585,30 @@ private:
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
class IVariableStateWrapper : public ov::IVariableState {
|
||||
private:
|
||||
std::shared_ptr<InferenceEngine::IVariableStateInternal> m_state;
|
||||
mutable ov::Tensor m_converted_state;
|
||||
|
||||
public:
|
||||
explicit IVariableStateWrapper(const std::shared_ptr<InferenceEngine::IVariableStateInternal>& state)
|
||||
: ov::IVariableState(state->GetName()),
|
||||
m_state(state) {}
|
||||
|
||||
void reset() override {
|
||||
m_state->Reset();
|
||||
}
|
||||
|
||||
void set_state(const ov::Tensor& state) override {
|
||||
m_state->SetState(state._impl);
|
||||
}
|
||||
|
||||
const ov::Tensor& get_state() const override {
|
||||
m_converted_state = ov::Tensor(std::const_pointer_cast<InferenceEngine::Blob>(m_state->GetState()), {});
|
||||
return m_converted_state;
|
||||
}
|
||||
};
|
||||
|
||||
class IAsyncInferRequestWrapper : public ov::IAsyncInferRequest {
|
||||
public:
|
||||
IAsyncInferRequestWrapper(const std::shared_ptr<InferenceEngine::IInferRequestInternal>& request)
|
||||
@ -676,12 +727,11 @@ public:
|
||||
m_request->SetBlobs(get_legacy_name_from_port(port), blobs);
|
||||
}
|
||||
|
||||
std::vector<ov::VariableState> query_state() const override {
|
||||
std::vector<ov::VariableState> variable_states;
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override {
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> variable_states;
|
||||
std::vector<std::shared_ptr<void>> soVec;
|
||||
soVec = {m_request->getPointerToSo()};
|
||||
for (auto&& state : m_request->QueryState()) {
|
||||
variable_states.emplace_back(ov::VariableState{state, soVec});
|
||||
variable_states.emplace_back(std::make_shared<InferenceEngine::IVariableStateWrapper>(state));
|
||||
}
|
||||
return variable_states;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "openvino/runtime/isync_infer_request.hpp"
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
#include "openvino/runtime/variable_state.hpp"
|
||||
#include "threading/ie_immediate_executor.hpp"
|
||||
#include "threading/ie_istreams_executor.hpp"
|
||||
@ -101,7 +102,7 @@ void ov::IAsyncInferRequest::set_callback(std::function<void(std::exception_ptr)
|
||||
m_callback = std::move(callback);
|
||||
}
|
||||
|
||||
std::vector<ov::VariableState> ov::IAsyncInferRequest::query_state() const {
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> ov::IAsyncInferRequest::query_state() const {
|
||||
check_state();
|
||||
return m_sync_request->query_state();
|
||||
}
|
||||
|
27
src/inference/src/dev/ivariable_state.cpp
Normal file
27
src/inference/src/dev/ivariable_state.cpp
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
|
||||
#include "openvino/core/except.hpp"
|
||||
|
||||
ov::IVariableState::IVariableState(const std::string& name) : m_name(name) {}
|
||||
|
||||
ov::IVariableState::~IVariableState() = default;
|
||||
|
||||
const std::string& ov::IVariableState::get_name() const {
|
||||
return m_name;
|
||||
}
|
||||
|
||||
void ov::IVariableState::reset() {
|
||||
OPENVINO_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
void ov::IVariableState::set_state(const ov::Tensor& state) {
|
||||
m_state = state;
|
||||
}
|
||||
|
||||
const ov::Tensor& ov::IVariableState::get_state() const {
|
||||
return m_state;
|
||||
}
|
@ -264,9 +264,7 @@ std::vector<VariableState> InferRequest::query_state() {
|
||||
std::vector<VariableState> variable_states;
|
||||
OV_INFER_REQ_CALL_STATEMENT({
|
||||
for (auto&& state : _impl->query_state()) {
|
||||
auto soVec = state._so;
|
||||
soVec.emplace_back(_so);
|
||||
variable_states.emplace_back(ov::VariableState{state._impl, soVec});
|
||||
variable_states.emplace_back(ov::VariableState{state, {_so}});
|
||||
}
|
||||
})
|
||||
return variable_states;
|
||||
|
@ -75,7 +75,7 @@ TemplatePlugin::InferRequest::InferRequest(const std::shared_ptr<const TemplateP
|
||||
}
|
||||
// ! [infer_request:ctor]
|
||||
|
||||
std::vector<ov::VariableState> TemplatePlugin::InferRequest::query_state() const {
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> TemplatePlugin::InferRequest::query_state() const {
|
||||
OPENVINO_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ public:
|
||||
~InferRequest();
|
||||
|
||||
void infer() override;
|
||||
std::vector<ov::VariableState> query_state() const override;
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override;
|
||||
std::vector<ov::ProfilingInfo> get_profiling_info() const override;
|
||||
|
||||
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
|
||||
|
Loading…
Reference in New Issue
Block a user