diff --git a/docs/nGraph_DG/nGraph_Python_API.md b/docs/nGraph_DG/nGraph_Python_API.md index 3b778747c06..da6357b6bd8 100644 --- a/docs/nGraph_DG/nGraph_Python_API.md +++ b/docs/nGraph_DG/nGraph_Python_API.md @@ -169,13 +169,13 @@ You can also set attribute values using corresponding setter methods, for exampl node.set_axis(0) ``` -Currently, you can get all attributes of a node using the `_get_attributes` method. Please note that this is an internal API method and may change in future versions of OpenVINO. +Currently, you can get all attributes of a node using the `get_attributes` method. The following code displays all attributes for all nodes in a function: ```python for node in function.get_ordered_ops(): - attributes = node._get_attributes() + attributes = node.get_attributes() if(attributes): print('Operation {} of type {} has attributes:'.format(node.get_friendly_name(), node.get_type_name())) for attr, value in attributes.items(): diff --git a/runtime/bindings/python/src/compatibility/ngraph/utils/node_factory.py b/runtime/bindings/python/src/compatibility/ngraph/utils/node_factory.py index 0809a55b4c5..ffb0c3d861c 100644 --- a/runtime/bindings/python/src/compatibility/ngraph/utils/node_factory.py +++ b/runtime/bindings/python/src/compatibility/ngraph/utils/node_factory.py @@ -10,6 +10,8 @@ from _pyngraph import NodeFactory as _NodeFactory from ngraph.impl import Node, Output +from ngraph.exceptions import UserInputError + DEFAULT_OPSET = "opset8" @@ -26,7 +28,7 @@ class NodeFactory(object): def create( self, op_type_name: str, - arguments: List[Union[Node, Output]], + arguments: Optional[List[Union[Node, Output]]] = None, attributes: Optional[Dict[str, Any]] = None, ) -> Node: """Create node object from provided description. @@ -39,9 +41,24 @@ class NodeFactory(object): @return Node object representing requested operator with attributes set. """ + if arguments is None and attributes is None: + node = self.factory.create(op_type_name) + node._attr_cache = {} + node._attr_cache_valid = False + return node + + if arguments is None and attributes is not None: + raise UserInputError( + 'Error: cannot create "{}" op without arguments.'.format( + op_type_name + ) + ) + if attributes is None: attributes = {} + assert arguments is not None + arguments = self._arguments_as_outputs(arguments) node = self.factory.create(op_type_name, arguments, attributes) @@ -57,7 +74,7 @@ class NodeFactory(object): # node.get_some_metric_attr_name() # node.set_some_metric_attr_name() # Please see test_dyn_attributes.py for more usage examples. - all_attributes = node._get_attributes() + all_attributes = node.get_attributes() for attr_name in all_attributes.keys(): setattr( node, @@ -134,7 +151,7 @@ class NodeFactory(object): @return The node attribute value. """ if not node._attr_cache_valid: - node._attr_cache = node._get_attributes() + node._attr_cache = node.get_attributes() node._attr_cache_valid = True return node._attr_cache[attr_name] @@ -146,5 +163,5 @@ class NodeFactory(object): @param attr_name: The attribute name. @param value: The new attribute value. """ - node._set_attribute(attr_name, value) + node.set_attribute(attr_name, value) node._attr_cache[attr_name] = value diff --git a/runtime/bindings/python/src/compatibility/pyngraph/node.cpp b/runtime/bindings/python/src/compatibility/pyngraph/node.cpp index 59ba2da7abf..3c7436d2dc8 100644 --- a/runtime/bindings/python/src/compatibility/pyngraph/node.cpp +++ b/runtime/bindings/python/src/compatibility/pyngraph/node.cpp @@ -265,15 +265,21 @@ void regclass_pyngraph_Node(py::module m) { node.def_property_readonly("version", &ngraph::Node::get_version); node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name); - node.def("_get_attributes", [](const std::shared_ptr& self) { + node.def("get_attributes", [](const std::shared_ptr& self) { util::DictAttributeSerializer dict_serializer(self); return dict_serializer.get_attributes(); }); - node.def("_set_attribute", [](std::shared_ptr& self, const std::string& atr_name, py::object value) { + node.def("set_attribute", [](std::shared_ptr& self, const std::string& atr_name, py::object value) { py::dict attr_dict; attr_dict[atr_name.c_str()] = value; std::unordered_map> variables; util::DictAttributeDeserializer dict_deserializer(attr_dict, variables); self->visit_attributes(dict_deserializer); }); + node.def("set_arguments", [](const std::shared_ptr& self, const ngraph::OutputVector& arguments) { + return self->set_arguments(arguments); + }); + node.def("validate", [](const std::shared_ptr& self) { + return self->constructor_validate_and_infer_types(); + }); } diff --git a/runtime/bindings/python/src/compatibility/pyngraph/node_factory.cpp b/runtime/bindings/python/src/compatibility/pyngraph/node_factory.cpp index bcc4a63c0bf..d4ad3d7bacb 100644 --- a/runtime/bindings/python/src/compatibility/pyngraph/node_factory.cpp +++ b/runtime/bindings/python/src/compatibility/pyngraph/node_factory.cpp @@ -19,6 +19,7 @@ #include "dict_attribute_visitor.hpp" #include "ngraph/check.hpp" #include "ngraph/except.hpp" +#include "ngraph/log.hpp" #include "ngraph/node.hpp" #include "ngraph/op/util/op_types.hpp" #include "ngraph/op/util/variable.hpp" @@ -51,6 +52,19 @@ public: return op_node; } + std::shared_ptr create(const std::string op_type_name) { + std::shared_ptr op_node = std::shared_ptr(m_opset.create(op_type_name)); + + NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name); + NGRAPH_CHECK(!ngraph::op::is_constant(op_node), + "Currently NodeFactory doesn't support Constant node: ", + op_type_name); + + NGRAPH_WARN << "Empty op created! Please assign inputs and attributes and run validate() before op is used."; + + return op_node; + } + private: const ngraph::OpSet& get_opset(std::string opset_ver) { std::locale loc; @@ -90,7 +104,16 @@ void regclass_pyngraph_NodeFactory(py::module m) { node_factory.def(py::init()); node_factory.def(py::init()); - node_factory.def("create", &NodeFactory::create); + node_factory.def("create", [](NodeFactory& self, const std::string name) { + return self.create(name); + }); + node_factory.def("create", + [](NodeFactory& self, + const std::string name, + const ngraph::OutputVector& arguments, + const py::dict& attributes) { + return self.create(name, arguments, attributes); + }); node_factory.def("__repr__", [](const NodeFactory& self) { return ""; diff --git a/runtime/bindings/python/tests/test_ngraph/test_node_factory.py b/runtime/bindings/python/tests/test_ngraph/test_node_factory.py index e47ad8ee75d..14fe3d62d04 100644 --- a/runtime/bindings/python/tests/test_ngraph/test_node_factory.py +++ b/runtime/bindings/python/tests/test_ngraph/test_node_factory.py @@ -3,6 +3,7 @@ import numpy as np import ngraph as ng +from ngraph.exceptions import UserInputError from ngraph.utils.node_factory import NodeFactory from _pyngraph import NodeFactory as _NodeFactory @@ -45,7 +46,49 @@ def test_node_factory_topk(): node = factory.create( "TopK", arguments, {"axis": 1, "mode": "max", "sort": "value"} ) + attributes = node.get_attributes() assert node.get_type_name() == "TopK" assert node.get_output_size() == 2 assert list(node.get_output_shape(0)) == [2, 3] + assert attributes["axis"] == 1 + assert attributes["mode"] == "max" + assert attributes["sort"] == "value" + + +def test_node_factory_empty_topk(): + factory = NodeFactory("opset1") + node = factory.create("TopK") + + assert node.get_type_name() == "TopK" + + +def test_node_factory_empty_topk_with_args_and_attrs(): + dtype = np.int32 + data = ng.parameter([2, 10], dtype=dtype, name="A") + k = ng.constant(3, dtype=dtype, name="B") + factory = NodeFactory("opset1") + arguments = NodeFactory._arguments_as_outputs([data, k]) + node = factory.create("TopK", None, None) + node.set_arguments(arguments) + node.set_attribute("axis", 1) + node.set_attribute("mode", "max") + node.set_attribute("sort", "value") + node.validate() + + assert node.get_type_name() == "TopK" + assert node.get_output_size() == 2 + assert list(node.get_output_shape(0)) == [2, 3] + + +def test_node_factory_validate_missing_arguments(): + factory = NodeFactory("opset1") + + try: + factory.create( + "TopK", None, {"axis": 1, "mode": "max", "sort": "value"} + ) + except UserInputError: + pass + else: + raise AssertionError("Validation of missing arguments has unexpectedly passed.")