Add default constructor for op. (#7368)

* Add default constructor for op.

* Apply correct format.

* Apply correct format for py files.

* Remove underscore prefix, add create overlaod method.

* Resolve use cases for parameter arguments in create fucntion.

* Resolve myoy issue.

* Remove underscore for getter and setter functions.

* Restore check when arguments is None.

* Add tests covering raising errors and get_attributes() function.
This commit is contained in:
Szymon Durawa 2021-09-13 05:49:19 +02:00 committed by GitHub
parent 66bad412a4
commit c50c0d59bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 9 deletions

View File

@ -169,13 +169,13 @@ You can also set attribute values using corresponding setter methods, for exampl
node.set_axis(0)
```
Currently, you can get all attributes of a node using the `_get_attributes` method. Please note that this is an internal API method and may change in future versions of OpenVINO.
Currently, you can get all attributes of a node using the `get_attributes` method.
The following code displays all attributes for all nodes in a function:
```python
for node in function.get_ordered_ops():
attributes = node._get_attributes()
attributes = node.get_attributes()
if(attributes):
print('Operation {} of type {} has attributes:'.format(node.get_friendly_name(), node.get_type_name()))
for attr, value in attributes.items():

View File

@ -10,6 +10,8 @@ from _pyngraph import NodeFactory as _NodeFactory
from ngraph.impl import Node, Output
from ngraph.exceptions import UserInputError
DEFAULT_OPSET = "opset8"
@ -26,7 +28,7 @@ class NodeFactory(object):
def create(
self,
op_type_name: str,
arguments: List[Union[Node, Output]],
arguments: Optional[List[Union[Node, Output]]] = None,
attributes: Optional[Dict[str, Any]] = None,
) -> Node:
"""Create node object from provided description.
@ -39,9 +41,24 @@ class NodeFactory(object):
@return Node object representing requested operator with attributes set.
"""
if arguments is None and attributes is None:
node = self.factory.create(op_type_name)
node._attr_cache = {}
node._attr_cache_valid = False
return node
if arguments is None and attributes is not None:
raise UserInputError(
'Error: cannot create "{}" op without arguments.'.format(
op_type_name
)
)
if attributes is None:
attributes = {}
assert arguments is not None
arguments = self._arguments_as_outputs(arguments)
node = self.factory.create(op_type_name, arguments, attributes)
@ -57,7 +74,7 @@ class NodeFactory(object):
# node.get_some_metric_attr_name()
# node.set_some_metric_attr_name()
# Please see test_dyn_attributes.py for more usage examples.
all_attributes = node._get_attributes()
all_attributes = node.get_attributes()
for attr_name in all_attributes.keys():
setattr(
node,
@ -134,7 +151,7 @@ class NodeFactory(object):
@return The node attribute value.
"""
if not node._attr_cache_valid:
node._attr_cache = node._get_attributes()
node._attr_cache = node.get_attributes()
node._attr_cache_valid = True
return node._attr_cache[attr_name]
@ -146,5 +163,5 @@ class NodeFactory(object):
@param attr_name: The attribute name.
@param value: The new attribute value.
"""
node._set_attribute(attr_name, value)
node.set_attribute(attr_name, value)
node._attr_cache[attr_name] = value

View File

@ -265,15 +265,21 @@ void regclass_pyngraph_Node(py::module m) {
node.def_property_readonly("version", &ngraph::Node::get_version);
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
node.def("get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
util::DictAttributeSerializer dict_serializer(self);
return dict_serializer.get_attributes();
});
node.def("_set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
node.def("set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
py::dict attr_dict;
attr_dict[atr_name.c_str()] = value;
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
self->visit_attributes(dict_deserializer);
});
node.def("set_arguments", [](const std::shared_ptr<ngraph::Node>& self, const ngraph::OutputVector& arguments) {
return self->set_arguments(arguments);
});
node.def("validate", [](const std::shared_ptr<ngraph::Node>& self) {
return self->constructor_validate_and_infer_types();
});
}

View File

@ -19,6 +19,7 @@
#include "dict_attribute_visitor.hpp"
#include "ngraph/check.hpp"
#include "ngraph/except.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/util/variable.hpp"
@ -51,6 +52,19 @@ public:
return op_node;
}
std::shared_ptr<ngraph::Node> create(const std::string op_type_name) {
std::shared_ptr<ngraph::Node> op_node = std::shared_ptr<ngraph::Node>(m_opset.create(op_type_name));
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
NGRAPH_CHECK(!ngraph::op::is_constant(op_node),
"Currently NodeFactory doesn't support Constant node: ",
op_type_name);
NGRAPH_WARN << "Empty op created! Please assign inputs and attributes and run validate() before op is used.";
return op_node;
}
private:
const ngraph::OpSet& get_opset(std::string opset_ver) {
std::locale loc;
@ -90,7 +104,16 @@ void regclass_pyngraph_NodeFactory(py::module m) {
node_factory.def(py::init());
node_factory.def(py::init<std::string>());
node_factory.def("create", &NodeFactory::create);
node_factory.def("create", [](NodeFactory& self, const std::string name) {
return self.create(name);
});
node_factory.def("create",
[](NodeFactory& self,
const std::string name,
const ngraph::OutputVector& arguments,
const py::dict& attributes) {
return self.create(name, arguments, attributes);
});
node_factory.def("__repr__", [](const NodeFactory& self) {
return "<NodeFactory>";

View File

@ -3,6 +3,7 @@
import numpy as np
import ngraph as ng
from ngraph.exceptions import UserInputError
from ngraph.utils.node_factory import NodeFactory
from _pyngraph import NodeFactory as _NodeFactory
@ -45,7 +46,49 @@ def test_node_factory_topk():
node = factory.create(
"TopK", arguments, {"axis": 1, "mode": "max", "sort": "value"}
)
attributes = node.get_attributes()
assert node.get_type_name() == "TopK"
assert node.get_output_size() == 2
assert list(node.get_output_shape(0)) == [2, 3]
assert attributes["axis"] == 1
assert attributes["mode"] == "max"
assert attributes["sort"] == "value"
def test_node_factory_empty_topk():
factory = NodeFactory("opset1")
node = factory.create("TopK")
assert node.get_type_name() == "TopK"
def test_node_factory_empty_topk_with_args_and_attrs():
dtype = np.int32
data = ng.parameter([2, 10], dtype=dtype, name="A")
k = ng.constant(3, dtype=dtype, name="B")
factory = NodeFactory("opset1")
arguments = NodeFactory._arguments_as_outputs([data, k])
node = factory.create("TopK", None, None)
node.set_arguments(arguments)
node.set_attribute("axis", 1)
node.set_attribute("mode", "max")
node.set_attribute("sort", "value")
node.validate()
assert node.get_type_name() == "TopK"
assert node.get_output_size() == 2
assert list(node.get_output_shape(0)) == [2, 3]
def test_node_factory_validate_missing_arguments():
factory = NodeFactory("opset1")
try:
factory.create(
"TopK", None, {"axis": 1, "mode": "max", "sort": "value"}
)
except UserInputError:
pass
else:
raise AssertionError("Validation of missing arguments has unexpectedly passed.")