Dynamic attribute getters and setters. (#964)

This commit is contained in:
Adam Osewski 2020-06-26 16:35:00 +02:00 committed by GitHub
parent 5aa9ffbfe3
commit d0be6b1d2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 865 additions and 220 deletions

View File

@ -182,6 +182,7 @@ sources = [
"pyngraph/axis_vector.cpp",
"pyngraph/coordinate.cpp",
"pyngraph/coordinate_diff.cpp",
"pyngraph/dict_attribute_visitor.cpp",
"pyngraph/dimension.cpp",
"pyngraph/function.cpp",
"pyngraph/node.cpp",

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Any, Dict, List, Optional
from _pyngraph import NodeFactory as _NodeFactory
@ -21,6 +22,8 @@ class NodeFactory(object):
) -> Node:
"""Create node object from provided description.
The user does not have to provide all node's attributes, but only required ones.
:param op_type_name: The operator type name.
:param arguments: The operator arguments.
:param attributes: The operator attributes.
@ -30,4 +33,90 @@ class NodeFactory(object):
if attributes is None:
attributes = {}
node = self.factory.create(op_type_name, arguments, attributes)
# Currently we don't support any attribute getters & setters for TensorIterator node.
if node.get_type_name() == "TensorIterator":
return node
# Set getters and setters for each node's attribute.
# node.get_attribute_name()
# node.set_attribute_name()
# For compound (with more than one level of nesting) attributes of form ie.:
# node.class_member_name.some_metric.attr_name:
# node.get_some_metric_attr_name()
# node.set_some_metric_attr_name()
# Please see test_dyn_attributes.py for more usage examples.
all_attributes = node._get_attributes()
for attr_name in all_attributes.keys():
setattr(node,
self._normalize_attr_name_getter(attr_name),
partial(NodeFactory._get_node_attr_value, node, attr_name))
setattr(node,
self._normalize_attr_name_setter(attr_name),
partial(NodeFactory._set_node_attr_value, node, attr_name))
# Setup helper members for caching attribute values.
# The cache would be lazily populated at first access attempt.
setattr(node, "_attr_cache", dict())
setattr(node, "_attr_cache_valid", bool(False))
return node
@staticmethod
def _normalize_attr_name(attr_name: str, prefix: str) -> str:
"""Normalizes attribute name.
:param attr_name: The attribute name.
:param prefix: The prefix to attach to attribute name.
:returns: The modified attribute name.
"""
# Trim first part of the name if there is only one level of attribute hierarchy.
if attr_name.count(".") == 1:
attr_name = attr_name[attr_name.find(".") + 1:]
return prefix + attr_name.replace(".", "_")
@classmethod
def _normalize_attr_name_getter(cls, attr_name: str) -> str:
"""Normalizes atr name to be suitable for getter function name.
:param attr_name: The attribute name to normalize
:returns: The appropriate getter function name.
"""
return cls._normalize_attr_name(attr_name, "get_")
@classmethod
def _normalize_attr_name_setter(cls, attr_name: str) -> str:
"""Normalizes atr name to be suitable for setter function name.
:param attr_name: The attribute name to normalize
:returns: The appropriate setter function name.
"""
return cls._normalize_attr_name(attr_name, "set_")
@staticmethod
def _get_node_attr_value(node: Node, attr_name: str) -> Any:
"""Gets provided node attribute value.
:param node: The node we retrieve attribute value from.
:param attr_name: The attribute name.
:returns: The node attribute value.
"""
if not node._attr_cache_valid:
node._attr_cache = node._get_attributes()
node._attr_cache_valid = True
return node._attr_cache[attr_name]
@staticmethod
def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None:
"""Sets the node attribute value.
:param node: The node we change attribute value for.
:param attr_name: The attribute name.
:param value: The new attribute value.
"""
node._set_attribute(attr_name, value)
node._attr_cache[attr_name] = value

View File

@ -0,0 +1,351 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// These are not used here, but needed in order to not violate ODR, since
// these are included in other translation units, and specialize some types.
// Related: https://github.com/pybind/pybind11/issues/1055
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "dict_attribute_visitor.hpp"
namespace py = pybind11;
util::DictAttributeDeserializer::DictAttributeDeserializer(const py::dict& attributes)
: m_attributes(attributes)
{
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<void>& adapter)
{
if (m_attributes.contains(name))
{
NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<bool>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<bool>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::string>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::string>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int8_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int8_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int16_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int16_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int32_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int32_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int64_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int64_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint8_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint16_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint32_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint64_t>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<float>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<float>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<double>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<double>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
}
}
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<float>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
}
}
void util::DictAttributeDeserializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<double>>& adapter)
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
}
}
util::DictAttributeSerializer::DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node)
{
node->visit_attributes(*this);
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<void>& adapter)
{
if (m_attributes.contains(name))
{
NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
}
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<bool>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::string>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int8_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int16_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int32_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<int64_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint8_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint16_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint32_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<uint64_t>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<float>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<double>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(
const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<float>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}
void util::DictAttributeSerializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<double>>& adapter)
{
m_attributes[name.c_str()] = adapter.get();
}

View File

@ -0,0 +1,158 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstdint>
#include <string>
#include <vector>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace util
{
class DictAttributeDeserializer : public ngraph::AttributeVisitor
{
public:
DictAttributeDeserializer(const py::dict& attributes);
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<void>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<bool>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::string>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int8_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int16_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int32_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int64_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint8_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint16_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint32_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint64_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<float>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<double>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
protected:
const py::dict& m_attributes;
};
class DictAttributeSerializer : public ngraph::AttributeVisitor
{
public:
DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node);
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<void>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<bool>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::string>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int8_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int16_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int32_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int64_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint8_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint16_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint32_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint64_t>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<float>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<double>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
template <typename T>
T get_attribute(const std::string& name)
{
NGRAPH_CHECK(m_attributes.contains(name),
"Couldn't find attribute \"",
name,
"\" in serialized node attribute dictionary.");
return m_attributes[name.c_str()].cast<T>();
}
py::dict get_attributes() const { return m_attributes; }
protected:
py::dict m_attributes;
};
}

View File

@ -14,21 +14,21 @@
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/node.hpp" // ngraph::Node
#include "ngraph/op/add.hpp" // ngraph::op::Add
#include "ngraph/op/divide.hpp" // ngraph::op::Divide
#include "ngraph/op/multiply.hpp" // ngraph::op::Multiply
#include "ngraph/op/subtract.hpp" // ngraph::op::Subtract
#include "dict_attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "pyngraph/node.hpp"
namespace py = pybind11;
void regclass_pyngraph_Node(py::module m)
{
py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node");
py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node", py::dynamic_attr());
node.doc() = "ngraph.impl.Node wraps ngraph::Node";
node.def("__add__",
[](const std::shared_ptr<ngraph::Node>& a, const std::shared_ptr<ngraph::Node> b) {
@ -79,4 +79,18 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_unique_name", &ngraph::Node::get_name);
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def_property_readonly("shape", &ngraph::Node::get_shape);
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
util::DictAttributeSerializer dict_serializer(self);
return dict_serializer.get_attributes();
});
node.def(
"_set_attribute",
[](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
py::dict attr_dict;
attr_dict[atr_name.c_str()] = value;
util::DictAttributeDeserializer dict_deserializer(attr_dict);
self->visit_attributes(dict_deserializer);
});
}

View File

@ -26,14 +26,11 @@
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "ngraph/attribute_visitor.hpp"
#include "dict_attribute_visitor.hpp"
#include "ngraph/check.hpp"
#include "ngraph/enum_names.hpp"
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/opsets/opset.hpp"
#include "ngraph/util.hpp"
#include "node_factory.hpp"
#include "tensor_iterator_builder.hpp"
@ -41,212 +38,6 @@ namespace py = pybind11;
namespace
{
class DictAttributeDeserializer : public ngraph::AttributeVisitor
{
public:
DictAttributeDeserializer(const py::dict& attributes)
: m_attributes(attributes)
{
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<void>& adapter) override
{
if (m_attributes.contains(name))
{
NGRAPH_CHECK(
false, "No AttributeVisitor support for accessing attribute named: ", name);
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<bool>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<bool>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::string>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::string>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int8_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int8_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int16_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int16_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int32_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int32_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<int64_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<int64_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint8_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint16_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint32_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<uint64_t>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<float>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<float>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<double>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<double>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<float>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
}
}
virtual void on_adapter(const std::string& name,
ngraph::ValueAccessor<std::vector<double>>& adapter) override
{
if (m_attributes.contains(name))
{
adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
}
}
protected:
const py::dict& m_attributes;
};
class NodeFactory
{
public:
@ -270,12 +61,12 @@ namespace
if (op_type_name == "TensorIterator")
{
// TODO: how to differentiate opsets?
// XXX: How to differentiate opsets?
return util::TensorIteratorBuilder(arguments, attributes)
.configure(std::static_pointer_cast<ngraph::op::TensorIterator>(op_node));
}
DictAttributeDeserializer visitor(attributes);
util::DictAttributeDeserializer visitor(attributes);
op_node->set_arguments(arguments);
op_node->visit_attributes(visitor);

View File

@ -508,7 +508,7 @@ def test_roi_pooling():
node = ng.roi_pooling(inputs, coords, [6, 6], 0.0625, "Max")
assert node.get_type_name() == "ROIPooling"
assert node.get_output_size() == 1
assert node.get_output_size() == [6, 6]
assert list(node.get_output_shape(0)) == [150, 3, 6, 6]
assert node.get_output_element_type(0) == Type.f32

View File

@ -0,0 +1,241 @@
# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import pytest
import ngraph as ng
@pytest.fixture()
def _proposal_node():
attributes = {
"attrs.base_size": np.uint16(1),
"attrs.pre_nms_topn": np.uint16(20),
"attrs.post_nms_topn": np.uint16(64),
"attrs.nms_thresh": np.float64(0.34),
"attrs.feat_stride": np.uint16(16),
"attrs.min_size": np.uint16(32),
"attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=np.float64),
"attrs.scale": np.array([2, 3, 3, 4], dtype=np.float64),
}
batch_size = 7
class_probs = ng.parameter([batch_size, 12, 34, 62], np.float64, "class_probs")
class_logits = ng.parameter([batch_size, 24, 34, 62], np.float64, "class_logits")
image_shape = ng.parameter([3], np.float64, "image_shape")
return ng.proposal(class_probs, class_logits, image_shape, attributes)
def test_dynamic_attributes_softmax():
axis = 2
data = ng.parameter([1, 2, 3, 4], np.float32, "data_in")
node = ng.softmax(data, axis)
assert node.get_axis() == axis
node.set_axis(3)
assert node.get_axis() == 3
@pytest.mark.parametrize(
"int_dtype, fp_dtype",
[
(np.int8, np.float32),
(np.int16, np.float32),
(np.int32, np.float32),
(np.int64, np.float32),
(np.uint8, np.float32),
(np.uint16, np.float32),
(np.uint32, np.float32),
(np.uint64, np.float32),
(np.int32, np.float16),
(np.int32, np.float64),
],
)
def test_dynamic_get_attribute_value(int_dtype, fp_dtype):
attributes = {
"attrs.num_classes": int_dtype(85),
"attrs.background_label_id": int_dtype(13),
"attrs.top_k": int_dtype(16),
"attrs.variance_encoded_in_target": True,
"attrs.keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype),
"attrs.code_type": "pytorch.some_parameter_name",
"attrs.share_location": False,
"attrs.nms_threshold": fp_dtype(0.645),
"attrs.confidence_threshold": fp_dtype(0.111),
"attrs.clip_after_nms": True,
"attrs.clip_before_nms": False,
"attrs.decrease_label_id": True,
"attrs.normalized": True,
"attrs.input_height": int_dtype(86),
"attrs.input_width": int_dtype(79),
"attrs.objectness_score": fp_dtype(0.77),
}
box_logits = ng.parameter([4, 1, 5, 5], fp_dtype, "box_logits")
class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "class_preds")
proposals = ng.parameter([2, 1, 4, 5], fp_dtype, "proposals")
aux_class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_class_preds")
aux_box_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_box_preds")
node = ng.detection_output(
box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds
)
assert node.get_num_classes() == int_dtype(85)
assert node.get_background_label_id() == int_dtype(13)
assert node.get_top_k() == int_dtype(16)
assert node.get_variance_encoded_in_target() == True
assert np.all(np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype)))
assert node.get_code_type() == "pytorch.some_parameter_name"
assert node.get_share_location() == False
assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645))
assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111))
assert node.get_clip_after_nms() == True
assert node.get_clip_before_nms() == False
assert node.get_decrease_label_id() == True
assert node.get_normalized() == True
assert node.get_input_height() == int_dtype(86)
assert node.get_input_width() == int_dtype(79)
assert np.isclose(node.get_objectness_score(), fp_dtype(0.77))
assert node.get_num_classes() == int_dtype(85)
@pytest.mark.parametrize(
"int_dtype, fp_dtype",
[
(np.uint8, np.float32),
(np.uint16, np.float32),
(np.uint32, np.float32),
(np.uint64, np.float32),
(np.uint32, np.float16),
(np.uint32, np.float64),
],
)
def test_dynamic_set_attribute_value(int_dtype, fp_dtype):
attributes = {
"attrs.base_size": int_dtype(1),
"attrs.pre_nms_topn": int_dtype(20),
"attrs.post_nms_topn": int_dtype(64),
"attrs.nms_thresh": fp_dtype(0.34),
"attrs.feat_stride": int_dtype(16),
"attrs.min_size": int_dtype(32),
"attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=fp_dtype),
"attrs.scale": np.array([2, 3, 3, 4], dtype=fp_dtype),
}
batch_size = 7
class_probs = ng.parameter([batch_size, 12, 34, 62], fp_dtype, "class_probs")
class_logits = ng.parameter([batch_size, 24, 34, 62], fp_dtype, "class_logits")
image_shape = ng.parameter([3], fp_dtype, "image_shape")
node = ng.proposal(class_probs, class_logits, image_shape, attributes)
node.set_base_size(int_dtype(15))
node.set_pre_nms_topn(int_dtype(7))
node.set_post_nms_topn(int_dtype(33))
node.set_nms_thresh(fp_dtype(1.55))
node.set_feat_stride(int_dtype(8))
node.set_min_size(int_dtype(123))
node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
node.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
node.set_clip_before_nms(True)
node.set_clip_after_nms(True)
node.set_normalize(True)
node.set_box_size_scale(fp_dtype(1.34))
node.set_box_coordinate_scale(fp_dtype(0.88))
node.set_framework("OpenVINO")
assert node.get_base_size() == int_dtype(15)
assert node.get_pre_nms_topn() == int_dtype(7)
assert node.get_post_nms_topn() == int_dtype(33)
assert np.isclose(node.get_nms_thresh(), fp_dtype(1.55))
assert node.get_feat_stride() == int_dtype(8)
assert node.get_min_size() == int_dtype(123)
assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
assert node.get_clip_before_nms() == True
assert node.get_clip_after_nms() == True
assert node.get_normalize() == True
assert np.isclose(node.get_box_size_scale(), fp_dtype(1.34))
assert np.isclose(node.get_box_coordinate_scale(), fp_dtype(0.88))
assert node.get_framework() == "OpenVINO"
def test_dynamic_attr_cache(_proposal_node):
node = _proposal_node
assert not node._attr_cache_valid
node.set_nms_thresh(1.3453678102)
assert not node._attr_cache_valid
assert np.isclose(node.get_nms_thresh(), np.float64(1.3453678102))
assert node._attr_cache_valid
def test_dynamic_attr_transitivity(_proposal_node):
node = _proposal_node
node2 = node
node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
assert np.allclose(node2.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
node2.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
assert np.allclose(node2.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
def test_dynamic_attributes_simple():
batch_size = 1
input_size = 16
hidden_size = 128
X_shape = [batch_size, input_size]
H_t_shape = [batch_size, hidden_size]
W_shape = [3 * hidden_size, input_size]
R_shape = [3 * hidden_size, hidden_size]
B_shape = [4 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=np.float32)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=np.float32)
parameter_W = ng.parameter(W_shape, name="W", dtype=np.float32)
parameter_R = ng.parameter(R_shape, name="R", dtype=np.float32)
parameter_B = ng.parameter(B_shape, name="B", dtype=np.float32)
activations = ["tanh", "relu"]
activations_alpha = [1.0, 2.0]
activations_beta = [1.0, 2.0]
clip = 0.5
linear_before_reset = True
node = ng.gru_cell(
parameter_X,
parameter_H_t,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
activations,
activations_alpha,
activations_beta,
clip,
linear_before_reset,
)
assert node.get_hidden_size() == hidden_size
assert all(map(lambda x, y: x == y, node.get_activations(), activations))
assert all(np.equal(node.get_activations_alpha(), activations_alpha))
assert all(np.equal(node.get_activations_beta(), activations_beta))
assert node.get_linear_before_reset() == linear_before_reset
assert np.isclose(node.get_clip(), clip)