[nGraph] RTInfo refactor (#1806)
This commit is contained in:
parent
038d5d8f22
commit
c147f03d5f
@ -212,6 +212,7 @@ sources = [
|
||||
"pyngraph/types/regmodule_pyngraph_types.cpp",
|
||||
"pyngraph/util.cpp",
|
||||
"pyngraph/variant.cpp",
|
||||
"pyngraph/rt_map.cpp",
|
||||
]
|
||||
|
||||
packages = [
|
||||
|
@ -23,7 +23,6 @@ try:
|
||||
except DistributionNotFound:
|
||||
__version__ = "0.0.0.dev0"
|
||||
|
||||
import ngraph.utils.rt_map
|
||||
from ngraph.impl import Node
|
||||
from ngraph.helpers import function_from_cnn
|
||||
|
||||
|
@ -45,7 +45,5 @@ from _pyngraph import CoordinateDiff
|
||||
from _pyngraph import AxisSet
|
||||
from _pyngraph import AxisVector
|
||||
from _pyngraph import Coordinate
|
||||
from _pyngraph import VariantString
|
||||
from _pyngraph import VariantInt
|
||||
|
||||
from _pyngraph import util
|
||||
|
@ -1,52 +0,0 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2017-2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
"""Overrides pybind PyRTMap class methods."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import _pyngraph
|
||||
|
||||
from _pyngraph import Variant
|
||||
from ngraph.impl import VariantInt, VariantString
|
||||
from ngraph.exceptions import UserInputError
|
||||
|
||||
|
||||
def _convert_to_variant(item: Any) -> Variant:
|
||||
"""Convert value to Variant class, otherwise throw error."""
|
||||
if isinstance(item, Variant):
|
||||
return item
|
||||
variant_mapping = {
|
||||
int: VariantInt,
|
||||
str: VariantString,
|
||||
}
|
||||
|
||||
new_type = variant_mapping.get(type(item), None)
|
||||
|
||||
if new_type is None:
|
||||
raise UserInputError("Cannot map value to any of registered Variant classes", str(item))
|
||||
|
||||
return new_type(item)
|
||||
|
||||
|
||||
binding_copy = _pyngraph.PyRTMap.__setitem__
|
||||
|
||||
|
||||
def _setitem(self: _pyngraph.PyRTMap, arg0: str, arg1: Any) -> None:
|
||||
"""Override setting values in dictionary."""
|
||||
binding_copy(self, arg0, _convert_to_variant(arg1))
|
||||
|
||||
|
||||
_pyngraph.PyRTMap.__setitem__ = _setitem
|
@ -26,6 +26,7 @@
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/variant.hpp"
|
||||
#include "pyngraph/node.hpp"
|
||||
#include "pyngraph/rt_map.hpp"
|
||||
#include "pyngraph/variant.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
@ -36,11 +37,6 @@ PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
void regclass_pyngraph_Node(py::module m)
|
||||
{
|
||||
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap");
|
||||
py_map.doc() =
|
||||
"ngraph.impl.PyRTMap makes bindings for std::map<std::string, "
|
||||
"std::shared_ptr<ngraph::Variant>>, which can later be used as ngraph::Node::RTMap";
|
||||
|
||||
py::class_<ngraph::Node, std::shared_ptr<ngraph::Node>> node(m, "Node", py::dynamic_attr());
|
||||
node.doc() = "ngraph.impl.Node wraps ngraph::Node";
|
||||
node.def("__add__",
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
|
||||
#include "pyngraph/partial_shape.hpp"
|
||||
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
|
||||
#include "pyngraph/rt_map.hpp"
|
||||
#include "pyngraph/shape.hpp"
|
||||
#include "pyngraph/strides.hpp"
|
||||
#include "pyngraph/types/regmodule_pyngraph_types.hpp"
|
||||
@ -45,6 +46,7 @@ namespace py = pybind11;
|
||||
PYBIND11_MODULE(_pyngraph, m)
|
||||
{
|
||||
m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph";
|
||||
regclass_pyngraph_PyRTMap(m);
|
||||
regclass_pyngraph_Node(m);
|
||||
regclass_pyngraph_Input(m);
|
||||
regclass_pyngraph_Output(m);
|
||||
|
62
ngraph/python/src/pyngraph/rt_map.cpp
Normal file
62
ngraph/python/src/pyngraph/rt_map.cpp
Normal file
@ -0,0 +1,62 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include "dict_attribute_visitor.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/variant.hpp"
|
||||
#include "pyngraph/node.hpp"
|
||||
#include "pyngraph/rt_map.hpp"
|
||||
#include "pyngraph/variant.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using PyRTMap = std::map<std::string, std::shared_ptr<ngraph::Variant>>;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
template <typename T>
|
||||
void _set_with_variant(PyRTMap& m, const std::string& k, const T v)
|
||||
{
|
||||
auto new_v = std::make_shared<ngraph::VariantWrapper<T>>(ngraph::VariantWrapper<T>(v));
|
||||
auto it = m.find(k);
|
||||
if (it != m.end())
|
||||
it->second = new_v;
|
||||
else
|
||||
m.emplace(k, new_v);
|
||||
}
|
||||
|
||||
void regclass_pyngraph_PyRTMap(py::module m)
|
||||
{
|
||||
auto py_map = py::bind_map<PyRTMap>(m, "PyRTMap");
|
||||
py_map.doc() =
|
||||
"ngraph.impl.PyRTMap makes bindings for std::map<std::string, "
|
||||
"std::shared_ptr<ngraph::Variant>>, which can later be used as ngraph::Node::RTMap";
|
||||
|
||||
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const std::string v) {
|
||||
_set_with_variant(m, k, v);
|
||||
});
|
||||
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const int64_t v) {
|
||||
_set_with_variant(m, k, v);
|
||||
});
|
||||
}
|
23
ngraph/python/src/pyngraph/rt_map.hpp
Normal file
23
ngraph/python/src/pyngraph/rt_map.hpp
Normal file
@ -0,0 +1,23 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_pyngraph_PyRTMap(py::module m);
|
@ -39,6 +39,27 @@ extern void regclass_pyngraph_VariantWrapper(py::module m, std::string typestrin
|
||||
"ngraph.impl.Variant[typestring] wraps ngraph::VariantWrapper<typestring>";
|
||||
|
||||
variant_wrapper.def(py::init<const VT&>());
|
||||
|
||||
variant_wrapper.def("__eq__",
|
||||
[](const ngraph::VariantWrapper<VT>& a,
|
||||
const ngraph::VariantWrapper<VT>& b) { return a.get() == b.get(); },
|
||||
py::is_operator());
|
||||
variant_wrapper.def("__eq__",
|
||||
[](const ngraph::VariantWrapper<std::string>& a, const std::string& b) {
|
||||
return a.get() == b;
|
||||
},
|
||||
py::is_operator());
|
||||
variant_wrapper.def(
|
||||
"__eq__",
|
||||
[](const ngraph::VariantWrapper<int64_t>& a, const int64_t& b) { return a.get() == b; },
|
||||
py::is_operator());
|
||||
|
||||
variant_wrapper.def("__repr__", [](const ngraph::VariantWrapper<VT> self) {
|
||||
std::stringstream ret;
|
||||
ret << self.get();
|
||||
return ret.str();
|
||||
});
|
||||
|
||||
variant_wrapper.def("get",
|
||||
(VT & (ngraph::VariantWrapper<VT>::*)()) & ngraph::VariantWrapper<VT>::get);
|
||||
variant_wrapper.def("set", &ngraph::VariantWrapper<VT>::set);
|
||||
|
@ -18,9 +18,11 @@ import json
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from _pyngraph import VariantInt, VariantString
|
||||
|
||||
import ngraph as ng
|
||||
from ngraph.exceptions import UserInputError
|
||||
from ngraph.impl import Function, PartialShape, Shape, Type, VariantInt, VariantString
|
||||
from ngraph.impl import Function, PartialShape, Shape, Type
|
||||
from ngraph.impl.op import Parameter
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
@ -408,7 +410,7 @@ def test_variants():
|
||||
|
||||
|
||||
def test_runtime_info():
|
||||
test_shape = PartialShape([1, 3, 22, 22])
|
||||
test_shape = PartialShape([1, 1, 1, 1])
|
||||
test_type = Type.f32
|
||||
test_param = Parameter(test_type, test_shape)
|
||||
relu_node = ng.relu(test_param)
|
||||
@ -417,7 +419,7 @@ def test_runtime_info():
|
||||
relu_node.set_friendly_name("testReLU")
|
||||
runtime_info_after = relu_node.get_rt_info()
|
||||
|
||||
assert runtime_info == runtime_info_after
|
||||
assert runtime_info_after["affinity"] == "test_affinity"
|
||||
|
||||
params = [test_param]
|
||||
results = [relu_node]
|
||||
|
Loading…
Reference in New Issue
Block a user