Add getting/setting friendly name for Node wrapper Py API (#1286)
This commit is contained in:
@@ -20,9 +20,9 @@ from ngraph.impl import Node
|
||||
from ngraph.utils.types import NodeInput, as_node, as_nodes
|
||||
|
||||
|
||||
def _set_node_name(node: Node, **kwargs: Any) -> Node:
|
||||
def _set_node_friendly_name(node: Node, **kwargs: Any) -> Node:
|
||||
if "name" in kwargs:
|
||||
node.name = kwargs["name"]
|
||||
node.friendly_name = kwargs["name"]
|
||||
return node
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def nameable_op(node_factory_function: Callable) -> Callable:
|
||||
@wraps(node_factory_function)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Node:
|
||||
node = node_factory_function(*args, **kwargs)
|
||||
node = _set_node_name(node, **kwargs)
|
||||
node = _set_node_friendly_name(node, **kwargs)
|
||||
return node
|
||||
|
||||
return wrapper
|
||||
@@ -45,7 +45,7 @@ def unary_op(node_factory_function: Callable) -> Callable:
|
||||
def wrapper(input_value: NodeInput, *args: Any, **kwargs: Any) -> Node:
|
||||
input_node = as_node(input_value)
|
||||
node = node_factory_function(input_node, *args, **kwargs)
|
||||
node = _set_node_name(node, **kwargs)
|
||||
node = _set_node_friendly_name(node, **kwargs)
|
||||
return node
|
||||
|
||||
return wrapper
|
||||
@@ -58,7 +58,7 @@ def binary_op(node_factory_function: Callable) -> Callable:
|
||||
def wrapper(left: NodeInput, right: NodeInput, *args: Any, **kwargs: Any) -> Node:
|
||||
left, right = as_nodes(left, right)
|
||||
node = node_factory_function(left, right, *args, **kwargs)
|
||||
node = _set_node_name(node, **kwargs)
|
||||
node = _set_node_friendly_name(node, **kwargs)
|
||||
return node
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -77,6 +77,8 @@ void regclass_pyngraph_Node(py::module m)
|
||||
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
|
||||
node.def("get_type_name", &ngraph::Node::get_type_name);
|
||||
node.def("get_unique_name", &ngraph::Node::get_name);
|
||||
node.def("get_friendly_name", &ngraph::Node::get_friendly_name);
|
||||
node.def("set_friendly_name", &ngraph::Node::set_friendly_name);
|
||||
node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input);
|
||||
node.def("inputs",
|
||||
(std::vector<ngraph::Input<ngraph::Node>>(ngraph::Node::*)()) & ngraph::Node::inputs);
|
||||
@@ -86,8 +88,10 @@ void regclass_pyngraph_Node(py::module m)
|
||||
(std::vector<ngraph::Output<ngraph::Node>>(ngraph::Node::*)()) &
|
||||
ngraph::Node::outputs);
|
||||
|
||||
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||
node.def_property_readonly("shape", &ngraph::Node::get_shape);
|
||||
node.def_property_readonly("name", &ngraph::Node::get_name);
|
||||
node.def_property(
|
||||
"friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||
|
||||
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||
util::DictAttributeSerializer dict_serializer(self);
|
||||
|
||||
@@ -273,6 +273,21 @@ def test_result():
|
||||
assert np.allclose(result, node)
|
||||
|
||||
|
||||
def test_node_friendly_name():
|
||||
dummy_node = ng.parameter(shape=[1], name="dummy_name")
|
||||
|
||||
assert(dummy_node.name == "Parameter_0")
|
||||
assert(dummy_node.friendly_name == "dummy_name")
|
||||
|
||||
dummy_node.set_friendly_name("changed_name")
|
||||
|
||||
assert(dummy_node.get_friendly_name() == "changed_name")
|
||||
|
||||
dummy_node.friendly_name = "new_name"
|
||||
|
||||
assert(dummy_node.get_friendly_name() == "new_name")
|
||||
|
||||
|
||||
def test_node_output():
|
||||
input_array = np.array([0, 1, 2, 3, 4, 5])
|
||||
splits = 3
|
||||
|
||||
Reference in New Issue
Block a user