Add variable state for proxy plugin (#18426)

* Added proxy variable state

* Added documentation for some methods

* Fixed code style
This commit is contained in:
Ilya Churaev
2023-07-07 19:44:57 +04:00
committed by GitHub
parent 205de6106b
commit 79c37882aa
4 changed files with 79 additions and 1 deletions

View File

@@ -8,6 +8,7 @@
#include "openvino/runtime/remote_context.hpp"
#include "openvino/runtime/so_ptr.hpp"
#include "remote_context.hpp"
#include "variable_state.hpp"
ov::proxy::InferRequest::InferRequest(ov::SoPtr<ov::IAsyncInferRequest>&& request,
const std::shared_ptr<const ov::ICompiledModel>& compiled_model)
@@ -75,7 +76,11 @@ void ov::proxy::InferRequest::set_tensors(const ov::Output<const ov::Node>& port
}
std::vector<std::shared_ptr<ov::IVariableState>> ov::proxy::InferRequest::query_state() const {
return m_infer_request->query_state();
auto states = m_infer_request->query_state();
for (auto&& state : states) {
state = std::make_shared<ov::proxy::VariableState>(state, m_infer_request._so);
}
return states;
}
const std::shared_ptr<const ov::ICompiledModel>& ov::proxy::InferRequest::get_compiled_model() const {

View File

@@ -12,8 +12,25 @@
namespace ov {
namespace proxy {
/**
* @brief Proxy remote context implementation
* This class wraps hardware specific remote context and replace the context name
*/
class RemoteContext : public ov::IRemoteContext {
public:
/**
* @brief Constructs the proxy remote context
*
* @param ctx hardware context
* @param dev_name device name without index
* @param dev_index device index if exists else 0
* @param has_index flag is true if device has an index and false in another case
* @param is_new_api flag reports which API is used
*
* These arguments are needed to support the difference between legacy and 2.0 APIs.
* In legacy API remote context doesn't contain the index in the name but Blob contains.
* In 2.0 API Tensor and Context always contain device index
*/
RemoteContext(ov::RemoteContext&& ctx,
const std::string& dev_name,
size_t dev_index,

View File

@@ -11,6 +11,10 @@
namespace ov {
namespace proxy {
/**
* @brief Proxy remote tensor class.
* This class wraps the original remote tensor and change the name of RemoteTensor
*/
class RemoteTensor : public ov::IRemoteTensor {
public:
RemoteTensor(ov::RemoteTensor&& ctx, const std::string& dev_name);

View File

@@ -0,0 +1,52 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include "openvino/runtime/ivariable_state.hpp"
namespace ov {
namespace proxy {
/**
* @brief Simple wrapper for hardware variable states which holds plugin shared object
*/
class VariableState : public ov::IVariableState {
std::shared_ptr<ov::IVariableState> m_state;
std::shared_ptr<void> m_so;
public:
/**
* @brief Constructor of proxy VariableState
*
* @param state hardware state
* @param so shared object
*/
VariableState(const std::shared_ptr<ov::IVariableState>& state, const std::shared_ptr<void>& so)
: IVariableState(""),
m_state(state),
m_so(so) {
OPENVINO_ASSERT(m_state);
}
const std::string& get_name() const override {
return m_state->get_name();
}
void reset() override {
m_state->reset();
}
void set_state(const ov::Tensor& state) override {
m_state->set_state(state);
}
const ov::Tensor& get_state() const override {
return m_state->get_state();
}
};
} // namespace proxy
} // namespace ov