Add default constructor for op. (#7368)
* Add default constructor for op. * Apply correct format. * Apply correct format for py files. * Remove underscore prefix, add create overlaod method. * Resolve use cases for parameter arguments in create fucntion. * Resolve myoy issue. * Remove underscore for getter and setter functions. * Restore check when arguments is None. * Add tests covering raising errors and get_attributes() function.
This commit is contained in:
parent
66bad412a4
commit
c50c0d59bb
@ -169,13 +169,13 @@ You can also set attribute values using corresponding setter methods, for exampl
|
||||
node.set_axis(0)
|
||||
```
|
||||
|
||||
Currently, you can get all attributes of a node using the `_get_attributes` method. Please note that this is an internal API method and may change in future versions of OpenVINO.
|
||||
Currently, you can get all attributes of a node using the `get_attributes` method.
|
||||
|
||||
The following code displays all attributes for all nodes in a function:
|
||||
|
||||
```python
|
||||
for node in function.get_ordered_ops():
|
||||
attributes = node._get_attributes()
|
||||
attributes = node.get_attributes()
|
||||
if(attributes):
|
||||
print('Operation {} of type {} has attributes:'.format(node.get_friendly_name(), node.get_type_name()))
|
||||
for attr, value in attributes.items():
|
||||
|
@ -10,6 +10,8 @@ from _pyngraph import NodeFactory as _NodeFactory
|
||||
|
||||
from ngraph.impl import Node, Output
|
||||
|
||||
from ngraph.exceptions import UserInputError
|
||||
|
||||
DEFAULT_OPSET = "opset8"
|
||||
|
||||
|
||||
@ -26,7 +28,7 @@ class NodeFactory(object):
|
||||
def create(
|
||||
self,
|
||||
op_type_name: str,
|
||||
arguments: List[Union[Node, Output]],
|
||||
arguments: Optional[List[Union[Node, Output]]] = None,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
"""Create node object from provided description.
|
||||
@ -39,9 +41,24 @@ class NodeFactory(object):
|
||||
|
||||
@return Node object representing requested operator with attributes set.
|
||||
"""
|
||||
if arguments is None and attributes is None:
|
||||
node = self.factory.create(op_type_name)
|
||||
node._attr_cache = {}
|
||||
node._attr_cache_valid = False
|
||||
return node
|
||||
|
||||
if arguments is None and attributes is not None:
|
||||
raise UserInputError(
|
||||
'Error: cannot create "{}" op without arguments.'.format(
|
||||
op_type_name
|
||||
)
|
||||
)
|
||||
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
|
||||
assert arguments is not None
|
||||
|
||||
arguments = self._arguments_as_outputs(arguments)
|
||||
node = self.factory.create(op_type_name, arguments, attributes)
|
||||
|
||||
@ -57,7 +74,7 @@ class NodeFactory(object):
|
||||
# node.get_some_metric_attr_name()
|
||||
# node.set_some_metric_attr_name()
|
||||
# Please see test_dyn_attributes.py for more usage examples.
|
||||
all_attributes = node._get_attributes()
|
||||
all_attributes = node.get_attributes()
|
||||
for attr_name in all_attributes.keys():
|
||||
setattr(
|
||||
node,
|
||||
@ -134,7 +151,7 @@ class NodeFactory(object):
|
||||
@return The node attribute value.
|
||||
"""
|
||||
if not node._attr_cache_valid:
|
||||
node._attr_cache = node._get_attributes()
|
||||
node._attr_cache = node.get_attributes()
|
||||
node._attr_cache_valid = True
|
||||
return node._attr_cache[attr_name]
|
||||
|
||||
@ -146,5 +163,5 @@ class NodeFactory(object):
|
||||
@param attr_name: The attribute name.
|
||||
@param value: The new attribute value.
|
||||
"""
|
||||
node._set_attribute(attr_name, value)
|
||||
node.set_attribute(attr_name, value)
|
||||
node._attr_cache[attr_name] = value
|
||||
|
@ -265,15 +265,21 @@ void regclass_pyngraph_Node(py::module m) {
|
||||
node.def_property_readonly("version", &ngraph::Node::get_version);
|
||||
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||
|
||||
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||
node.def("get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||
util::DictAttributeSerializer dict_serializer(self);
|
||||
return dict_serializer.get_attributes();
|
||||
});
|
||||
node.def("_set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
||||
node.def("set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
||||
py::dict attr_dict;
|
||||
attr_dict[atr_name.c_str()] = value;
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
|
||||
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
|
||||
self->visit_attributes(dict_deserializer);
|
||||
});
|
||||
node.def("set_arguments", [](const std::shared_ptr<ngraph::Node>& self, const ngraph::OutputVector& arguments) {
|
||||
return self->set_arguments(arguments);
|
||||
});
|
||||
node.def("validate", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||
return self->constructor_validate_and_infer_types();
|
||||
});
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "dict_attribute_visitor.hpp"
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/util/variable.hpp"
|
||||
@ -51,6 +52,19 @@ public:
|
||||
return op_node;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> create(const std::string op_type_name) {
|
||||
std::shared_ptr<ngraph::Node> op_node = std::shared_ptr<ngraph::Node>(m_opset.create(op_type_name));
|
||||
|
||||
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
|
||||
NGRAPH_CHECK(!ngraph::op::is_constant(op_node),
|
||||
"Currently NodeFactory doesn't support Constant node: ",
|
||||
op_type_name);
|
||||
|
||||
NGRAPH_WARN << "Empty op created! Please assign inputs and attributes and run validate() before op is used.";
|
||||
|
||||
return op_node;
|
||||
}
|
||||
|
||||
private:
|
||||
const ngraph::OpSet& get_opset(std::string opset_ver) {
|
||||
std::locale loc;
|
||||
@ -90,7 +104,16 @@ void regclass_pyngraph_NodeFactory(py::module m) {
|
||||
node_factory.def(py::init());
|
||||
node_factory.def(py::init<std::string>());
|
||||
|
||||
node_factory.def("create", &NodeFactory::create);
|
||||
node_factory.def("create", [](NodeFactory& self, const std::string name) {
|
||||
return self.create(name);
|
||||
});
|
||||
node_factory.def("create",
|
||||
[](NodeFactory& self,
|
||||
const std::string name,
|
||||
const ngraph::OutputVector& arguments,
|
||||
const py::dict& attributes) {
|
||||
return self.create(name, arguments, attributes);
|
||||
});
|
||||
|
||||
node_factory.def("__repr__", [](const NodeFactory& self) {
|
||||
return "<NodeFactory>";
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import ngraph as ng
|
||||
from ngraph.exceptions import UserInputError
|
||||
from ngraph.utils.node_factory import NodeFactory
|
||||
from _pyngraph import NodeFactory as _NodeFactory
|
||||
|
||||
@ -45,7 +46,49 @@ def test_node_factory_topk():
|
||||
node = factory.create(
|
||||
"TopK", arguments, {"axis": 1, "mode": "max", "sort": "value"}
|
||||
)
|
||||
attributes = node.get_attributes()
|
||||
|
||||
assert node.get_type_name() == "TopK"
|
||||
assert node.get_output_size() == 2
|
||||
assert list(node.get_output_shape(0)) == [2, 3]
|
||||
assert attributes["axis"] == 1
|
||||
assert attributes["mode"] == "max"
|
||||
assert attributes["sort"] == "value"
|
||||
|
||||
|
||||
def test_node_factory_empty_topk():
|
||||
factory = NodeFactory("opset1")
|
||||
node = factory.create("TopK")
|
||||
|
||||
assert node.get_type_name() == "TopK"
|
||||
|
||||
|
||||
def test_node_factory_empty_topk_with_args_and_attrs():
|
||||
dtype = np.int32
|
||||
data = ng.parameter([2, 10], dtype=dtype, name="A")
|
||||
k = ng.constant(3, dtype=dtype, name="B")
|
||||
factory = NodeFactory("opset1")
|
||||
arguments = NodeFactory._arguments_as_outputs([data, k])
|
||||
node = factory.create("TopK", None, None)
|
||||
node.set_arguments(arguments)
|
||||
node.set_attribute("axis", 1)
|
||||
node.set_attribute("mode", "max")
|
||||
node.set_attribute("sort", "value")
|
||||
node.validate()
|
||||
|
||||
assert node.get_type_name() == "TopK"
|
||||
assert node.get_output_size() == 2
|
||||
assert list(node.get_output_shape(0)) == [2, 3]
|
||||
|
||||
|
||||
def test_node_factory_validate_missing_arguments():
|
||||
factory = NodeFactory("opset1")
|
||||
|
||||
try:
|
||||
factory.create(
|
||||
"TopK", None, {"axis": 1, "mode": "max", "sort": "value"}
|
||||
)
|
||||
except UserInputError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Validation of missing arguments has unexpectedly passed.")
|
||||
|
Loading…
Reference in New Issue
Block a user