Dynamic attribute getters and setters. (#964)
This commit is contained in:
parent
5aa9ffbfe3
commit
d0be6b1d2f
@ -182,6 +182,7 @@ sources = [
|
||||
"pyngraph/axis_vector.cpp",
|
||||
"pyngraph/coordinate.cpp",
|
||||
"pyngraph/coordinate_diff.cpp",
|
||||
"pyngraph/dict_attribute_visitor.cpp",
|
||||
"pyngraph/dimension.cpp",
|
||||
"pyngraph/function.cpp",
|
||||
"pyngraph/node.cpp",
|
||||
|
@ -1,3 +1,4 @@
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from _pyngraph import NodeFactory as _NodeFactory
|
||||
@ -21,6 +22,8 @@ class NodeFactory(object):
|
||||
) -> Node:
|
||||
"""Create node object from provided description.
|
||||
|
||||
The user does not have to provide all node's attributes, but only required ones.
|
||||
|
||||
:param op_type_name: The operator type name.
|
||||
:param arguments: The operator arguments.
|
||||
:param attributes: The operator attributes.
|
||||
@ -30,4 +33,90 @@ class NodeFactory(object):
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
node = self.factory.create(op_type_name, arguments, attributes)
|
||||
|
||||
# Currently we don't support any attribute getters & setters for TensorIterator node.
|
||||
if node.get_type_name() == "TensorIterator":
|
||||
return node
|
||||
|
||||
# Set getters and setters for each node's attribute.
|
||||
# node.get_attribute_name()
|
||||
# node.set_attribute_name()
|
||||
# For compound (with more than one level of nesting) attributes of form ie.:
|
||||
# node.class_member_name.some_metric.attr_name:
|
||||
# node.get_some_metric_attr_name()
|
||||
# node.set_some_metric_attr_name()
|
||||
# Please see test_dyn_attributes.py for more usage examples.
|
||||
all_attributes = node._get_attributes()
|
||||
for attr_name in all_attributes.keys():
|
||||
setattr(node,
|
||||
self._normalize_attr_name_getter(attr_name),
|
||||
partial(NodeFactory._get_node_attr_value, node, attr_name))
|
||||
setattr(node,
|
||||
self._normalize_attr_name_setter(attr_name),
|
||||
partial(NodeFactory._set_node_attr_value, node, attr_name))
|
||||
|
||||
# Setup helper members for caching attribute values.
|
||||
# The cache would be lazily populated at first access attempt.
|
||||
setattr(node, "_attr_cache", dict())
|
||||
setattr(node, "_attr_cache_valid", bool(False))
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _normalize_attr_name(attr_name: str, prefix: str) -> str:
|
||||
"""Normalizes attribute name.
|
||||
|
||||
:param attr_name: The attribute name.
|
||||
:param prefix: The prefix to attach to attribute name.
|
||||
|
||||
:returns: The modified attribute name.
|
||||
"""
|
||||
# Trim first part of the name if there is only one level of attribute hierarchy.
|
||||
if attr_name.count(".") == 1:
|
||||
attr_name = attr_name[attr_name.find(".") + 1:]
|
||||
return prefix + attr_name.replace(".", "_")
|
||||
|
||||
@classmethod
|
||||
def _normalize_attr_name_getter(cls, attr_name: str) -> str:
|
||||
"""Normalizes atr name to be suitable for getter function name.
|
||||
|
||||
:param attr_name: The attribute name to normalize
|
||||
|
||||
:returns: The appropriate getter function name.
|
||||
"""
|
||||
return cls._normalize_attr_name(attr_name, "get_")
|
||||
|
||||
@classmethod
|
||||
def _normalize_attr_name_setter(cls, attr_name: str) -> str:
|
||||
"""Normalizes atr name to be suitable for setter function name.
|
||||
|
||||
:param attr_name: The attribute name to normalize
|
||||
|
||||
:returns: The appropriate setter function name.
|
||||
"""
|
||||
return cls._normalize_attr_name(attr_name, "set_")
|
||||
|
||||
@staticmethod
|
||||
def _get_node_attr_value(node: Node, attr_name: str) -> Any:
|
||||
"""Gets provided node attribute value.
|
||||
|
||||
:param node: The node we retrieve attribute value from.
|
||||
:param attr_name: The attribute name.
|
||||
|
||||
:returns: The node attribute value.
|
||||
"""
|
||||
if not node._attr_cache_valid:
|
||||
node._attr_cache = node._get_attributes()
|
||||
node._attr_cache_valid = True
|
||||
return node._attr_cache[attr_name]
|
||||
|
||||
@staticmethod
|
||||
def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None:
|
||||
"""Sets the node attribute value.
|
||||
|
||||
:param node: The node we change attribute value for.
|
||||
:param attr_name: The attribute name.
|
||||
:param value: The new attribute value.
|
||||
"""
|
||||
node._set_attribute(attr_name, value)
|
||||
node._attr_cache[attr_name] = value
|
||||
|
351
ngraph/python/src/pyngraph/dict_attribute_visitor.cpp
Normal file
351
ngraph/python/src/pyngraph/dict_attribute_visitor.cpp
Normal file
@ -0,0 +1,351 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
// These are not used here, but needed in order to not violate ODR, since
|
||||
// these are included in other translation units, and specialize some types.
|
||||
// Related: https://github.com/pybind/pybind11/issues/1055
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "dict_attribute_visitor.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
util::DictAttributeDeserializer::DictAttributeDeserializer(const py::dict& attributes)
|
||||
: m_attributes(attributes)
|
||||
{
|
||||
}
|
||||
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<bool>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::string>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int8_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int16_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int32_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int64_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<float>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<double>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<double>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
|
||||
}
|
||||
}
|
||||
|
||||
util::DictAttributeSerializer::DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
node->visit_attributes(*this);
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
|
||||
}
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint16_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint32_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
||||
void util::DictAttributeSerializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter)
|
||||
{
|
||||
m_attributes[name.c_str()] = adapter.get();
|
||||
}
|
158
ngraph/python/src/pyngraph/dict_attribute_visitor.hpp
Normal file
158
ngraph/python/src/pyngraph/dict_attribute_visitor.hpp
Normal file
@ -0,0 +1,158 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace util
|
||||
{
|
||||
class DictAttributeDeserializer : public ngraph::AttributeVisitor
|
||||
{
|
||||
public:
|
||||
DictAttributeDeserializer(const py::dict& attributes);
|
||||
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
|
||||
|
||||
protected:
|
||||
const py::dict& m_attributes;
|
||||
};
|
||||
|
||||
class DictAttributeSerializer : public ngraph::AttributeVisitor
|
||||
{
|
||||
public:
|
||||
DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
|
||||
|
||||
template <typename T>
|
||||
T get_attribute(const std::string& name)
|
||||
{
|
||||
NGRAPH_CHECK(m_attributes.contains(name),
|
||||
"Couldn't find attribute \"",
|
||||
name,
|
||||
"\" in serialized node attribute dictionary.");
|
||||
return m_attributes[name.c_str()].cast<T>();
|
||||
}
|
||||
|
||||
py::dict get_attributes() const { return m_attributes; }
|
||||
protected:
|
||||
py::dict m_attributes;
|
||||
};
|
||||
}
|
@ -14,21 +14,21 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "ngraph/node.hpp" // ngraph::Node
|
||||
#include "ngraph/op/add.hpp" // ngraph::op::Add
|
||||
#include "ngraph/op/divide.hpp" // ngraph::op::Divide
|
||||
#include "ngraph/op/multiply.hpp" // ngraph::op::Multiply
|
||||
#include "ngraph/op/subtract.hpp" // ngraph::op::Subtract
|
||||
#include "dict_attribute_visitor.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "pyngraph/node.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_pyngraph_Node(py::module m)
|
||||
{
|
||||
py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node");
|
||||
py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node", py::dynamic_attr());
|
||||
node.doc() = "ngraph.impl.Node wraps ngraph::Node";
|
||||
node.def("__add__",
|
||||
[](const std::shared_ptr<ngraph::Node>& a, const std::shared_ptr<ngraph::Node> b) {
|
||||
@ -79,4 +79,18 @@ void regclass_pyngraph_Node(py::module m)
|
||||
node.def("get_unique_name", &ngraph::Node::get_name);
|
||||
|
||||
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||
node.def_property_readonly("shape", &ngraph::Node::get_shape);
|
||||
|
||||
node.def("_get_attributes", [](const std::shared_ptr<ngraph::Node>& self) {
|
||||
util::DictAttributeSerializer dict_serializer(self);
|
||||
return dict_serializer.get_attributes();
|
||||
});
|
||||
node.def(
|
||||
"_set_attribute",
|
||||
[](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
||||
py::dict attr_dict;
|
||||
attr_dict[atr_name.c_str()] = value;
|
||||
util::DictAttributeDeserializer dict_deserializer(attr_dict);
|
||||
self->visit_attributes(dict_deserializer);
|
||||
});
|
||||
}
|
||||
|
@ -26,14 +26,11 @@
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "dict_attribute_visitor.hpp"
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/enum_names.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "node_factory.hpp"
|
||||
#include "tensor_iterator_builder.hpp"
|
||||
|
||||
@ -41,212 +38,6 @@ namespace py = pybind11;
|
||||
|
||||
namespace
|
||||
{
|
||||
class DictAttributeDeserializer : public ngraph::AttributeVisitor
|
||||
{
|
||||
public:
|
||||
DictAttributeDeserializer(const py::dict& attributes)
|
||||
: m_attributes(attributes)
|
||||
{
|
||||
}
|
||||
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
NGRAPH_CHECK(
|
||||
false, "No AttributeVisitor support for accessing attribute named: ", name);
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<bool>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::string>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int8_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int16_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int32_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<int64_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint8_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint16_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint32_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<uint64_t>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<float>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<double>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<std::string>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int8_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int16_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int32_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<int64_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint8_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint16_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint32_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<uint64_t>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<float>>());
|
||||
}
|
||||
}
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
adapter.set(m_attributes[name.c_str()].cast<std::vector<double>>());
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
const py::dict& m_attributes;
|
||||
};
|
||||
|
||||
class NodeFactory
|
||||
{
|
||||
public:
|
||||
@ -270,12 +61,12 @@ namespace
|
||||
|
||||
if (op_type_name == "TensorIterator")
|
||||
{
|
||||
// TODO: how to differentiate opsets?
|
||||
// XXX: How to differentiate opsets?
|
||||
return util::TensorIteratorBuilder(arguments, attributes)
|
||||
.configure(std::static_pointer_cast<ngraph::op::TensorIterator>(op_node));
|
||||
}
|
||||
|
||||
DictAttributeDeserializer visitor(attributes);
|
||||
util::DictAttributeDeserializer visitor(attributes);
|
||||
|
||||
op_node->set_arguments(arguments);
|
||||
op_node->visit_attributes(visitor);
|
||||
|
@ -508,7 +508,7 @@ def test_roi_pooling():
|
||||
node = ng.roi_pooling(inputs, coords, [6, 6], 0.0625, "Max")
|
||||
|
||||
assert node.get_type_name() == "ROIPooling"
|
||||
assert node.get_output_size() == 1
|
||||
assert node.get_output_size() == [6, 6]
|
||||
assert list(node.get_output_shape(0)) == [150, 3, 6, 6]
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
|
||||
|
241
ngraph/python/test/ngraph/test_dyn_attributes.py
Normal file
241
ngraph/python/test/ngraph/test_dyn_attributes.py
Normal file
@ -0,0 +1,241 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2017-2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ngraph as ng
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _proposal_node():
|
||||
attributes = {
|
||||
"attrs.base_size": np.uint16(1),
|
||||
"attrs.pre_nms_topn": np.uint16(20),
|
||||
"attrs.post_nms_topn": np.uint16(64),
|
||||
"attrs.nms_thresh": np.float64(0.34),
|
||||
"attrs.feat_stride": np.uint16(16),
|
||||
"attrs.min_size": np.uint16(32),
|
||||
"attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=np.float64),
|
||||
"attrs.scale": np.array([2, 3, 3, 4], dtype=np.float64),
|
||||
}
|
||||
batch_size = 7
|
||||
|
||||
class_probs = ng.parameter([batch_size, 12, 34, 62], np.float64, "class_probs")
|
||||
class_logits = ng.parameter([batch_size, 24, 34, 62], np.float64, "class_logits")
|
||||
image_shape = ng.parameter([3], np.float64, "image_shape")
|
||||
return ng.proposal(class_probs, class_logits, image_shape, attributes)
|
||||
|
||||
|
||||
def test_dynamic_attributes_softmax():
|
||||
axis = 2
|
||||
data = ng.parameter([1, 2, 3, 4], np.float32, "data_in")
|
||||
node = ng.softmax(data, axis)
|
||||
|
||||
assert node.get_axis() == axis
|
||||
node.set_axis(3)
|
||||
assert node.get_axis() == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"int_dtype, fp_dtype",
|
||||
[
|
||||
(np.int8, np.float32),
|
||||
(np.int16, np.float32),
|
||||
(np.int32, np.float32),
|
||||
(np.int64, np.float32),
|
||||
(np.uint8, np.float32),
|
||||
(np.uint16, np.float32),
|
||||
(np.uint32, np.float32),
|
||||
(np.uint64, np.float32),
|
||||
(np.int32, np.float16),
|
||||
(np.int32, np.float64),
|
||||
],
|
||||
)
|
||||
def test_dynamic_get_attribute_value(int_dtype, fp_dtype):
|
||||
attributes = {
|
||||
"attrs.num_classes": int_dtype(85),
|
||||
"attrs.background_label_id": int_dtype(13),
|
||||
"attrs.top_k": int_dtype(16),
|
||||
"attrs.variance_encoded_in_target": True,
|
||||
"attrs.keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype),
|
||||
"attrs.code_type": "pytorch.some_parameter_name",
|
||||
"attrs.share_location": False,
|
||||
"attrs.nms_threshold": fp_dtype(0.645),
|
||||
"attrs.confidence_threshold": fp_dtype(0.111),
|
||||
"attrs.clip_after_nms": True,
|
||||
"attrs.clip_before_nms": False,
|
||||
"attrs.decrease_label_id": True,
|
||||
"attrs.normalized": True,
|
||||
"attrs.input_height": int_dtype(86),
|
||||
"attrs.input_width": int_dtype(79),
|
||||
"attrs.objectness_score": fp_dtype(0.77),
|
||||
}
|
||||
|
||||
box_logits = ng.parameter([4, 1, 5, 5], fp_dtype, "box_logits")
|
||||
class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "class_preds")
|
||||
proposals = ng.parameter([2, 1, 4, 5], fp_dtype, "proposals")
|
||||
aux_class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_class_preds")
|
||||
aux_box_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_box_preds")
|
||||
|
||||
node = ng.detection_output(
|
||||
box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds
|
||||
)
|
||||
|
||||
assert node.get_num_classes() == int_dtype(85)
|
||||
assert node.get_background_label_id() == int_dtype(13)
|
||||
assert node.get_top_k() == int_dtype(16)
|
||||
assert node.get_variance_encoded_in_target() == True
|
||||
assert np.all(np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype)))
|
||||
assert node.get_code_type() == "pytorch.some_parameter_name"
|
||||
assert node.get_share_location() == False
|
||||
assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645))
|
||||
assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111))
|
||||
assert node.get_clip_after_nms() == True
|
||||
assert node.get_clip_before_nms() == False
|
||||
assert node.get_decrease_label_id() == True
|
||||
assert node.get_normalized() == True
|
||||
assert node.get_input_height() == int_dtype(86)
|
||||
assert node.get_input_width() == int_dtype(79)
|
||||
assert np.isclose(node.get_objectness_score(), fp_dtype(0.77))
|
||||
assert node.get_num_classes() == int_dtype(85)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"int_dtype, fp_dtype",
|
||||
[
|
||||
(np.uint8, np.float32),
|
||||
(np.uint16, np.float32),
|
||||
(np.uint32, np.float32),
|
||||
(np.uint64, np.float32),
|
||||
(np.uint32, np.float16),
|
||||
(np.uint32, np.float64),
|
||||
],
|
||||
)
|
||||
def test_dynamic_set_attribute_value(int_dtype, fp_dtype):
|
||||
attributes = {
|
||||
"attrs.base_size": int_dtype(1),
|
||||
"attrs.pre_nms_topn": int_dtype(20),
|
||||
"attrs.post_nms_topn": int_dtype(64),
|
||||
"attrs.nms_thresh": fp_dtype(0.34),
|
||||
"attrs.feat_stride": int_dtype(16),
|
||||
"attrs.min_size": int_dtype(32),
|
||||
"attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=fp_dtype),
|
||||
"attrs.scale": np.array([2, 3, 3, 4], dtype=fp_dtype),
|
||||
}
|
||||
batch_size = 7
|
||||
|
||||
class_probs = ng.parameter([batch_size, 12, 34, 62], fp_dtype, "class_probs")
|
||||
class_logits = ng.parameter([batch_size, 24, 34, 62], fp_dtype, "class_logits")
|
||||
image_shape = ng.parameter([3], fp_dtype, "image_shape")
|
||||
node = ng.proposal(class_probs, class_logits, image_shape, attributes)
|
||||
|
||||
node.set_base_size(int_dtype(15))
|
||||
node.set_pre_nms_topn(int_dtype(7))
|
||||
node.set_post_nms_topn(int_dtype(33))
|
||||
node.set_nms_thresh(fp_dtype(1.55))
|
||||
node.set_feat_stride(int_dtype(8))
|
||||
node.set_min_size(int_dtype(123))
|
||||
node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
|
||||
node.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
|
||||
node.set_clip_before_nms(True)
|
||||
node.set_clip_after_nms(True)
|
||||
node.set_normalize(True)
|
||||
node.set_box_size_scale(fp_dtype(1.34))
|
||||
node.set_box_coordinate_scale(fp_dtype(0.88))
|
||||
node.set_framework("OpenVINO")
|
||||
|
||||
assert node.get_base_size() == int_dtype(15)
|
||||
assert node.get_pre_nms_topn() == int_dtype(7)
|
||||
assert node.get_post_nms_topn() == int_dtype(33)
|
||||
assert np.isclose(node.get_nms_thresh(), fp_dtype(1.55))
|
||||
assert node.get_feat_stride() == int_dtype(8)
|
||||
assert node.get_min_size() == int_dtype(123)
|
||||
assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype))
|
||||
assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype))
|
||||
assert node.get_clip_before_nms() == True
|
||||
assert node.get_clip_after_nms() == True
|
||||
assert node.get_normalize() == True
|
||||
assert np.isclose(node.get_box_size_scale(), fp_dtype(1.34))
|
||||
assert np.isclose(node.get_box_coordinate_scale(), fp_dtype(0.88))
|
||||
assert node.get_framework() == "OpenVINO"
|
||||
|
||||
|
||||
def test_dynamic_attr_cache(_proposal_node):
|
||||
node = _proposal_node
|
||||
|
||||
assert not node._attr_cache_valid
|
||||
node.set_nms_thresh(1.3453678102)
|
||||
assert not node._attr_cache_valid
|
||||
assert np.isclose(node.get_nms_thresh(), np.float64(1.3453678102))
|
||||
assert node._attr_cache_valid
|
||||
|
||||
|
||||
def test_dynamic_attr_transitivity(_proposal_node):
|
||||
node = _proposal_node
|
||||
node2 = node
|
||||
|
||||
node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
|
||||
assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
|
||||
assert np.allclose(node2.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64))
|
||||
|
||||
node2.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
|
||||
assert np.allclose(node2.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
|
||||
assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64))
|
||||
|
||||
|
||||
def test_dynamic_attributes_simple():
|
||||
batch_size = 1
|
||||
input_size = 16
|
||||
hidden_size = 128
|
||||
|
||||
X_shape = [batch_size, input_size]
|
||||
H_t_shape = [batch_size, hidden_size]
|
||||
W_shape = [3 * hidden_size, input_size]
|
||||
R_shape = [3 * hidden_size, hidden_size]
|
||||
B_shape = [4 * hidden_size]
|
||||
|
||||
parameter_X = ng.parameter(X_shape, name="X", dtype=np.float32)
|
||||
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=np.float32)
|
||||
parameter_W = ng.parameter(W_shape, name="W", dtype=np.float32)
|
||||
parameter_R = ng.parameter(R_shape, name="R", dtype=np.float32)
|
||||
parameter_B = ng.parameter(B_shape, name="B", dtype=np.float32)
|
||||
|
||||
activations = ["tanh", "relu"]
|
||||
activations_alpha = [1.0, 2.0]
|
||||
activations_beta = [1.0, 2.0]
|
||||
clip = 0.5
|
||||
linear_before_reset = True
|
||||
|
||||
node = ng.gru_cell(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_W,
|
||||
parameter_R,
|
||||
parameter_B,
|
||||
hidden_size,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
clip,
|
||||
linear_before_reset,
|
||||
)
|
||||
|
||||
assert node.get_hidden_size() == hidden_size
|
||||
assert all(map(lambda x, y: x == y, node.get_activations(), activations))
|
||||
assert all(np.equal(node.get_activations_alpha(), activations_alpha))
|
||||
assert all(np.equal(node.get_activations_beta(), activations_beta))
|
||||
assert node.get_linear_before_reset() == linear_before_reset
|
||||
assert np.isclose(node.get_clip(), clip)
|
Loading…
Reference in New Issue
Block a user