Backward compatible bindings for the Node class (#9193)

This commit is contained in:
Tomasz Dołbniak 2021-12-14 15:33:39 +01:00 committed by GitHub
parent 5c0b125554
commit eab49eec8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,6 +29,21 @@ public:
}
};
namespace impl {
py::dict get_attributes(const std::shared_ptr<ngraph::Node>& node) {
util::DictAttributeSerializer dict_serializer(node);
return dict_serializer.get_attributes();
}
void set_attribute(std::shared_ptr<ngraph::Node>& node, const std::string& atr_name, py::object value) {
py::dict attr_dict;
attr_dict[atr_name.c_str()] = value;
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
node->visit_attributes(dict_deserializer);
}
} // namespace impl
namespace py = pybind11;
using PyRTMap = ngraph::Node::RTMap;
@ -289,17 +304,12 @@ void regclass_pyngraph_Node(py::module m) {
node.def_property_readonly("type_info", &ngraph::Node::get_type_info);
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def("get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
util::DictAttributeSerializer dict_serializer(self);
return dict_serializer.get_attributes();
});
node.def("set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
py::dict attr_dict;
attr_dict[atr_name.c_str()] = value;
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
self->visit_attributes(dict_deserializer);
});
node.def("get_attributes", &impl::get_attributes);
node.def("set_attribute", &impl::set_attribute);
// for backwards compatibility, this is how this method was named until 2021.4
node.def("_get_attributes", &impl::get_attributes);
// for backwards compatibility, this is how this method was named until 2021.4
node.def("_set_attribute", &impl::set_attribute);
node.def("set_arguments", [](const std::shared_ptr<ngraph::Node>& self, const ngraph::OutputVector& arguments) {
return self->set_arguments(arguments);
});