[PYTHON] Expose get_rt_info method for input/output nodes (#9211)

* expose get_rt_info method for input/output nodes

* expose rt_info

* add tests with rf_info update

* remvoe redundant import

* rename PyRTMap to RTMap
This commit is contained in:
Bartek Szmelczynski 2021-12-22 08:55:57 +01:00 committed by GitHub
parent 5fada94504
commit 8a1fd76124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 202 additions and 3 deletions

View File

@ -60,6 +60,7 @@ from openvino.pyopenvino import Layout
from openvino.pyopenvino import ConstOutput
from openvino.pyopenvino import util
from openvino.pyopenvino import layout_helpers
from openvino.pyopenvino import RTMap
from openvino.runtime.ie_api import Core
from openvino.runtime.ie_api import CompiledModel

View File

@ -11,6 +11,10 @@
namespace py = pybind11;
using PyRTMap = ov::Node::RTMap;
PYBIND11_MAKE_OPAQUE(PyRTMap);
void regclass_graph_Input(py::module m) {
py::class_<ov::Input<ov::Node>, std::shared_ptr<ov::Input<ov::Node>>> input(m, "Input", py::dynamic_attr());
input.doc() = "openvino.runtime.Input wraps ov::Input<Node>";
@ -75,4 +79,20 @@ void regclass_graph_Input(py::module m) {
get_source_output : Output
Output that is connected to the input.
)");
input.def("get_rt_info",
(ov::RTMap & (ov::Input<ov::Node>::*)()) & ov::Input<ov::Node>::get_rt_info,
py::return_value_policy::reference_internal,
R"(
Returns RTMap which is a dictionary of user defined runtime info.
Returns
----------
get_rt_info : RTMap
A dictionary of user defined data.
)");
input.def_property_readonly("rt_info", (ov::RTMap & (ov::Input<ov::Node>::*)()) & ov::Input<ov::Node>::get_rt_info);
input.def_property_readonly("rt_info",
(const ov::RTMap& (ov::Input<ov::Node>::*)() const) & ov::Input<ov::Node>::get_rt_info,
py::return_value_policy::reference_internal);
}

View File

@ -11,6 +11,10 @@
namespace py = pybind11;
using PyRTMap = ov::Node::RTMap;
PYBIND11_MAKE_OPAQUE(PyRTMap);
template <typename VT>
void regclass_graph_Output(py::module m, std::string typestring)
{
@ -115,6 +119,18 @@ void regclass_graph_Output(py::module m, std::string typestring)
get_tensor : descriptor.Tensor
Tensor of the output.
)");
output.def("get_rt_info",
(ov::RTMap & (ov::Output<VT>::*)()) & ov::Output<VT>::get_rt_info,
py::return_value_policy::reference_internal,
R"(
Returns RTMap which is a dictionary of user defined runtime info.
Returns
----------
get_rt_info : RTMap
A dictionary of user defined data.
)");
output.def_property_readonly("node", &ov::Output<VT>::get_node);
output.def_property_readonly("index", &ov::Output<VT>::get_index);
@ -125,4 +141,12 @@ void regclass_graph_Output(py::module m, std::string typestring)
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape);
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
output.def_property_readonly("rt_info",
(ov::RTMap&(ov::Output<VT>::*)()) &
ov::Output<VT>::get_rt_info,
py::return_value_policy::reference_internal);
output.def_property_readonly("rt_info",
(const ov::RTMap&(ov::Output<VT>::*)() const) &
ov::Output<VT>::get_rt_info,
py::return_value_policy::reference_internal);
}

View File

@ -26,8 +26,8 @@ using PyRTMap = ov::RTMap;
PYBIND11_MAKE_OPAQUE(PyRTMap);
void regclass_graph_PyRTMap(py::module m) {
auto py_map = py::class_<PyRTMap>(m, "PyRTMap");
py_map.doc() = "openvino.runtime.PyRTMap makes bindings for std::map<std::string, "
auto py_map = py::class_<PyRTMap>(m, "RTMap");
py_map.doc() = "openvino.runtime.RTMap makes bindings for std::map<std::string, "
"ov::Any, which can later be used as ov::Node::RTMap";
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const std::string v) {

View File

@ -0,0 +1,123 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
from ..conftest import model_path
from openvino.runtime import Input, Shape, PartialShape, Type, Parameter, \
RTMap
from openvino.runtime import Core
is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
test_net_xml, test_net_bin = model_path(is_myriad)
def model_path(is_myriad=False):
path_to_repo = os.environ["MODELS_PATH"]
if not is_myriad:
test_xml = os.path.join(path_to_repo, "models", "test_model", "test_model_fp32.xml")
test_bin = os.path.join(path_to_repo, "models", "test_model", "test_model_fp32.bin")
else:
test_xml = os.path.join(path_to_repo, "models", "test_model", "test_model_fp16.xml")
test_bin = os.path.join(path_to_repo, "models", "test_model", "test_model_fp16.bin")
return (test_xml, test_bin)
def test_input_type(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
assert isinstance(input_node, Input)
def test_const_output_docs(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
exptected_string = "openvino.runtime.Input wraps ov::Input<Node>"
assert input_node.__doc__ == exptected_string
def test_input_get_index(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
assert input_node.get_index() == 0
def test_input_element_type(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
assert input_node.get_element_type() == Type.f32
def test_input_get_shape(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
assert str(input_node.get_shape()) == str(Shape([1, 10]))
def test_input_get_partial_shape(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
expected_partial_shape = PartialShape([1, 10])
assert input_node.get_partial_shape() == expected_partial_shape
def test_input_get_source_output(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
name = input_node.get_source_output().get_node().get_friendly_name()
assert name == "fc_out"
def test_input_get_rt_info(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
rt_info = input_node.get_rt_info()
assert isinstance(rt_info, RTMap)
def test_input_rt_info(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
rt_info = input_node.rt_info
assert isinstance(rt_info, RTMap)
def test_input_update_rt_info(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input = exec_net.output(0)
input_node = input.get_node().inputs()[0]
rt = input_node.get_rt_info()
rt["test12345"] = "test"
for k, v in input_node.get_rt_info().items():
assert k == "test12345"
assert isinstance(v, Parameter)

View File

@ -4,7 +4,9 @@
import os
from ..conftest import model_path
from openvino.runtime import ConstOutput, Shape, PartialShape, Type
import openvino.runtime.opset8 as ops
from openvino.runtime import ConstOutput, Shape, PartialShape, Type, \
Output, Parameter, RTMap
from openvino.runtime import Core
@ -100,3 +102,32 @@ def test_const_output_get_names(device):
assert node.names == expected_names
assert node.get_any_name() == input_name
assert node.any_name == input_name
def test_const_get_rf_info(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
output_node = exec_net.output(0)
rt_info = output_node.get_rt_info()
assert isinstance(rt_info, RTMap)
def test_const_output_runtime_info(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input_name = "data"
output_node = exec_net.input(input_name)
rt_info = output_node.rt_info
assert isinstance(rt_info, RTMap)
def test_update_rt_info(device):
relu = ops.relu(5)
output_node = Output._from_node(relu)
rt = output_node.get_rt_info()
rt["test12345"] = "test"
for k, v in output_node.get_rt_info().items():
assert k == "test12345"
assert isinstance(v, Parameter)