Implement VariableState support for the Template plugin (#16356)

* Implement VariableState support for the Template plugin

* Suppress some warnings

* Try to fix Windows
This commit is contained in:
Ilya Churaev
2023-03-17 22:10:47 +04:00
committed by GitHub
parent e6ceed0bb9
commit b2a2266f60
11 changed files with 276 additions and 18 deletions

View File

@@ -39,7 +39,7 @@ OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
OPENVINO_API TensorVector wrap_tensors(const std::vector<ngraph::HostTensorPtr>& tensors);
/**
* @brief Update output host tensors if they got dynamic shapee before evaluation (not allocated).
* @brief Update output host tensors if they got dynamic shape before evaluation (not allocated).
*
* Other tensor not requires update as they are created from outputs and points to same data blob.
*

View File

@@ -8,6 +8,7 @@
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/deprecated.hpp"
namespace ov {
namespace op {
@@ -18,42 +19,69 @@ class OPENVINO_API VariableValue {
public:
using Ptr = std::shared_ptr<VariableValue>;
/// \brief Constructs an uninitialized VariableValue.
VariableValue() = default;
VariableValue();
/// \brief Constructor for VariableValue.
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
/// \param value The data for Variable.
explicit VariableValue(ngraph::HostTensorPtr value) : m_value(std::move(value)) {}
OPENVINO_DEPRECATED(
"This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor instead.")
explicit VariableValue(ngraph::HostTensorPtr value);
/// \brief Constructor for VariableValue.
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
/// \param value Data for Variable.
/// \param reset The current state of the reset flag.
VariableValue(ngraph::HostTensorPtr value, bool reset) : m_reset(reset), m_value(std::move(value)) {}
OPENVINO_DEPRECATED(
"This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor instead.")
VariableValue(ngraph::HostTensorPtr value, bool reset);
/// \brief Returns the current stored data.
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
OPENVINO_DEPRECATED("This method is deprecated and will be removed in 2024.0 release. Please get_state() instead.")
ngraph::HostTensorPtr get_value() const;
/// \brief Sets new values for Variable.
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
/// \param value New data for Variable.
OPENVINO_DEPRECATED(
"This method is deprecated and will be removed in 2024.0 release. Please use set_state() instead.")
void set_value(const ngraph::HostTensorPtr& value);
/// \brief Sets the reset flag to a new state.
/// \param reset The new state of the reset flag.
void set_reset(bool reset) {
m_reset = reset;
}
void set_reset(bool reset);
/// \brief Returns the current reset flag state.
bool get_reset() const {
return m_reset;
}
bool get_reset() const;
explicit VariableValue(const ov::Tensor& value);
/// \brief Constructor for VariableValue.
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
/// \param value Data for Variable.
/// \param reset The current state of the reset flag.
VariableValue(const ov::Tensor& value, bool reset);
/// \brief Returns the current stored data.
const ngraph::HostTensorPtr& get_value() const {
return m_value;
}
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
const ov::Tensor& get_state() const;
/// \brief Sets new values for Variable.
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
/// instead
/// \param value New data for Variable.
void set_value(const ngraph::HostTensorPtr& value) {
m_value = value;
}
void set_state(const ov::Tensor& value);
private:
bool m_reset = true;
ngraph::HostTensorPtr m_value;
ov::Tensor m_value;
};
} // namespace util
} // namespace op

View File

@@ -35,6 +35,12 @@ class IVariableStateInternalWrapper;
class ITensor;
class RemoteTensor;
namespace op {
namespace util {
class VariableValue;
}
} // namespace op
/**
* @brief Tensor API holding host memory
* It can throw exceptions safely for the application, where it is properly handled.
@@ -64,6 +70,7 @@ protected:
friend class ov::IVariableStateInternalWrapper;
friend class InferenceEngine::IAsyncInferRequestWrapper;
friend class InferenceEngine::IVariableStateWrapper;
friend class ov::op::util::VariableValue;
public:
/// @brief Default constructor

View File

@@ -96,6 +96,7 @@ bool op::v6::Assign::evaluate(const HostTensorVector& outputs,
const auto& variable_values = variable_context.get_variable_values();
OPENVINO_SUPPRESS_DEPRECATED_START
// automatically allocate memory if not provided by user
if (variable_values.find(m_variable) == variable_values.end()) {
auto host_tensor =
@@ -106,6 +107,7 @@ bool op::v6::Assign::evaluate(const HostTensorVector& outputs,
const auto var_value = variable_values.find(m_variable)->second;
var_value->set_reset(false);
const auto& buffer = var_value->get_value();
OPENVINO_SUPPRESS_DEPRECATED_END
buffer->set_unary(inputs[0]);
outputs[0]->set_unary(inputs[0]);

View File

@@ -108,7 +108,9 @@ bool op::v6::ReadValue::evaluate(const HostTensorVector& outputs,
// initial value (inputs[0]) is not supported, use zeros
auto zero_const = make_shared<v0::Constant>(inputs[0]->get_element_type(), inputs[0]->get_shape(), 0);
auto zero_tensor = make_shared<HostTensor>(zero_const);
OPENVINO_SUPPRESS_DEPRECATED_START
const auto& input_tensor = use_context ? var_value->second->get_value() : zero_tensor;
OPENVINO_SUPPRESS_DEPRECATED_END
outputs[0]->set_unary(input_tensor);
void* input = input_tensor->get_data_ptr();

View File

@@ -0,0 +1,143 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/util/variable_value.hpp"
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/core/shape.hpp"
#include "openvino/runtime/allocator.hpp"
#include "openvino/runtime/itensor.hpp"
#include "openvino/runtime/tensor.hpp"
#include "shape_util.hpp"
namespace {
class TensorWrapper : public ngraph::runtime::HostTensor {
public:
TensorWrapper(const ov::Tensor& tensor)
: ngraph::runtime::HostTensor(tensor.get_element_type(), tensor.get_shape(), tensor.data()),
tensor(tensor) {}
ov::Tensor tensor;
};
/**
* @brief Tensor what contains HostTensorPtr inside
*/
class HostTensorWrapper : public ov::ITensor {
public:
ngraph::HostTensorPtr tensor;
HostTensorWrapper(const ngraph::HostTensorPtr& tensor) : tensor{tensor}, m_type(tensor->get_element_type()) {
const auto& p_shape = tensor->get_partial_shape();
if (p_shape.is_static()) {
m_shape = p_shape.to_shape();
} else {
OPENVINO_SUPPRESS_DEPRECATED_START
m_shape = ov::util::make_dynamic_shape();
OPENVINO_SUPPRESS_DEPRECATED_END
}
update_strides();
}
const ov::element::Type& get_element_type() const override {
return m_type;
}
void set_shape(ov::Shape shape) override {
tensor->set_shape(shape);
m_shape = shape;
update_strides();
}
const ov::Shape& get_shape() const override {
return m_shape;
}
const ov::Strides& get_strides() const override {
OPENVINO_ASSERT(get_element_type().bitwidth() >= 8,
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
get_element_type());
return m_strides;
}
size_t get_size() const override {
return ov::shape_size(m_shape);
}
size_t get_byte_size() const override {
return get_size() * m_type.size();
}
void* data(const ov::element::Type& element_type) const override {
return tensor->get_data_ptr();
}
private:
ov::element::Type m_type;
ov::Shape m_shape;
ov::Strides m_strides;
void update_strides() {
if (m_type.bitwidth() >= 8) {
m_strides.clear();
m_strides.resize(m_shape.size());
auto size = m_strides.size();
for (size_t i = 0; i < size; i++) {
size_t value(m_type.size());
size_t dim(m_shape[size - 1 - i]);
if (i) {
value = m_strides[size - i] * dim;
}
m_strides[size - i - 1] = value;
}
}
}
};
} // namespace
ov::op::util::VariableValue::VariableValue() = default;
OPENVINO_SUPPRESS_DEPRECATED_START
ov::op::util::VariableValue::VariableValue(ngraph::HostTensorPtr value)
: m_value(ov::Tensor{std::make_shared<HostTensorWrapper>(value), {}}) {}
ov::op::util::VariableValue::VariableValue(ngraph::HostTensorPtr value, bool reset)
: m_reset(reset),
m_value(ov::Tensor{std::make_shared<HostTensorWrapper>(value), {}}) {}
ngraph::HostTensorPtr ov::op::util::VariableValue::get_value() const {
if (auto wrapper = std::dynamic_pointer_cast<HostTensorWrapper>(m_value._impl))
return wrapper->tensor;
return std::make_shared<TensorWrapper>(m_value);
}
void ov::op::util::VariableValue::set_value(const ngraph::HostTensorPtr& value) {
m_value = ov::Tensor{std::make_shared<HostTensorWrapper>(value), {}};
}
OPENVINO_SUPPRESS_DEPRECATED_END
void ov::op::util::VariableValue::set_reset(bool reset) {
m_reset = reset;
}
bool ov::op::util::VariableValue::get_reset() const {
return m_reset;
}
ov::op::util::VariableValue::VariableValue(const ov::Tensor& value) : m_value(value) {}
ov::op::util::VariableValue::VariableValue(const ov::Tensor& value, bool reset) : m_reset(reset), m_value(value) {}
const ov::Tensor& ov::op::util::VariableValue::get_state() const {
return m_value;
}
void ov::op::util::VariableValue::set_state(const ov::Tensor& value) {
m_value = value;
}

View File

@@ -13,8 +13,12 @@
#include "itt.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/core/except.hpp"
#include "openvino/op/util/variable_context.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/tensor.hpp"
#include "plugin.hpp"
#include "variable_state.hpp"
using Time = std::chrono::high_resolution_clock;
@@ -72,11 +76,27 @@ ov::template_plugin::InferRequest::InferRequest(const std::shared_ptr<const ov::
output.get_partial_shape().is_dynamic() ? ov::Shape{0} : output.get_shape());
});
}
// Save variable states
ov::op::util::VariableContext variable_context;
for (const auto& variable : get_template_model()->m_model->get_variables()) {
auto value = std::make_shared<ov::op::util::VariableValue>();
if (!variable_context.get_variable_value(variable)) {
auto shape = variable->get_info().data_shape.is_dynamic() ? ov::Shape{0}
: variable->get_info().data_shape.to_shape();
ov::Tensor tensor = ov::Tensor(variable->get_info().data_type, shape);
variable_context.set_variable_value(variable, std::make_shared<ov::op::util::VariableValue>(tensor));
}
auto state = std::make_shared<VariableState>(variable->get_info().variable_id,
variable_context.get_variable_value(variable)->get_state());
m_variable_states.emplace_back(state);
}
m_eval_context.emplace("VariableContext", variable_context);
}
// ! [infer_request:ctor]
std::vector<std::shared_ptr<ov::IVariableState>> ov::template_plugin::InferRequest::query_state() const {
OPENVINO_NOT_IMPLEMENTED;
return m_variable_states;
}
std::shared_ptr<const ov::template_plugin::CompiledModel> ov::template_plugin::InferRequest::get_template_model()

View File

@@ -13,8 +13,10 @@
#include "executable.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "openvino/core/node.hpp"
#include "openvino/itt.hpp"
#include "openvino/runtime/isync_infer_request.hpp"
#include "openvino/runtime/ivariable_state.hpp"
namespace ov {
namespace template_plugin {
@@ -52,6 +54,8 @@ private:
std::vector<ov::Tensor> m_backend_input_tensors;
std::vector<ov::Tensor> m_backend_output_tensors;
std::shared_ptr<ov::runtime::Executable> m_executable;
ov::EvaluationContext m_eval_context;
std::vector<std::shared_ptr<ov::IVariableState>> m_variable_states;
};
// ! [infer_request:header]

View File

@@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/runtime/ivariable_state.hpp"
namespace ov {
namespace template_plugin {
class VariableState : public ov::IVariableState {
public:
VariableState(const std::string& name, const ov::Tensor& tensor) : ov::IVariableState(name) {
set_state(tensor);
}
void reset() override {
std::memset(m_state.data(), 0, m_state.get_byte_size());
}
~VariableState() override = default;
};
} // namespace template_plugin
} // namespace ov

View File

@@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "behavior/infer_request/memory_states.hpp"
using namespace BehaviorTestsDefinitions;
namespace {
std::vector<memoryStateParams> memoryStateTestCases = {memoryStateParams(InferRequestVariableStateTest::getNetwork(),
{"c_1-3", "r_1-3"},
CommonTestUtils::DEVICE_TEMPLATE,
{})};
INSTANTIATE_TEST_SUITE_P(smoke_Template_BehaviorTests,
InferRequestVariableStateTest,
::testing::ValuesIn(memoryStateTestCases),
InferRequestVariableStateTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Template_BehaviorTests,
InferRequestQueryStateExceptionTest,
::testing::ValuesIn(memoryStateTestCases),
InferRequestQueryStateExceptionTest::getTestCaseName);
} // namespace

View File

@@ -133,6 +133,8 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*InferRequestIOBBlobTest.*secondCallGetOutputAfterInferSync.*)",
// Old API cannot deallocate tensor
R"(.*InferRequestIOBBlobTest.*canProcessDeallocatedOutputBlobAfterGetAndSetBlob.*)",
// Why query state should throw an exception
R"(.*InferRequestQueryStateExceptionTest.*inferreq_smoke_QueryState_ExceptionTest.*)",
};
#ifdef _WIN32