Backward compatible bindings for the Node class (#9193)
This commit is contained in:
parent
5c0b125554
commit
eab49eec8b
@ -29,6 +29,21 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace impl {
|
||||||
|
py::dict get_attributes(const std::shared_ptr<ngraph::Node>& node) {
|
||||||
|
util::DictAttributeSerializer dict_serializer(node);
|
||||||
|
return dict_serializer.get_attributes();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_attribute(std::shared_ptr<ngraph::Node>& node, const std::string& atr_name, py::object value) {
|
||||||
|
py::dict attr_dict;
|
||||||
|
attr_dict[atr_name.c_str()] = value;
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
|
||||||
|
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
|
||||||
|
node->visit_attributes(dict_deserializer);
|
||||||
|
}
|
||||||
|
} // namespace impl
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
using PyRTMap = ngraph::Node::RTMap;
|
using PyRTMap = ngraph::Node::RTMap;
|
||||||
@ -289,17 +304,12 @@ void regclass_pyngraph_Node(py::module m) {
|
|||||||
node.def_property_readonly("type_info", &ngraph::Node::get_type_info);
|
node.def_property_readonly("type_info", &ngraph::Node::get_type_info);
|
||||||
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||||
|
|
||||||
node.def("get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
node.def("get_attributes", &impl::get_attributes);
|
||||||
util::DictAttributeSerializer dict_serializer(self);
|
node.def("set_attribute", &impl::set_attribute);
|
||||||
return dict_serializer.get_attributes();
|
// for backwards compatibility, this is how this method was named until 2021.4
|
||||||
});
|
node.def("_get_attributes", &impl::get_attributes);
|
||||||
node.def("set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
// for backwards compatibility, this is how this method was named until 2021.4
|
||||||
py::dict attr_dict;
|
node.def("_set_attribute", &impl::set_attribute);
|
||||||
attr_dict[atr_name.c_str()] = value;
|
|
||||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
|
|
||||||
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
|
|
||||||
self->visit_attributes(dict_deserializer);
|
|
||||||
});
|
|
||||||
node.def("set_arguments", [](const std::shared_ptr<ngraph::Node>& self, const ngraph::OutputVector& arguments) {
|
node.def("set_arguments", [](const std::shared_ptr<ngraph::Node>& self, const ngraph::OutputVector& arguments) {
|
||||||
return self->set_arguments(arguments);
|
return self->set_arguments(arguments);
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user