Add default constructor for op. (#7368)
* Add default constructor for op. * Apply correct format. * Apply correct format for py files. * Remove underscore prefix, add create overlaod method. * Resolve use cases for parameter arguments in create fucntion. * Resolve myoy issue. * Remove underscore for getter and setter functions. * Restore check when arguments is None. * Add tests covering raising errors and get_attributes() function.
This commit is contained in:
parent
66bad412a4
commit
c50c0d59bb
@ -169,13 +169,13 @@ You can also set attribute values using corresponding setter methods, for exampl
|
|||||||
node.set_axis(0)
|
node.set_axis(0)
|
||||||
```
|
```
|
||||||
|
|
||||||
Currently, you can get all attributes of a node using the `_get_attributes` method. Please note that this is an internal API method and may change in future versions of OpenVINO.
|
Currently, you can get all attributes of a node using the `get_attributes` method.
|
||||||
|
|
||||||
The following code displays all attributes for all nodes in a function:
|
The following code displays all attributes for all nodes in a function:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
for node in function.get_ordered_ops():
|
for node in function.get_ordered_ops():
|
||||||
attributes = node._get_attributes()
|
attributes = node.get_attributes()
|
||||||
if(attributes):
|
if(attributes):
|
||||||
print('Operation {} of type {} has attributes:'.format(node.get_friendly_name(), node.get_type_name()))
|
print('Operation {} of type {} has attributes:'.format(node.get_friendly_name(), node.get_type_name()))
|
||||||
for attr, value in attributes.items():
|
for attr, value in attributes.items():
|
||||||
|
@ -10,6 +10,8 @@ from _pyngraph import NodeFactory as _NodeFactory
|
|||||||
|
|
||||||
from ngraph.impl import Node, Output
|
from ngraph.impl import Node, Output
|
||||||
|
|
||||||
|
from ngraph.exceptions import UserInputError
|
||||||
|
|
||||||
DEFAULT_OPSET = "opset8"
|
DEFAULT_OPSET = "opset8"
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +28,7 @@ class NodeFactory(object):
|
|||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
op_type_name: str,
|
op_type_name: str,
|
||||||
arguments: List[Union[Node, Output]],
|
arguments: Optional[List[Union[Node, Output]]] = None,
|
||||||
attributes: Optional[Dict[str, Any]] = None,
|
attributes: Optional[Dict[str, Any]] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Create node object from provided description.
|
"""Create node object from provided description.
|
||||||
@ -39,9 +41,24 @@ class NodeFactory(object):
|
|||||||
|
|
||||||
@return Node object representing requested operator with attributes set.
|
@return Node object representing requested operator with attributes set.
|
||||||
"""
|
"""
|
||||||
|
if arguments is None and attributes is None:
|
||||||
|
node = self.factory.create(op_type_name)
|
||||||
|
node._attr_cache = {}
|
||||||
|
node._attr_cache_valid = False
|
||||||
|
return node
|
||||||
|
|
||||||
|
if arguments is None and attributes is not None:
|
||||||
|
raise UserInputError(
|
||||||
|
'Error: cannot create "{}" op without arguments.'.format(
|
||||||
|
op_type_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if attributes is None:
|
if attributes is None:
|
||||||
attributes = {}
|
attributes = {}
|
||||||
|
|
||||||
|
assert arguments is not None
|
||||||
|
|
||||||
arguments = self._arguments_as_outputs(arguments)
|
arguments = self._arguments_as_outputs(arguments)
|
||||||
node = self.factory.create(op_type_name, arguments, attributes)
|
node = self.factory.create(op_type_name, arguments, attributes)
|
||||||
|
|
||||||
@ -57,7 +74,7 @@ class NodeFactory(object):
|
|||||||
# node.get_some_metric_attr_name()
|
# node.get_some_metric_attr_name()
|
||||||
# node.set_some_metric_attr_name()
|
# node.set_some_metric_attr_name()
|
||||||
# Please see test_dyn_attributes.py for more usage examples.
|
# Please see test_dyn_attributes.py for more usage examples.
|
||||||
all_attributes = node._get_attributes()
|
all_attributes = node.get_attributes()
|
||||||
for attr_name in all_attributes.keys():
|
for attr_name in all_attributes.keys():
|
||||||
setattr(
|
setattr(
|
||||||
node,
|
node,
|
||||||
@ -134,7 +151,7 @@ class NodeFactory(object):
|
|||||||
@return The node attribute value.
|
@return The node attribute value.
|
||||||
"""
|
"""
|
||||||
if not node._attr_cache_valid:
|
if not node._attr_cache_valid:
|
||||||
node._attr_cache = node._get_attributes()
|
node._attr_cache = node.get_attributes()
|
||||||
node._attr_cache_valid = True
|
node._attr_cache_valid = True
|
||||||
return node._attr_cache[attr_name]
|
return node._attr_cache[attr_name]
|
||||||
|
|
||||||
@ -146,5 +163,5 @@ class NodeFactory(object):
|
|||||||
@param attr_name: The attribute name.
|
@param attr_name: The attribute name.
|
||||||
@param value: The new attribute value.
|
@param value: The new attribute value.
|
||||||
"""
|
"""
|
||||||
node._set_attribute(attr_name, value)
|
node.set_attribute(attr_name, value)
|
||||||
node._attr_cache[attr_name] = value
|
node._attr_cache[attr_name] = value
|
||||||
|
@ -265,15 +265,21 @@ void regclass_pyngraph_Node(py::module m) {
|
|||||||
node.def_property_readonly("version", &ngraph::Node::get_version);
|
node.def_property_readonly("version", &ngraph::Node::get_version);
|
||||||
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
node.def_property("friendly_name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||||
|
|
||||||
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
node.def("get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||||
util::DictAttributeSerializer dict_serializer(self);
|
util::DictAttributeSerializer dict_serializer(self);
|
||||||
return dict_serializer.get_attributes();
|
return dict_serializer.get_attributes();
|
||||||
});
|
});
|
||||||
node.def("_set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
node.def("set_attribute", [](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
||||||
py::dict attr_dict;
|
py::dict attr_dict;
|
||||||
attr_dict[atr_name.c_str()] = value;
|
attr_dict[atr_name.c_str()] = value;
|
||||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
|
||||||
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
|
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
|
||||||
self->visit_attributes(dict_deserializer);
|
self->visit_attributes(dict_deserializer);
|
||||||
});
|
});
|
||||||
|
node.def("set_arguments", [](const std::shared_ptr<ngraph::Node>& self, const ngraph::OutputVector& arguments) {
|
||||||
|
return self->set_arguments(arguments);
|
||||||
|
});
|
||||||
|
node.def("validate", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||||
|
return self->constructor_validate_and_infer_types();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
#include "dict_attribute_visitor.hpp"
|
#include "dict_attribute_visitor.hpp"
|
||||||
#include "ngraph/check.hpp"
|
#include "ngraph/check.hpp"
|
||||||
#include "ngraph/except.hpp"
|
#include "ngraph/except.hpp"
|
||||||
|
#include "ngraph/log.hpp"
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/op/util/op_types.hpp"
|
#include "ngraph/op/util/op_types.hpp"
|
||||||
#include "ngraph/op/util/variable.hpp"
|
#include "ngraph/op/util/variable.hpp"
|
||||||
@ -51,6 +52,19 @@ public:
|
|||||||
return op_node;
|
return op_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> create(const std::string op_type_name) {
|
||||||
|
std::shared_ptr<ngraph::Node> op_node = std::shared_ptr<ngraph::Node>(m_opset.create(op_type_name));
|
||||||
|
|
||||||
|
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
|
||||||
|
NGRAPH_CHECK(!ngraph::op::is_constant(op_node),
|
||||||
|
"Currently NodeFactory doesn't support Constant node: ",
|
||||||
|
op_type_name);
|
||||||
|
|
||||||
|
NGRAPH_WARN << "Empty op created! Please assign inputs and attributes and run validate() before op is used.";
|
||||||
|
|
||||||
|
return op_node;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const ngraph::OpSet& get_opset(std::string opset_ver) {
|
const ngraph::OpSet& get_opset(std::string opset_ver) {
|
||||||
std::locale loc;
|
std::locale loc;
|
||||||
@ -90,7 +104,16 @@ void regclass_pyngraph_NodeFactory(py::module m) {
|
|||||||
node_factory.def(py::init());
|
node_factory.def(py::init());
|
||||||
node_factory.def(py::init<std::string>());
|
node_factory.def(py::init<std::string>());
|
||||||
|
|
||||||
node_factory.def("create", &NodeFactory::create);
|
node_factory.def("create", [](NodeFactory& self, const std::string name) {
|
||||||
|
return self.create(name);
|
||||||
|
});
|
||||||
|
node_factory.def("create",
|
||||||
|
[](NodeFactory& self,
|
||||||
|
const std::string name,
|
||||||
|
const ngraph::OutputVector& arguments,
|
||||||
|
const py::dict& attributes) {
|
||||||
|
return self.create(name, arguments, attributes);
|
||||||
|
});
|
||||||
|
|
||||||
node_factory.def("__repr__", [](const NodeFactory& self) {
|
node_factory.def("__repr__", [](const NodeFactory& self) {
|
||||||
return "<NodeFactory>";
|
return "<NodeFactory>";
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ngraph as ng
|
import ngraph as ng
|
||||||
|
from ngraph.exceptions import UserInputError
|
||||||
from ngraph.utils.node_factory import NodeFactory
|
from ngraph.utils.node_factory import NodeFactory
|
||||||
from _pyngraph import NodeFactory as _NodeFactory
|
from _pyngraph import NodeFactory as _NodeFactory
|
||||||
|
|
||||||
@ -45,7 +46,49 @@ def test_node_factory_topk():
|
|||||||
node = factory.create(
|
node = factory.create(
|
||||||
"TopK", arguments, {"axis": 1, "mode": "max", "sort": "value"}
|
"TopK", arguments, {"axis": 1, "mode": "max", "sort": "value"}
|
||||||
)
|
)
|
||||||
|
attributes = node.get_attributes()
|
||||||
|
|
||||||
assert node.get_type_name() == "TopK"
|
assert node.get_type_name() == "TopK"
|
||||||
assert node.get_output_size() == 2
|
assert node.get_output_size() == 2
|
||||||
assert list(node.get_output_shape(0)) == [2, 3]
|
assert list(node.get_output_shape(0)) == [2, 3]
|
||||||
|
assert attributes["axis"] == 1
|
||||||
|
assert attributes["mode"] == "max"
|
||||||
|
assert attributes["sort"] == "value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_factory_empty_topk():
|
||||||
|
factory = NodeFactory("opset1")
|
||||||
|
node = factory.create("TopK")
|
||||||
|
|
||||||
|
assert node.get_type_name() == "TopK"
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_factory_empty_topk_with_args_and_attrs():
|
||||||
|
dtype = np.int32
|
||||||
|
data = ng.parameter([2, 10], dtype=dtype, name="A")
|
||||||
|
k = ng.constant(3, dtype=dtype, name="B")
|
||||||
|
factory = NodeFactory("opset1")
|
||||||
|
arguments = NodeFactory._arguments_as_outputs([data, k])
|
||||||
|
node = factory.create("TopK", None, None)
|
||||||
|
node.set_arguments(arguments)
|
||||||
|
node.set_attribute("axis", 1)
|
||||||
|
node.set_attribute("mode", "max")
|
||||||
|
node.set_attribute("sort", "value")
|
||||||
|
node.validate()
|
||||||
|
|
||||||
|
assert node.get_type_name() == "TopK"
|
||||||
|
assert node.get_output_size() == 2
|
||||||
|
assert list(node.get_output_shape(0)) == [2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_factory_validate_missing_arguments():
|
||||||
|
factory = NodeFactory("opset1")
|
||||||
|
|
||||||
|
try:
|
||||||
|
factory.create(
|
||||||
|
"TopK", None, {"axis": 1, "mode": "max", "sort": "value"}
|
||||||
|
)
|
||||||
|
except UserInputError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise AssertionError("Validation of missing arguments has unexpectedly passed.")
|
||||||
|
Loading…
Reference in New Issue
Block a user