Add getting/setting friendly name for Node wrapper Py API (#1286)

This commit is contained in:
Jan Iwaszkiewicz
2020-07-20 10:31:49 +02:00
committed by GitHub
parent 8d6238a3d7
commit d4c9af91d8
3 changed files with 25 additions and 6 deletions

View File

@@ -20,9 +20,9 @@ from ngraph.impl import Node
from ngraph.utils.types import NodeInput, as_node, as_nodes
def _set_node_name(node: Node, **kwargs: Any) -> Node:
def _set_node_friendly_name(node: Node, **kwargs: Any) -> Node:
if "name" in kwargs:
node.name = kwargs["name"]
node.friendly_name = kwargs["name"]
return node
@@ -32,7 +32,7 @@ def nameable_op(node_factory_function: Callable) -> Callable:
@wraps(node_factory_function)
def wrapper(*args: Any, **kwargs: Any) -> Node:
node = node_factory_function(*args, **kwargs)
node = _set_node_name(node, **kwargs)
node = _set_node_friendly_name(node, **kwargs)
return node
return wrapper
@@ -45,7 +45,7 @@ def unary_op(node_factory_function: Callable) -> Callable:
def wrapper(input_value: NodeInput, *args: Any, **kwargs: Any) -> Node:
input_node = as_node(input_value)
node = node_factory_function(input_node, *args, **kwargs)
node = _set_node_name(node, **kwargs)
node = _set_node_friendly_name(node, **kwargs)
return node
return wrapper
@@ -58,7 +58,7 @@ def binary_op(node_factory_function: Callable) -> Callable:
def wrapper(left: NodeInput, right: NodeInput, *args: Any, **kwargs: Any) -> Node:
left, right = as_nodes(left, right)
node = node_factory_function(left, right, *args, **kwargs)
node = _set_node_name(node, **kwargs)
node = _set_node_friendly_name(node, **kwargs)
return node
return wrapper

View File

@@ -77,6 +77,8 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
node.def("get_type_name", &ngraph::Node::get_type_name);
node.def("get_unique_name", &ngraph::Node::get_name);
node.def("get_friendly_name", &ngraph::Node::get_friendly_name);
node.def("set_friendly_name", &ngraph::Node::set_friendly_name);
node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input);
node.def("inputs",
(std::vector<ngraph::Input<ngraph::Node>>(ngraph::Node::*)()) & ngraph::Node::inputs);
@@ -86,8 +88,10 @@ void regclass_pyngraph_Node(py::module m)
(std::vector<ngraph::Output<ngraph::Node>>(ngraph::Node::*)()) &
ngraph::Node::outputs);
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def_property_readonly("shape", &ngraph::Node::get_shape);
node.def_property_readonly("name", &ngraph::Node::get_name);
node.def_property(
"friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
util::DictAttributeSerializer dict_serializer(self);

View File

@@ -273,6 +273,21 @@ def test_result():
assert np.allclose(result, node)
def test_node_friendly_name():
dummy_node = ng.parameter(shape=[1], name="dummy_name")
assert(dummy_node.name == "Parameter_0")
assert(dummy_node.friendly_name == "dummy_name")
dummy_node.set_friendly_name("changed_name")
assert(dummy_node.get_friendly_name() == "changed_name")
dummy_node.friendly_name = "new_name"
assert(dummy_node.get_friendly_name() == "new_name")
def test_node_output():
input_array = np.array([0, 1, 2, 3, 4, 5])
splits = 3