[core] Migrate the Assign operator to new API (#19664)

* Migrate the Assign operator to new API

* Use memcpy instead of tensor copy_to
This commit is contained in:
Pawel Raasz 2023-09-12 13:10:23 +02:00 committed by GitHub
parent adf7a24ec0
commit 4af1fd087c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 43 deletions

View File

@ -63,11 +63,9 @@ public:
OPENVINO_ASSERT(m_variable, "Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs,
bool evaluate(TensorVector& outputs,
const TensorVector& inputs,
const EvaluationContext& evaluation_context) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override;
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
};

View File

@ -2,26 +2,24 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/assign.hpp"
#include <assign_shape_inference.hpp>
#include "openvino/op/assign.hpp"
#include "assign_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/op/read_value.hpp"
#include "ngraph/op/util/variable.hpp"
#include "ngraph/op/util/variable_context.hpp"
#include "ngraph/ops.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/op/util/variable.hpp"
#include "openvino/op/util/variable_context.hpp"
using namespace std;
using namespace ngraph;
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
namespace ov {
namespace op {
namespace v3 {
Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
: AssignBase({new_value}),
m_variable_id(variable_id) {
constructor_validate_and_infer_types();
}
void op::v3::Assign::validate_and_infer_types() {
void Assign::validate_and_infer_types() {
OV_OP_SCOPE(v3_Assign_validate_and_infer_types);
auto value = input_value(0);
auto arg_t = get_input_element_type(0);
@ -33,93 +31,95 @@ void op::v3::Assign::validate_and_infer_types() {
}
auto nodes = topological_sort(start_nodes);
for (const auto& node : nodes) {
if (auto read_value = ov::as_type_ptr<op::v3::ReadValue>(node)) {
if (auto read_value = ov::as_type_ptr<v3::ReadValue>(node)) {
if (read_value->get_variable_id() == m_variable_id)
m_variable = read_value->get_variable();
}
}
NODE_VALIDATION_CHECK(this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
}
std::vector<ov::PartialShape> input_shapes = {input_shape};
const auto input_shapes = std::vector<ov::PartialShape>{input_shape};
const auto output_shapes = shape_infer(this, input_shapes);
set_output_type(0, arg_t, output_shapes[0]);
}
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Assign::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_Assign_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<op::v3::Assign>(new_args.at(0), m_variable_id);
return std::make_shared<Assign>(new_args.at(0), m_variable_id);
}
bool op::v3::Assign::visit_attributes(AttributeVisitor& visitor) {
bool Assign::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v3_Assign_visit_attributes);
visitor.on_attribute("variable_id", m_variable_id);
return true;
}
} // namespace v3
op::v6::Assign::Assign(const Output<Node>& new_value, const std::shared_ptr<Variable>& variable)
namespace v6 {
Assign::Assign(const Output<Node>& new_value, const std::shared_ptr<util::Variable>& variable)
: AssignBase({new_value}) {
m_variable = variable;
constructor_validate_and_infer_types();
}
void op::v6::Assign::validate_and_infer_types() {
void Assign::validate_and_infer_types() {
OV_OP_SCOPE(v6_Assign_validate_and_infer_types);
m_variable->update({get_input_partial_shape(0), get_input_element_type(0), m_variable->get_info().variable_id});
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::v6::Assign::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Assign::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v6_Assign_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<op::v6::Assign>(new_args.at(0), m_variable);
return std::make_shared<Assign>(new_args.at(0), m_variable);
}
bool op::v6::Assign::visit_attributes(AttributeVisitor& visitor) {
bool Assign::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v6_Assign_visit_attributes);
visitor.on_attribute("variable_id", m_variable);
return true;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool op::v6::Assign::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs,
const EvaluationContext& evaluation_context) const {
bool Assign::evaluate(TensorVector& outputs,
const TensorVector& inputs,
const EvaluationContext& evaluation_context) const {
OV_OP_SCOPE(v6_Assign_evaluate);
const auto& found_context = evaluation_context.find("VariableContext");
NODE_VALIDATION_CHECK(this, found_context != evaluation_context.end(), "VariableContext not found.");
auto& variable_context = const_cast<VariableContext&>(found_context->second.as<VariableContext>());
auto& variable_context = const_cast<util::VariableContext&>(found_context->second.as<util::VariableContext>());
const auto& variable_values = variable_context.get_variable_values();
// automatically allocate memory if not provided by user
if (variable_values.find(m_variable) == variable_values.end()) {
auto host_tensor =
std::make_shared<ngraph::HostTensor>(m_variable->get_info().data_type, m_variable->get_info().data_shape);
variable_context.set_variable_value(m_variable, make_shared<VariableValue>(host_tensor));
auto tensor = Tensor(m_variable->get_info().data_type, m_variable->get_info().data_shape.to_shape());
variable_context.set_variable_value(m_variable, std::make_shared<util::VariableValue>(tensor));
}
const auto var_value = variable_values.find(m_variable)->second;
var_value->set_reset(false);
const auto& buffer = var_value->get_value();
buffer->set_unary(inputs[0]);
outputs[0]->set_unary(inputs[0]);
auto buffer = var_value->get_state();
buffer.set_shape(inputs[0].get_shape());
outputs[0].set_shape(inputs[0].get_shape());
void* input = inputs[0]->get_data_ptr();
outputs[0]->write(input, outputs[0]->get_size_in_bytes());
buffer->write(input, buffer->get_size_in_bytes());
std::memcpy(outputs[0].data(), inputs[0].data(), inputs[0].get_byte_size());
std::memcpy(buffer.data(), inputs[0].data(), inputs[0].get_byte_size());
return true;
}
OPENVINO_SUPPRESS_DEPRECATED_END
bool op::v6::Assign::has_evaluate() const {
bool Assign::has_evaluate() const {
OV_OP_SCOPE(v1_Assign_has_evaluate);
return true;
}
bool op::v6::Assign::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
bool Assign::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
return false;
}
} // namespace v6
} // namespace op
} // namespace ov