Added ov::IVariableState (#15843)

* Added ov::IVariableState

* Added variable state

* Try to fix Windows

* Fixed export
This commit is contained in:
Ilya Churaev 2023-02-22 14:30:46 +04:00 committed by GitHub
parent c8643a9a30
commit 548f972e19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 167 additions and 22 deletions

View File

@ -20,6 +20,7 @@
namespace InferenceEngine {
class Blob;
class IAsyncInferRequestWrapper;
class IVariableStateWrapper;
} // namespace InferenceEngine
namespace ov {
@ -31,6 +32,7 @@ class RemoteContext;
class VariableState;
class ISyncInferRequest;
class IInferRequestInternalWrapper;
class IVariableStateInternalWrapper;
/**
* @brief Tensor API holding host memory
@ -57,7 +59,9 @@ protected:
friend class ov::VariableState;
friend class ov::ISyncInferRequest;
friend class ov::IInferRequestInternalWrapper;
friend class ov::IVariableStateInternalWrapper;
friend class InferenceEngine::IAsyncInferRequestWrapper;
friend class InferenceEngine::IVariableStateWrapper;
public:
/// @brief Default constructor

View File

@ -129,7 +129,7 @@ public:
* State control essential for recurrent models.
* @return Vector of Variable State objects.
*/
std::vector<ov::VariableState> query_state() const override;
std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override;
/**
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)

View File

@ -15,6 +15,7 @@
#include <vector>
#include "openvino/runtime/common.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/tensor.hpp"
@ -87,7 +88,7 @@ public:
* State control essential for recurrent models.
* @return Vector of Variable State objects.
*/
virtual std::vector<ov::VariableState> query_state() const = 0;
virtual std::vector<std::shared_ptr<ov::IVariableState>> query_state() const = 0;
/**
* @brief Gets pointer to compiled model (usually synchronous request holds the compiled model)

View File

@ -0,0 +1,63 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
/**
* @brief OpenVINO Runtime IVariableState interface
* @file openvino/runtime/ivariable_state.hpp
*/
#pragma once
#include <memory>
#include <string>
#include "openvino/runtime/common.hpp"
#include "openvino/runtime/tensor.hpp"
namespace ov {
/**
* @interface IVariableState
* @brief Minimal interface for variable state implementation
* @ingroup ov_dev_api_variable_state_api
*/
class OPENVINO_RUNTIME_API IVariableState : public std::enable_shared_from_this<IVariableState> {
public:
explicit IVariableState(const std::string& name);
/**
* @brief Gets a variable state name
* @return A string representing variable state name
*/
virtual const std::string& get_name() const;
/**
* @brief Reset internal variable state for relevant infer request, to a value specified as
* default for according `ReadValue` node
*/
virtual void reset();
/**
* @brief Sets the new state for the next inference
* @param newState A new state
*/
virtual void set_state(const ov::Tensor& state);
/**
* @brief Returns the value of the variable state.
* @return The value of the variable state
*/
virtual const ov::Tensor& get_state() const;
protected:
/**
* @brief A default dtor
*/
virtual ~IVariableState();
std::string m_name;
ov::Tensor m_state;
};
} // namespace ov

View File

@ -16,13 +16,13 @@
#include "openvino/runtime/tensor.hpp"
namespace InferenceEngine {
class IVariableStateInternal;
class IAsyncInferRequestWrapper;
} // namespace InferenceEngine
namespace ov {
class InferRequest;
class IVariableState;
class IInferRequestInternalWrapper;
/**
@ -30,7 +30,7 @@ class IInferRequestInternalWrapper;
* @ingroup ov_runtime_cpp_api
*/
class OPENVINO_RUNTIME_API VariableState {
std::shared_ptr<InferenceEngine::IVariableStateInternal> _impl;
std::shared_ptr<ov::IVariableState> _impl;
std::vector<std::shared_ptr<void>> _so;
/**
@ -39,8 +39,7 @@ class OPENVINO_RUNTIME_API VariableState {
* @param so Optional: plugin to use. This is required to ensure that VariableState can work properly even if a
* plugin object is destroyed.
*/
VariableState(const std::shared_ptr<InferenceEngine::IVariableStateInternal>& impl,
const std::vector<std::shared_ptr<void>>& so);
VariableState(const std::shared_ptr<ov::IVariableState>& impl, const std::vector<std::shared_ptr<void>>& so);
friend class ov::InferRequest;
friend class ov::IInferRequestInternalWrapper;

View File

@ -5,6 +5,7 @@
#include "cpp/ie_memory_state.hpp"
#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp"
#include "openvino/core/except.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/variable_state.hpp"
#define VARIABLE_CALL_STATEMENT(...) \
@ -65,26 +66,27 @@ VariableState::~VariableState() {
_impl = {};
}
VariableState::VariableState(const ie::IVariableStateInternal::Ptr& impl, const std::vector<std::shared_ptr<void>>& so)
VariableState::VariableState(const std::shared_ptr<ov::IVariableState>& impl,
const std::vector<std::shared_ptr<void>>& so)
: _impl{impl},
_so{so} {
OPENVINO_ASSERT(_impl != nullptr, "VariableState was not initialized.");
}
void VariableState::reset() {
OV_VARIABLE_CALL_STATEMENT(_impl->Reset());
OV_VARIABLE_CALL_STATEMENT(_impl->reset());
}
std::string VariableState::get_name() const {
OV_VARIABLE_CALL_STATEMENT(return _impl->GetName());
OV_VARIABLE_CALL_STATEMENT(return _impl->get_name());
}
Tensor VariableState::get_state() const {
OV_VARIABLE_CALL_STATEMENT(return {std::const_pointer_cast<ie::Blob>(_impl->GetState()), {_so}});
OV_VARIABLE_CALL_STATEMENT(return _impl->get_state());
}
void VariableState::set_state(const Tensor& state) {
OV_VARIABLE_CALL_STATEMENT(_impl->SetState(state._impl));
OV_VARIABLE_CALL_STATEMENT(_impl->set_state(state));
}
} // namespace ov

View File

@ -12,6 +12,7 @@
#include "cnn_network_ngraph_impl.hpp"
#include "cpp_interfaces/interface/ie_iexecutable_network_internal.hpp"
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
#include "cpp_interfaces/interface/ie_ivariable_state_internal.hpp"
#include "icompiled_model_wrapper.hpp"
#include "ie_blob.h"
#include "ie_common.h"
@ -29,6 +30,7 @@
#include "openvino/runtime/icompiled_model.hpp"
#include "openvino/runtime/iinfer_request.hpp"
#include "openvino/runtime/iplugin.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/remote_context.hpp"
#include "openvino/runtime/tensor.hpp"
@ -185,6 +187,31 @@ std::shared_ptr<const ov::Model> ov::legacy_convert::convert_model(const Inferen
namespace ov {
class IVariableStateInternalWrapper : public InferenceEngine::IVariableStateInternal {
std::shared_ptr<ov::IVariableState> m_state;
public:
IVariableStateInternalWrapper(const std::shared_ptr<ov::IVariableState>& state)
: InferenceEngine::IVariableStateInternal(state->get_name()),
m_state(state) {}
std::string GetName() const override {
return m_state->get_name();
}
void Reset() override {
m_state->reset();
}
void SetState(const InferenceEngine::Blob::Ptr& newState) override {
m_state->set_state(ov::Tensor(newState, {}));
}
InferenceEngine::Blob::CPtr GetState() const override {
return m_state->get_state()._impl;
}
};
class IInferencePluginWrapper : public InferenceEngine::IInferencePlugin {
public:
IInferencePluginWrapper(const std::shared_ptr<ov::IPlugin>& plugin) : m_plugin(plugin) {
@ -521,7 +548,7 @@ public:
auto res = m_request->query_state();
std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> ret;
for (const auto& state : res) {
ret.emplace_back(state._impl);
ret.emplace_back(std::make_shared<ov::IVariableStateInternalWrapper>(state));
}
return ret;
}
@ -558,6 +585,30 @@ private:
namespace InferenceEngine {
class IVariableStateWrapper : public ov::IVariableState {
private:
std::shared_ptr<InferenceEngine::IVariableStateInternal> m_state;
mutable ov::Tensor m_converted_state;
public:
explicit IVariableStateWrapper(const std::shared_ptr<InferenceEngine::IVariableStateInternal>& state)
: ov::IVariableState(state->GetName()),
m_state(state) {}
void reset() override {
m_state->Reset();
}
void set_state(const ov::Tensor& state) override {
m_state->SetState(state._impl);
}
const ov::Tensor& get_state() const override {
m_converted_state = ov::Tensor(std::const_pointer_cast<InferenceEngine::Blob>(m_state->GetState()), {});
return m_converted_state;
}
};
class IAsyncInferRequestWrapper : public ov::IAsyncInferRequest {
public:
IAsyncInferRequestWrapper(const std::shared_ptr<InferenceEngine::IInferRequestInternal>& request)
@ -676,12 +727,11 @@ public:
m_request->SetBlobs(get_legacy_name_from_port(port), blobs);
}
std::vector<ov::VariableState> query_state() const override {
std::vector<ov::VariableState> variable_states;
std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override {
std::vector<std::shared_ptr<ov::IVariableState>> variable_states;
std::vector<std::shared_ptr<void>> soVec;
soVec = {m_request->getPointerToSo()};
for (auto&& state : m_request->QueryState()) {
variable_states.emplace_back(ov::VariableState{state, soVec});
variable_states.emplace_back(std::make_shared<InferenceEngine::IVariableStateWrapper>(state));
}
return variable_states;
}

View File

@ -7,6 +7,7 @@
#include <memory>
#include "openvino/runtime/isync_infer_request.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/variable_state.hpp"
#include "threading/ie_immediate_executor.hpp"
#include "threading/ie_istreams_executor.hpp"
@ -101,7 +102,7 @@ void ov::IAsyncInferRequest::set_callback(std::function<void(std::exception_ptr)
m_callback = std::move(callback);
}
std::vector<ov::VariableState> ov::IAsyncInferRequest::query_state() const {
std::vector<std::shared_ptr<ov::IVariableState>> ov::IAsyncInferRequest::query_state() const {
check_state();
return m_sync_request->query_state();
}

View File

@ -0,0 +1,27 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/core/except.hpp"
ov::IVariableState::IVariableState(const std::string& name) : m_name(name) {}
ov::IVariableState::~IVariableState() = default;
const std::string& ov::IVariableState::get_name() const {
return m_name;
}
void ov::IVariableState::reset() {
OPENVINO_NOT_IMPLEMENTED;
}
void ov::IVariableState::set_state(const ov::Tensor& state) {
m_state = state;
}
const ov::Tensor& ov::IVariableState::get_state() const {
return m_state;
}

View File

@ -264,9 +264,7 @@ std::vector<VariableState> InferRequest::query_state() {
std::vector<VariableState> variable_states;
OV_INFER_REQ_CALL_STATEMENT({
for (auto&& state : _impl->query_state()) {
auto soVec = state._so;
soVec.emplace_back(_so);
variable_states.emplace_back(ov::VariableState{state._impl, soVec});
variable_states.emplace_back(ov::VariableState{state, {_so}});
}
})
return variable_states;

View File

@ -75,7 +75,7 @@ TemplatePlugin::InferRequest::InferRequest(const std::shared_ptr<const TemplateP
}
// ! [infer_request:ctor]
std::vector<ov::VariableState> TemplatePlugin::InferRequest::query_state() const {
std::vector<std::shared_ptr<ov::IVariableState>> TemplatePlugin::InferRequest::query_state() const {
OPENVINO_NOT_IMPLEMENTED;
}

View File

@ -28,7 +28,7 @@ public:
~InferRequest();
void infer() override;
std::vector<ov::VariableState> query_state() const override;
std::vector<std::shared_ptr<ov::IVariableState>> query_state() const override;
std::vector<ov::ProfilingInfo> get_profiling_info() const override;
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor