Use evaluation context for the inference (#16492)

This commit is contained in:
Ilya Churaev 2023-03-23 13:52:03 +04:00 committed by GitHub
parent 982e1c1192
commit a3958d6ddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 19 deletions

View File

@ -24,6 +24,14 @@ public:
/// \returns true if iteration is successful, false otherwise
virtual bool call(std::vector<ov::Tensor>& outputs, const std::vector<ov::Tensor>& inputs) = 0;
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \param context Evaluation context
/// \returns true if iteration is successful, false otherwise
virtual bool call(std::vector<ov::Tensor>& outputs,
const std::vector<ov::Tensor>& inputs,
const ov::EvaluationContext& context) = 0;
/// \brief Executes a single iteration of a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs

View File

@ -105,7 +105,31 @@ ov::runtime::interpreter::INTExecutable::INTExecutable(const std::shared_ptr<ov:
bool ov::runtime::interpreter::INTExecutable::call(std::vector<ov::Tensor>& outputs,
const std::vector<ov::Tensor>& inputs) {
// map function params -> HostTensor
EvaluationContext eval_context;
ov::op::util::VariableContext variable_context;
eval_context.emplace("VariableContext", variable_context);
// for each ordered op in the graph
for (const auto& op : m_nodes) {
if (auto var_extension = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(op)) {
auto variable = var_extension->get_variable();
if (!variable_context.get_variable_value(variable)) {
auto h_tensor = ov::Tensor(op->get_input_element_type(0), op->get_input_shape(0));
// h_tensor->write(h_tensor->get_data_ptr(), h_tensor->get_size_in_bytes());
const auto tensor_input = make_tmp_host_tensor(h_tensor);
variable_context.set_variable_value(variable,
std::make_shared<ov::op::util::VariableValue>(tensor_input));
}
}
}
return call(outputs, inputs, eval_context);
}
bool ov::runtime::interpreter::INTExecutable::call(std::vector<ov::Tensor>& outputs,
const std::vector<ov::Tensor>& inputs,
const ov::EvaluationContext& context) {
// map function params -> ov::Tensor
std::unordered_map<std::shared_ptr<ov::descriptor::Tensor>, ov::Tensor> tensor_map;
size_t input_count = 0;
for (const auto& param : get_parameters()) {
@ -116,17 +140,13 @@ bool ov::runtime::interpreter::INTExecutable::call(std::vector<ov::Tensor>& outp
}
std::unordered_map<std::shared_ptr<ov::descriptor::Tensor>, size_t> results_map;
// map function outputs -> HostTensor
// map function outputs -> ov::Tensor
for (size_t output_count = 0; output_count < get_results().size(); ++output_count) {
auto output = get_results()[output_count]->output(0).get_tensor_ptr();
if (!results_map.count(output))
results_map.emplace(output, output_count);
}
EvaluationContext eval_context;
ov::op::util::VariableContext variable_context;
eval_context.emplace("VariableContext", variable_context);
// for each ordered op in the graph
for (const auto& op : m_nodes) {
if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(op)) {
@ -165,19 +185,9 @@ bool ov::runtime::interpreter::INTExecutable::call(std::vector<ov::Tensor>& outp
op_outputs.push_back(host_tensor);
}
if (auto var_extension = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(cloned_node)) {
auto variable = var_extension->get_variable();
if (!variable_context.get_variable_value(variable)) {
auto h_tensor = ov::Tensor(cloned_node->get_input_element_type(0), cloned_node->get_input_shape(0));
// h_tensor->write(h_tensor->get_data_ptr(), h_tensor->get_size_in_bytes());
const auto tensor_input = make_tmp_host_tensor(h_tensor);
variable_context.set_variable_value(variable,
std::make_shared<ov::op::util::VariableValue>(tensor_input));
}
}
// Call evaluate for cloned_node with static shapes
if (!cloned_node->evaluate(op_outputs, op_inputs, eval_context)) {
if (!cloned_node->evaluate(op_outputs, op_inputs, context)) {
// TODO: extend evaluate map for the context
evaluate_node(cloned_node, op_outputs, op_inputs);
}
// Update tensors in tensor map

View File

@ -29,6 +29,9 @@ public:
INTExecutable(const std::shared_ptr<ov::Model>& model);
bool call(std::vector<ov::Tensor>& outputs, const std::vector<ov::Tensor>& inputs) override;
bool call(std::vector<ov::Tensor>& outputs,
const std::vector<ov::Tensor>& inputs,
const ov::EvaluationContext& context) override;
ov::Tensor create_input_tensor(size_t input_index) override;

View File

@ -207,7 +207,7 @@ void ov::template_plugin::InferRequest::infer_preprocess() {
void ov::template_plugin::InferRequest::start_pipeline() {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, m_profiling_task[StartPipeline])
auto start = Time::now();
m_executable->call(m_backend_output_tensors, m_backend_input_tensors);
m_executable->call(m_backend_output_tensors, m_backend_input_tensors, m_eval_context);
m_durations[StartPipeline] = Time::now() - start;
}
// ! [infer_request:start_pipeline]