[core] Migrate the Assign operator to new API (#19664)

* Migrate the Assign operator to new API

* Use memcpy instead of tensor copy_to
This commit is contained in:
Pawel Raasz 2023-09-12 13:10:23 +02:00 committed by GitHub
parent adf7a24ec0
commit 4af1fd087c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 43 deletions

View File

@ -63,11 +63,9 @@ public:
OPENVINO_ASSERT(m_variable, "Variable is not initialized. Variable_id is unavailable"); OPENVINO_ASSERT(m_variable, "Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id; return m_variable->get_info().variable_id;
} }
OPENVINO_SUPPRESS_DEPRECATED_START bool evaluate(TensorVector& outputs,
bool evaluate(const HostTensorVector& outputs, const TensorVector& inputs,
const HostTensorVector& inputs,
const EvaluationContext& evaluation_context) const override; const EvaluationContext& evaluation_context) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override; bool has_evaluate() const override;
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
}; };

View File

@ -2,26 +2,24 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "ngraph/op/assign.hpp" #include "openvino/op/assign.hpp"
#include <assign_shape_inference.hpp>
#include "assign_shape_inference.hpp"
#include "itt.hpp" #include "itt.hpp"
#include "ngraph/op/read_value.hpp" #include "openvino/op/read_value.hpp"
#include "ngraph/op/util/variable.hpp" #include "openvino/op/util/variable.hpp"
#include "ngraph/op/util/variable_context.hpp" #include "openvino/op/util/variable_context.hpp"
#include "ngraph/ops.hpp"
using namespace std; namespace ov {
using namespace ngraph; namespace op {
namespace v3 {
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id) Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
: AssignBase({new_value}), : AssignBase({new_value}),
m_variable_id(variable_id) { m_variable_id(variable_id) {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::v3::Assign::validate_and_infer_types() { void Assign::validate_and_infer_types() {
OV_OP_SCOPE(v3_Assign_validate_and_infer_types); OV_OP_SCOPE(v3_Assign_validate_and_infer_types);
auto value = input_value(0); auto value = input_value(0);
auto arg_t = get_input_element_type(0); auto arg_t = get_input_element_type(0);
@ -33,93 +31,95 @@ void op::v3::Assign::validate_and_infer_types() {
} }
auto nodes = topological_sort(start_nodes); auto nodes = topological_sort(start_nodes);
for (const auto& node : nodes) { for (const auto& node : nodes) {
if (auto read_value = ov::as_type_ptr<op::v3::ReadValue>(node)) { if (auto read_value = ov::as_type_ptr<v3::ReadValue>(node)) {
if (read_value->get_variable_id() == m_variable_id) if (read_value->get_variable_id() == m_variable_id)
m_variable = read_value->get_variable(); m_variable = read_value->get_variable();
} }
} }
NODE_VALIDATION_CHECK(this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id); NODE_VALIDATION_CHECK(this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
} }
std::vector<ov::PartialShape> input_shapes = {input_shape};
const auto input_shapes = std::vector<ov::PartialShape>{input_shape};
const auto output_shapes = shape_infer(this, input_shapes); const auto output_shapes = shape_infer(this, input_shapes);
set_output_type(0, arg_t, output_shapes[0]); set_output_type(0, arg_t, output_shapes[0]);
} }
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const { std::shared_ptr<Node> Assign::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_Assign_clone_with_new_inputs); OV_OP_SCOPE(v3_Assign_clone_with_new_inputs);
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<op::v3::Assign>(new_args.at(0), m_variable_id); return std::make_shared<Assign>(new_args.at(0), m_variable_id);
} }
bool op::v3::Assign::visit_attributes(AttributeVisitor& visitor) { bool Assign::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v3_Assign_visit_attributes); OV_OP_SCOPE(v3_Assign_visit_attributes);
visitor.on_attribute("variable_id", m_variable_id); visitor.on_attribute("variable_id", m_variable_id);
return true; return true;
} }
} // namespace v3
op::v6::Assign::Assign(const Output<Node>& new_value, const std::shared_ptr<Variable>& variable) namespace v6 {
Assign::Assign(const Output<Node>& new_value, const std::shared_ptr<util::Variable>& variable)
: AssignBase({new_value}) { : AssignBase({new_value}) {
m_variable = variable; m_variable = variable;
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::v6::Assign::validate_and_infer_types() { void Assign::validate_and_infer_types() {
OV_OP_SCOPE(v6_Assign_validate_and_infer_types); OV_OP_SCOPE(v6_Assign_validate_and_infer_types);
m_variable->update({get_input_partial_shape(0), get_input_element_type(0), m_variable->get_info().variable_id}); m_variable->update({get_input_partial_shape(0), get_input_element_type(0), m_variable->get_info().variable_id});
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
shared_ptr<Node> op::v6::Assign::clone_with_new_inputs(const OutputVector& new_args) const { std::shared_ptr<Node> Assign::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v6_Assign_clone_with_new_inputs); OV_OP_SCOPE(v6_Assign_clone_with_new_inputs);
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<op::v6::Assign>(new_args.at(0), m_variable); return std::make_shared<Assign>(new_args.at(0), m_variable);
} }
bool op::v6::Assign::visit_attributes(AttributeVisitor& visitor) { bool Assign::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v6_Assign_visit_attributes); OV_OP_SCOPE(v6_Assign_visit_attributes);
visitor.on_attribute("variable_id", m_variable); visitor.on_attribute("variable_id", m_variable);
return true; return true;
} }
OPENVINO_SUPPRESS_DEPRECATED_START bool Assign::evaluate(TensorVector& outputs,
bool op::v6::Assign::evaluate(const HostTensorVector& outputs, const TensorVector& inputs,
const HostTensorVector& inputs,
const EvaluationContext& evaluation_context) const { const EvaluationContext& evaluation_context) const {
OV_OP_SCOPE(v6_Assign_evaluate); OV_OP_SCOPE(v6_Assign_evaluate);
const auto& found_context = evaluation_context.find("VariableContext"); const auto& found_context = evaluation_context.find("VariableContext");
NODE_VALIDATION_CHECK(this, found_context != evaluation_context.end(), "VariableContext not found."); NODE_VALIDATION_CHECK(this, found_context != evaluation_context.end(), "VariableContext not found.");
auto& variable_context = const_cast<VariableContext&>(found_context->second.as<VariableContext>()); auto& variable_context = const_cast<util::VariableContext&>(found_context->second.as<util::VariableContext>());
const auto& variable_values = variable_context.get_variable_values(); const auto& variable_values = variable_context.get_variable_values();
// automatically allocate memory if not provided by user // automatically allocate memory if not provided by user
if (variable_values.find(m_variable) == variable_values.end()) { if (variable_values.find(m_variable) == variable_values.end()) {
auto host_tensor = auto tensor = Tensor(m_variable->get_info().data_type, m_variable->get_info().data_shape.to_shape());
std::make_shared<ngraph::HostTensor>(m_variable->get_info().data_type, m_variable->get_info().data_shape); variable_context.set_variable_value(m_variable, std::make_shared<util::VariableValue>(tensor));
variable_context.set_variable_value(m_variable, make_shared<VariableValue>(host_tensor));
} }
const auto var_value = variable_values.find(m_variable)->second; const auto var_value = variable_values.find(m_variable)->second;
var_value->set_reset(false); var_value->set_reset(false);
const auto& buffer = var_value->get_value(); auto buffer = var_value->get_state();
buffer->set_unary(inputs[0]); buffer.set_shape(inputs[0].get_shape());
outputs[0]->set_unary(inputs[0]); outputs[0].set_shape(inputs[0].get_shape());
void* input = inputs[0]->get_data_ptr(); std::memcpy(outputs[0].data(), inputs[0].data(), inputs[0].get_byte_size());
outputs[0]->write(input, outputs[0]->get_size_in_bytes()); std::memcpy(buffer.data(), inputs[0].data(), inputs[0].get_byte_size());
buffer->write(input, buffer->get_size_in_bytes());
return true; return true;
} }
OPENVINO_SUPPRESS_DEPRECATED_END
bool op::v6::Assign::has_evaluate() const { bool Assign::has_evaluate() const {
OV_OP_SCOPE(v1_Assign_has_evaluate); OV_OP_SCOPE(v1_Assign_has_evaluate);
return true; return true;
} }
bool op::v6::Assign::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { bool Assign::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
return false; return false;
} }
} // namespace v6
} // namespace op
} // namespace ov