[nGraph] RTInfo refactor (#1806)

This commit is contained in:
Jan Iwaszkiewicz 2020-08-27 15:41:21 +02:00 committed by GitHub
parent 038d5d8f22
commit c147f03d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 115 additions and 63 deletions

View File

@ -212,6 +212,7 @@ sources = [
"pyngraph/types/regmodule_pyngraph_types.cpp",
"pyngraph/util.cpp",
"pyngraph/variant.cpp",
"pyngraph/rt_map.cpp",
]
packages = [

View File

@ -23,7 +23,6 @@ try:
except DistributionNotFound:
__version__ = "0.0.0.dev0"
import ngraph.utils.rt_map
from ngraph.impl import Node
from ngraph.helpers import function_from_cnn

View File

@ -45,7 +45,5 @@ from _pyngraph import CoordinateDiff
from _pyngraph import AxisSet
from _pyngraph import AxisVector
from _pyngraph import Coordinate
from _pyngraph import VariantString
from _pyngraph import VariantInt
from _pyngraph import util

View File

@ -1,52 +0,0 @@
# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""Overrides pybind PyRTMap class methods."""
from typing import Any
import _pyngraph
from _pyngraph import Variant
from ngraph.impl import VariantInt, VariantString
from ngraph.exceptions import UserInputError
def _convert_to_variant(item: Any) -> Variant:
"""Convert value to Variant class, otherwise throw error."""
if isinstance(item, Variant):
return item
variant_mapping = {
int: VariantInt,
str: VariantString,
}
new_type = variant_mapping.get(type(item), None)
if new_type is None:
raise UserInputError("Cannot map value to any of registered Variant classes", str(item))
return new_type(item)
binding_copy = _pyngraph.PyRTMap.__setitem__
def _setitem(self: _pyngraph.PyRTMap, arg0: str, arg1: Any) -> None:
"""Override setting values in dictionary."""
binding_copy(self, arg0, _convert_to_variant(arg1))
_pyngraph.PyRTMap.__setitem__ = _setitem

View File

@ -26,6 +26,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/variant.hpp"
#include "pyngraph/node.hpp"
#include "pyngraph/rt_map.hpp"
#include "pyngraph/variant.hpp"
namespace py = pybind11;
@ -36,11 +37,6 @@ PYBIND11_MAKE_OPAQUE(PyRTMap);
void regclass_pyngraph_Node(py::module m)
{
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap");
py_map.doc() =
"ngraph.impl.PyRTMap makes bindings for std::map<std::string, "
"std::shared_ptr<ngraph::Variant>>, which can later be used as ngraph::Node::RTMap";
py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node", py::dynamic_attr());
node.doc() = "ngraph.impl.Node wraps ngraph::Node";
node.def("__add__",

View File

@ -34,6 +34,7 @@
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/partial_shape.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/rt_map.hpp"
#include "pyngraph/shape.hpp"
#include "pyngraph/strides.hpp"
#include "pyngraph/types/regmodule_pyngraph_types.hpp"
@ -45,6 +46,7 @@ namespace py = pybind11;
PYBIND11_MODULE(_pyngraph, m)
{
m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph";
regclass_pyngraph_PyRTMap(m);
regclass_pyngraph_Node(m);
regclass_pyngraph_Input(m);
regclass_pyngraph_Output(m);

View File

@ -0,0 +1,62 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include "dict_attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/variant.hpp"
#include "pyngraph/node.hpp"
#include "pyngraph/rt_map.hpp"
#include "pyngraph/variant.hpp"
namespace py = pybind11;
using PyRTMap = std::map<std::string, std::shared_ptr<ngraph::Variant>>;
PYBIND11_MAKE_OPAQUE(PyRTMap);
template <typename T>
void _set_with_variant(PyRTMap& m, const std::string& k, const T v)
{
auto new_v = std::make_shared<ngraph::VariantWrapper<T>>(ngraph::VariantWrapper<T>(v));
auto it = m.find(k);
if (it != m.end())
it->second = new_v;
else
m.emplace(k, new_v);
}
void regclass_pyngraph_PyRTMap(py::module m)
{
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap");
py_map.doc() =
"ngraph.impl.PyRTMap makes bindings for std::map<std::string, "
"std::shared_ptr<ngraph::Variant>>, which can later be used as ngraph::Node::RTMap";
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const std::string v) {
_set_with_variant(m, k, v);
});
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const int64_t v) {
_set_with_variant(m, k, v);
});
}

View File

@ -0,0 +1,23 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_PyRTMap(py::module m);

View File

@ -39,6 +39,27 @@ extern void regclass_pyngraph_VariantWrapper(py::module m, std::string typestrin
"ngraph.impl.Variant[typestring] wraps ngraph::VariantWrapper<typestring>";
variant_wrapper.def(py::init<const VT&>());
variant_wrapper.def("__eq__",
[](const ngraph::VariantWrapper<VT>& a,
const ngraph::VariantWrapper<VT>& b) { return a.get() == b.get(); },
py::is_operator());
variant_wrapper.def("__eq__",
[](const ngraph::VariantWrapper<std::string>& a, const std::string& b) {
return a.get() == b;
},
py::is_operator());
variant_wrapper.def(
"__eq__",
[](const ngraph::VariantWrapper<int64_t>& a, const int64_t& b) { return a.get() == b; },
py::is_operator());
variant_wrapper.def("__repr__", [](const ngraph::VariantWrapper<VT> self) {
std::stringstream ret;
ret << self.get();
return ret.str();
});
variant_wrapper.def("get",
(VT & (ngraph::VariantWrapper<VT>::*)()) & ngraph::VariantWrapper<VT>::get);
variant_wrapper.def("set", &ngraph::VariantWrapper<VT>::set);

View File

@ -18,9 +18,11 @@ import json
import numpy as np
import pytest
from _pyngraph import VariantInt, VariantString
import ngraph as ng
from ngraph.exceptions import UserInputError
from ngraph.impl import Function, PartialShape, Shape, Type, VariantInt, VariantString
from ngraph.impl import Function, PartialShape, Shape, Type
from ngraph.impl.op import Parameter
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
@ -408,7 +410,7 @@ def test_variants():
def test_runtime_info():
test_shape = PartialShape([1, 3, 22, 22])
test_shape = PartialShape([1, 1, 1, 1])
test_type = Type.f32
test_param = Parameter(test_type, test_shape)
relu_node = ng.relu(test_param)
@ -417,7 +419,7 @@ def test_runtime_info():
relu_node.set_friendly_name("testReLU")
runtime_info_after = relu_node.get_rt_info()
assert runtime_info == runtime_info_after
assert runtime_info_after["affinity"] == "test_affinity"
params = [test_param]
results = [relu_node]