[PYTHON] Expose get_rt_info method for input/output nodes (#9211)
* expose get_rt_info method for input/output nodes * expose rt_info * add tests with rf_info update * remvoe redundant import * rename PyRTMap to RTMap
This commit is contained in:
parent
5fada94504
commit
8a1fd76124
@ -60,6 +60,7 @@ from openvino.pyopenvino import Layout
|
||||
from openvino.pyopenvino import ConstOutput
|
||||
from openvino.pyopenvino import util
|
||||
from openvino.pyopenvino import layout_helpers
|
||||
from openvino.pyopenvino import RTMap
|
||||
|
||||
from openvino.runtime.ie_api import Core
|
||||
from openvino.runtime.ie_api import CompiledModel
|
||||
|
@ -11,6 +11,10 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using PyRTMap = ov::Node::RTMap;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
void regclass_graph_Input(py::module m) {
|
||||
py::class_<ov::Input<ov::Node>, std::shared_ptr<ov::Input<ov::Node>>> input(m, "Input", py::dynamic_attr());
|
||||
input.doc() = "openvino.runtime.Input wraps ov::Input<Node>";
|
||||
@ -75,4 +79,20 @@ void regclass_graph_Input(py::module m) {
|
||||
get_source_output : Output
|
||||
Output that is connected to the input.
|
||||
)");
|
||||
|
||||
input.def("get_rt_info",
|
||||
(ov::RTMap & (ov::Input<ov::Node>::*)()) & ov::Input<ov::Node>::get_rt_info,
|
||||
py::return_value_policy::reference_internal,
|
||||
R"(
|
||||
Returns RTMap which is a dictionary of user defined runtime info.
|
||||
|
||||
Returns
|
||||
----------
|
||||
get_rt_info : RTMap
|
||||
A dictionary of user defined data.
|
||||
)");
|
||||
input.def_property_readonly("rt_info", (ov::RTMap & (ov::Input<ov::Node>::*)()) & ov::Input<ov::Node>::get_rt_info);
|
||||
input.def_property_readonly("rt_info",
|
||||
(const ov::RTMap& (ov::Input<ov::Node>::*)() const) & ov::Input<ov::Node>::get_rt_info,
|
||||
py::return_value_policy::reference_internal);
|
||||
}
|
||||
|
@ -11,6 +11,10 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using PyRTMap = ov::Node::RTMap;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
template <typename VT>
|
||||
void regclass_graph_Output(py::module m, std::string typestring)
|
||||
{
|
||||
@ -115,6 +119,18 @@ void regclass_graph_Output(py::module m, std::string typestring)
|
||||
get_tensor : descriptor.Tensor
|
||||
Tensor of the output.
|
||||
)");
|
||||
output.def("get_rt_info",
|
||||
(ov::RTMap & (ov::Output<VT>::*)()) & ov::Output<VT>::get_rt_info,
|
||||
py::return_value_policy::reference_internal,
|
||||
R"(
|
||||
Returns RTMap which is a dictionary of user defined runtime info.
|
||||
|
||||
Returns
|
||||
----------
|
||||
get_rt_info : RTMap
|
||||
A dictionary of user defined data.
|
||||
)");
|
||||
|
||||
|
||||
output.def_property_readonly("node", &ov::Output<VT>::get_node);
|
||||
output.def_property_readonly("index", &ov::Output<VT>::get_index);
|
||||
@ -125,4 +141,12 @@ void regclass_graph_Output(py::module m, std::string typestring)
|
||||
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape);
|
||||
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
|
||||
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
|
||||
output.def_property_readonly("rt_info",
|
||||
(ov::RTMap&(ov::Output<VT>::*)()) &
|
||||
ov::Output<VT>::get_rt_info,
|
||||
py::return_value_policy::reference_internal);
|
||||
output.def_property_readonly("rt_info",
|
||||
(const ov::RTMap&(ov::Output<VT>::*)() const) &
|
||||
ov::Output<VT>::get_rt_info,
|
||||
py::return_value_policy::reference_internal);
|
||||
}
|
||||
|
@ -26,8 +26,8 @@ using PyRTMap = ov::RTMap;
|
||||
PYBIND11_MAKE_OPAQUE(PyRTMap);
|
||||
|
||||
void regclass_graph_PyRTMap(py::module m) {
|
||||
auto py_map = py::class_<PyRTMap>(m, "PyRTMap");
|
||||
py_map.doc() = "openvino.runtime.PyRTMap makes bindings for std::map<std::string, "
|
||||
auto py_map = py::class_<PyRTMap>(m, "RTMap");
|
||||
py_map.doc() = "openvino.runtime.RTMap makes bindings for std::map<std::string, "
|
||||
"ov::Any, which can later be used as ov::Node::RTMap";
|
||||
|
||||
py_map.def("__setitem__", [](PyRTMap& m, const std::string& k, const std::string v) {
|
||||
|
@ -0,0 +1,123 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
|
||||
from ..conftest import model_path
|
||||
from openvino.runtime import Input, Shape, PartialShape, Type, Parameter, \
|
||||
RTMap
|
||||
|
||||
from openvino.runtime import Core
|
||||
|
||||
is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
|
||||
test_net_xml, test_net_bin = model_path(is_myriad)
|
||||
|
||||
|
||||
def model_path(is_myriad=False):
|
||||
path_to_repo = os.environ["MODELS_PATH"]
|
||||
if not is_myriad:
|
||||
test_xml = os.path.join(path_to_repo, "models", "test_model", "test_model_fp32.xml")
|
||||
test_bin = os.path.join(path_to_repo, "models", "test_model", "test_model_fp32.bin")
|
||||
else:
|
||||
test_xml = os.path.join(path_to_repo, "models", "test_model", "test_model_fp16.xml")
|
||||
test_bin = os.path.join(path_to_repo, "models", "test_model", "test_model_fp16.bin")
|
||||
return (test_xml, test_bin)
|
||||
|
||||
|
||||
def test_input_type(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
assert isinstance(input_node, Input)
|
||||
|
||||
|
||||
def test_const_output_docs(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
exptected_string = "openvino.runtime.Input wraps ov::Input<Node>"
|
||||
assert input_node.__doc__ == exptected_string
|
||||
|
||||
|
||||
def test_input_get_index(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
assert input_node.get_index() == 0
|
||||
|
||||
|
||||
def test_input_element_type(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
assert input_node.get_element_type() == Type.f32
|
||||
|
||||
|
||||
def test_input_get_shape(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
assert str(input_node.get_shape()) == str(Shape([1, 10]))
|
||||
|
||||
|
||||
def test_input_get_partial_shape(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
expected_partial_shape = PartialShape([1, 10])
|
||||
assert input_node.get_partial_shape() == expected_partial_shape
|
||||
|
||||
|
||||
def test_input_get_source_output(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
name = input_node.get_source_output().get_node().get_friendly_name()
|
||||
assert name == "fc_out"
|
||||
|
||||
|
||||
def test_input_get_rt_info(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
rt_info = input_node.get_rt_info()
|
||||
assert isinstance(rt_info, RTMap)
|
||||
|
||||
|
||||
def test_input_rt_info(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
rt_info = input_node.rt_info
|
||||
assert isinstance(rt_info, RTMap)
|
||||
|
||||
|
||||
def test_input_update_rt_info(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input = exec_net.output(0)
|
||||
input_node = input.get_node().inputs()[0]
|
||||
rt = input_node.get_rt_info()
|
||||
rt["test12345"] = "test"
|
||||
for k, v in input_node.get_rt_info().items():
|
||||
assert k == "test12345"
|
||||
assert isinstance(v, Parameter)
|
@ -4,7 +4,9 @@
|
||||
import os
|
||||
|
||||
from ..conftest import model_path
|
||||
from openvino.runtime import ConstOutput, Shape, PartialShape, Type
|
||||
import openvino.runtime.opset8 as ops
|
||||
from openvino.runtime import ConstOutput, Shape, PartialShape, Type, \
|
||||
Output, Parameter, RTMap
|
||||
|
||||
from openvino.runtime import Core
|
||||
|
||||
@ -100,3 +102,32 @@ def test_const_output_get_names(device):
|
||||
assert node.names == expected_names
|
||||
assert node.get_any_name() == input_name
|
||||
assert node.any_name == input_name
|
||||
|
||||
|
||||
def test_const_get_rf_info(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
output_node = exec_net.output(0)
|
||||
rt_info = output_node.get_rt_info()
|
||||
assert isinstance(rt_info, RTMap)
|
||||
|
||||
|
||||
def test_const_output_runtime_info(device):
|
||||
core = Core()
|
||||
func = core.read_model(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
input_name = "data"
|
||||
output_node = exec_net.input(input_name)
|
||||
rt_info = output_node.rt_info
|
||||
assert isinstance(rt_info, RTMap)
|
||||
|
||||
|
||||
def test_update_rt_info(device):
|
||||
relu = ops.relu(5)
|
||||
output_node = Output._from_node(relu)
|
||||
rt = output_node.get_rt_info()
|
||||
rt["test12345"] = "test"
|
||||
for k, v in output_node.get_rt_info().items():
|
||||
assert k == "test12345"
|
||||
assert isinstance(v, Parameter)
|
||||
|
Loading…
Reference in New Issue
Block a user