Implement VariableState support for the Template plugin (#16356)
* Implement VariableState support for the Template plugin * Suppress some warnings * Try to fix Windows
This commit is contained in:
@@ -39,7 +39,7 @@ OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
|
||||
OPENVINO_API TensorVector wrap_tensors(const std::vector<ngraph::HostTensorPtr>& tensors);
|
||||
|
||||
/**
|
||||
* @brief Update output host tensors if they got dynamic shapee before evaluation (not allocated).
|
||||
* @brief Update output host tensors if they got dynamic shape before evaluation (not allocated).
|
||||
*
|
||||
* Other tensor not requires update as they are created from outputs and points to same data blob.
|
||||
*
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/core/core_visibility.hpp"
|
||||
#include "openvino/core/deprecated.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
@@ -18,42 +19,69 @@ class OPENVINO_API VariableValue {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<VariableValue>;
|
||||
/// \brief Constructs an uninitialized VariableValue.
|
||||
VariableValue() = default;
|
||||
VariableValue();
|
||||
|
||||
/// \brief Constructor for VariableValue.
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
/// \param value The data for Variable.
|
||||
explicit VariableValue(ngraph::HostTensorPtr value) : m_value(std::move(value)) {}
|
||||
OPENVINO_DEPRECATED(
|
||||
"This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor instead.")
|
||||
explicit VariableValue(ngraph::HostTensorPtr value);
|
||||
|
||||
/// \brief Constructor for VariableValue.
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
/// \param value Data for Variable.
|
||||
/// \param reset The current state of the reset flag.
|
||||
VariableValue(ngraph::HostTensorPtr value, bool reset) : m_reset(reset), m_value(std::move(value)) {}
|
||||
OPENVINO_DEPRECATED(
|
||||
"This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor instead.")
|
||||
VariableValue(ngraph::HostTensorPtr value, bool reset);
|
||||
|
||||
/// \brief Returns the current stored data.
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
OPENVINO_DEPRECATED("This method is deprecated and will be removed in 2024.0 release. Please get_state() instead.")
|
||||
ngraph::HostTensorPtr get_value() const;
|
||||
|
||||
/// \brief Sets new values for Variable.
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
/// \param value New data for Variable.
|
||||
OPENVINO_DEPRECATED(
|
||||
"This method is deprecated and will be removed in 2024.0 release. Please use set_state() instead.")
|
||||
void set_value(const ngraph::HostTensorPtr& value);
|
||||
|
||||
/// \brief Sets the reset flag to a new state.
|
||||
/// \param reset The new state of the reset flag.
|
||||
void set_reset(bool reset) {
|
||||
m_reset = reset;
|
||||
}
|
||||
void set_reset(bool reset);
|
||||
|
||||
/// \brief Returns the current reset flag state.
|
||||
bool get_reset() const {
|
||||
return m_reset;
|
||||
}
|
||||
bool get_reset() const;
|
||||
|
||||
explicit VariableValue(const ov::Tensor& value);
|
||||
|
||||
/// \brief Constructor for VariableValue.
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
/// \param value Data for Variable.
|
||||
/// \param reset The current state of the reset flag.
|
||||
VariableValue(const ov::Tensor& value, bool reset);
|
||||
|
||||
/// \brief Returns the current stored data.
|
||||
const ngraph::HostTensorPtr& get_value() const {
|
||||
return m_value;
|
||||
}
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
const ov::Tensor& get_state() const;
|
||||
|
||||
/// \brief Sets new values for Variable.
|
||||
/// \deprecated This method is deprecated and will be removed in 2024.0 release. Please use method with ov::Tensor
|
||||
/// instead
|
||||
/// \param value New data for Variable.
|
||||
void set_value(const ngraph::HostTensorPtr& value) {
|
||||
m_value = value;
|
||||
}
|
||||
void set_state(const ov::Tensor& value);
|
||||
|
||||
private:
|
||||
bool m_reset = true;
|
||||
ngraph::HostTensorPtr m_value;
|
||||
ov::Tensor m_value;
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
|
||||
@@ -35,6 +35,12 @@ class IVariableStateInternalWrapper;
|
||||
class ITensor;
|
||||
class RemoteTensor;
|
||||
|
||||
namespace op {
|
||||
namespace util {
|
||||
class VariableValue;
|
||||
}
|
||||
} // namespace op
|
||||
|
||||
/**
|
||||
* @brief Tensor API holding host memory
|
||||
* It can throw exceptions safely for the application, where it is properly handled.
|
||||
@@ -64,6 +70,7 @@ protected:
|
||||
friend class ov::IVariableStateInternalWrapper;
|
||||
friend class InferenceEngine::IAsyncInferRequestWrapper;
|
||||
friend class InferenceEngine::IVariableStateWrapper;
|
||||
friend class ov::op::util::VariableValue;
|
||||
|
||||
public:
|
||||
/// @brief Default constructor
|
||||
|
||||
@@ -96,6 +96,7 @@ bool op::v6::Assign::evaluate(const HostTensorVector& outputs,
|
||||
|
||||
const auto& variable_values = variable_context.get_variable_values();
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
// automatically allocate memory if not provided by user
|
||||
if (variable_values.find(m_variable) == variable_values.end()) {
|
||||
auto host_tensor =
|
||||
@@ -106,6 +107,7 @@ bool op::v6::Assign::evaluate(const HostTensorVector& outputs,
|
||||
const auto var_value = variable_values.find(m_variable)->second;
|
||||
var_value->set_reset(false);
|
||||
const auto& buffer = var_value->get_value();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
buffer->set_unary(inputs[0]);
|
||||
outputs[0]->set_unary(inputs[0]);
|
||||
|
||||
|
||||
@@ -108,7 +108,9 @@ bool op::v6::ReadValue::evaluate(const HostTensorVector& outputs,
|
||||
// initial value (inputs[0]) is not supported, use zeros
|
||||
auto zero_const = make_shared<v0::Constant>(inputs[0]->get_element_type(), inputs[0]->get_shape(), 0);
|
||||
auto zero_tensor = make_shared<HostTensor>(zero_const);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto& input_tensor = use_context ? var_value->second->get_value() : zero_tensor;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
outputs[0]->set_unary(input_tensor);
|
||||
|
||||
void* input = input_tensor->get_data_ptr();
|
||||
|
||||
143
src/core/src/op/util/variable_value.cpp
Normal file
143
src/core/src/op/util/variable_value.cpp
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/util/variable_value.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/core/deprecated.hpp"
|
||||
#include "openvino/core/shape.hpp"
|
||||
#include "openvino/runtime/allocator.hpp"
|
||||
#include "openvino/runtime/itensor.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
#include "shape_util.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
class TensorWrapper : public ngraph::runtime::HostTensor {
|
||||
public:
|
||||
TensorWrapper(const ov::Tensor& tensor)
|
||||
: ngraph::runtime::HostTensor(tensor.get_element_type(), tensor.get_shape(), tensor.data()),
|
||||
tensor(tensor) {}
|
||||
|
||||
ov::Tensor tensor;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Tensor what contains HostTensorPtr inside
|
||||
*/
|
||||
class HostTensorWrapper : public ov::ITensor {
|
||||
public:
|
||||
ngraph::HostTensorPtr tensor;
|
||||
|
||||
HostTensorWrapper(const ngraph::HostTensorPtr& tensor) : tensor{tensor}, m_type(tensor->get_element_type()) {
|
||||
const auto& p_shape = tensor->get_partial_shape();
|
||||
if (p_shape.is_static()) {
|
||||
m_shape = p_shape.to_shape();
|
||||
} else {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
m_shape = ov::util::make_dynamic_shape();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
update_strides();
|
||||
}
|
||||
|
||||
const ov::element::Type& get_element_type() const override {
|
||||
return m_type;
|
||||
}
|
||||
|
||||
void set_shape(ov::Shape shape) override {
|
||||
tensor->set_shape(shape);
|
||||
m_shape = shape;
|
||||
update_strides();
|
||||
}
|
||||
|
||||
const ov::Shape& get_shape() const override {
|
||||
return m_shape;
|
||||
}
|
||||
|
||||
const ov::Strides& get_strides() const override {
|
||||
OPENVINO_ASSERT(get_element_type().bitwidth() >= 8,
|
||||
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
|
||||
get_element_type());
|
||||
return m_strides;
|
||||
}
|
||||
|
||||
size_t get_size() const override {
|
||||
return ov::shape_size(m_shape);
|
||||
}
|
||||
|
||||
size_t get_byte_size() const override {
|
||||
return get_size() * m_type.size();
|
||||
}
|
||||
|
||||
void* data(const ov::element::Type& element_type) const override {
|
||||
return tensor->get_data_ptr();
|
||||
}
|
||||
|
||||
private:
|
||||
ov::element::Type m_type;
|
||||
ov::Shape m_shape;
|
||||
ov::Strides m_strides;
|
||||
|
||||
void update_strides() {
|
||||
if (m_type.bitwidth() >= 8) {
|
||||
m_strides.clear();
|
||||
m_strides.resize(m_shape.size());
|
||||
auto size = m_strides.size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
size_t value(m_type.size());
|
||||
size_t dim(m_shape[size - 1 - i]);
|
||||
if (i) {
|
||||
value = m_strides[size - i] * dim;
|
||||
}
|
||||
m_strides[size - i - 1] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ov::op::util::VariableValue::VariableValue() = default;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
ov::op::util::VariableValue::VariableValue(ngraph::HostTensorPtr value)
|
||||
: m_value(ov::Tensor{std::make_shared<HostTensorWrapper>(value), {}}) {}
|
||||
|
||||
ov::op::util::VariableValue::VariableValue(ngraph::HostTensorPtr value, bool reset)
|
||||
: m_reset(reset),
|
||||
m_value(ov::Tensor{std::make_shared<HostTensorWrapper>(value), {}}) {}
|
||||
|
||||
ngraph::HostTensorPtr ov::op::util::VariableValue::get_value() const {
|
||||
if (auto wrapper = std::dynamic_pointer_cast<HostTensorWrapper>(m_value._impl))
|
||||
return wrapper->tensor;
|
||||
return std::make_shared<TensorWrapper>(m_value);
|
||||
}
|
||||
|
||||
void ov::op::util::VariableValue::set_value(const ngraph::HostTensorPtr& value) {
|
||||
m_value = ov::Tensor{std::make_shared<HostTensorWrapper>(value), {}};
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
void ov::op::util::VariableValue::set_reset(bool reset) {
|
||||
m_reset = reset;
|
||||
}
|
||||
|
||||
bool ov::op::util::VariableValue::get_reset() const {
|
||||
return m_reset;
|
||||
}
|
||||
|
||||
ov::op::util::VariableValue::VariableValue(const ov::Tensor& value) : m_value(value) {}
|
||||
|
||||
ov::op::util::VariableValue::VariableValue(const ov::Tensor& value, bool reset) : m_reset(reset), m_value(value) {}
|
||||
|
||||
const ov::Tensor& ov::op::util::VariableValue::get_state() const {
|
||||
return m_value;
|
||||
}
|
||||
|
||||
void ov::op::util::VariableValue::set_state(const ov::Tensor& value) {
|
||||
m_value = value;
|
||||
}
|
||||
@@ -13,8 +13,12 @@
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/op/util/variable_context.hpp"
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
#include "openvino/runtime/profiling_info.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
#include "plugin.hpp"
|
||||
#include "variable_state.hpp"
|
||||
|
||||
using Time = std::chrono::high_resolution_clock;
|
||||
|
||||
@@ -72,11 +76,27 @@ ov::template_plugin::InferRequest::InferRequest(const std::shared_ptr<const ov::
|
||||
output.get_partial_shape().is_dynamic() ? ov::Shape{0} : output.get_shape());
|
||||
});
|
||||
}
|
||||
|
||||
// Save variable states
|
||||
ov::op::util::VariableContext variable_context;
|
||||
for (const auto& variable : get_template_model()->m_model->get_variables()) {
|
||||
auto value = std::make_shared<ov::op::util::VariableValue>();
|
||||
if (!variable_context.get_variable_value(variable)) {
|
||||
auto shape = variable->get_info().data_shape.is_dynamic() ? ov::Shape{0}
|
||||
: variable->get_info().data_shape.to_shape();
|
||||
ov::Tensor tensor = ov::Tensor(variable->get_info().data_type, shape);
|
||||
variable_context.set_variable_value(variable, std::make_shared<ov::op::util::VariableValue>(tensor));
|
||||
}
|
||||
auto state = std::make_shared<VariableState>(variable->get_info().variable_id,
|
||||
variable_context.get_variable_value(variable)->get_state());
|
||||
m_variable_states.emplace_back(state);
|
||||
}
|
||||
m_eval_context.emplace("VariableContext", variable_context);
|
||||
}
|
||||
// ! [infer_request:ctor]
|
||||
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> ov::template_plugin::InferRequest::query_state() const {
|
||||
OPENVINO_NOT_IMPLEMENTED;
|
||||
return m_variable_states;
|
||||
}
|
||||
|
||||
std::shared_ptr<const ov::template_plugin::CompiledModel> ov::template_plugin::InferRequest::get_template_model()
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
|
||||
#include "executable.hpp"
|
||||
#include "ngraph/runtime/tensor.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/itt.hpp"
|
||||
#include "openvino/runtime/isync_infer_request.hpp"
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace template_plugin {
|
||||
@@ -52,6 +54,8 @@ private:
|
||||
std::vector<ov::Tensor> m_backend_input_tensors;
|
||||
std::vector<ov::Tensor> m_backend_output_tensors;
|
||||
std::shared_ptr<ov::runtime::Executable> m_executable;
|
||||
ov::EvaluationContext m_eval_context;
|
||||
std::vector<std::shared_ptr<ov::IVariableState>> m_variable_states;
|
||||
};
|
||||
// ! [infer_request:header]
|
||||
|
||||
|
||||
25
src/plugins/template/src/variable_state.hpp
Normal file
25
src/plugins/template/src/variable_state.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/runtime/ivariable_state.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace template_plugin {
|
||||
|
||||
class VariableState : public ov::IVariableState {
|
||||
public:
|
||||
VariableState(const std::string& name, const ov::Tensor& tensor) : ov::IVariableState(name) {
|
||||
set_state(tensor);
|
||||
}
|
||||
void reset() override {
|
||||
std::memset(m_state.data(), 0, m_state.get_byte_size());
|
||||
}
|
||||
|
||||
~VariableState() override = default;
|
||||
};
|
||||
|
||||
} // namespace template_plugin
|
||||
} // namespace ov
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "behavior/infer_request/memory_states.hpp"
|
||||
|
||||
using namespace BehaviorTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<memoryStateParams> memoryStateTestCases = {memoryStateParams(InferRequestVariableStateTest::getNetwork(),
|
||||
{"c_1-3", "r_1-3"},
|
||||
CommonTestUtils::DEVICE_TEMPLATE,
|
||||
{})};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Template_BehaviorTests,
|
||||
InferRequestVariableStateTest,
|
||||
::testing::ValuesIn(memoryStateTestCases),
|
||||
InferRequestVariableStateTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Template_BehaviorTests,
|
||||
InferRequestQueryStateExceptionTest,
|
||||
::testing::ValuesIn(memoryStateTestCases),
|
||||
InferRequestQueryStateExceptionTest::getTestCaseName);
|
||||
} // namespace
|
||||
|
||||
@@ -133,6 +133,8 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*InferRequestIOBBlobTest.*secondCallGetOutputAfterInferSync.*)",
|
||||
// Old API cannot deallocate tensor
|
||||
R"(.*InferRequestIOBBlobTest.*canProcessDeallocatedOutputBlobAfterGetAndSetBlob.*)",
|
||||
// Why query state should throw an exception
|
||||
R"(.*InferRequestQueryStateExceptionTest.*inferreq_smoke_QueryState_ExceptionTest.*)",
|
||||
};
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
Reference in New Issue
Block a user