[PYTHON API] fix tensor naming (#8918)

* bind any_names

* Update src/bindings/python/src/pyopenvino/graph/node_output.hpp

Co-authored-by: Jan Iwaszkiewicz <jan.iwaszkiewicz@intel.com>

* add properties and tests

Co-authored-by: Jan Iwaszkiewicz <jan.iwaszkiewicz@intel.com>
Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>
This commit is contained in:
Alexey Lebedev 2021-12-01 14:57:24 +03:00 committed by GitHub
parent 5a52a8e4a3
commit 246e628c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 2 deletions

View File

@ -38,7 +38,7 @@ def normalize_inputs(py_dict: dict, py_types: dict) -> dict:
def get_input_types(obj: Union[InferRequestBase, ExecutableNetworkBase]) -> dict:
"""Get all precisions from object inputs."""
return {i.get_node().get_friendly_name(): i.get_node().get_element_type() for i in obj.inputs}
return {i.get_any_name(): i.get_element_type() for i in obj.inputs}
class InferRequest(InferRequestBase):

View File

@ -41,6 +41,27 @@ void regclass_graph_Output(py::module m, std::string typestring)
get_index : int
Index value as integer.
)");
output.def("get_any_name",
&ov::Output<VT>::get_any_name,
R"(
One of the tensor names associated with this output.
Note: first name in lexicographical order.
Returns
----------
get_any_name : str
Tensor name as string.
)");
output.def("get_names",
&ov::Output<VT>::get_names,
R"(
The tensor names associated with this output.
Returns
----------
get_names : set
Set of tensor names.
)");
output.def("get_element_type",
&ov::Output<VT>::get_element_type,
R"(
@ -94,4 +115,14 @@ void regclass_graph_Output(py::module m, std::string typestring)
get_tensor : descriptor.Tensor
Tensor of the output.
)");
output.def_property_readonly("node", &ov::Output<VT>::get_node);
output.def_property_readonly("index", &ov::Output<VT>::get_index);
output.def_property_readonly("any_name", &ov::Output<VT>::get_any_name);
output.def_property_readonly("names", &ov::Output<VT>::get_names);
output.def_property_readonly("element_type", &ov::Output<VT>::get_element_type);
output.def_property_readonly("shape", &ov::Output<VT>::get_shape);
output.def_property_readonly("partial_shape", &ov::Output<VT>::get_partial_shape);
output.def_property_readonly("target_inputs", &ov::Output<VT>::get_target_inputs);
output.def_property_readonly("tensor", &ov::Output<VT>::get_tensor);
}

View File

@ -33,7 +33,8 @@ def test_get_profiling_info(device):
exec_net = core.compile_model(func, device)
img = read_image()
request = exec_net.create_infer_request()
request.infer({0: img})
tensor_name = exec_net.input("data").any_name
request.infer({tensor_name: img})
assert request.latency > 0
prof_info = request.get_profiling_info()
soft_max_node = next(node for node in prof_info if node.node_name == "fc_out")

View File

@ -46,6 +46,7 @@ def test_const_output_get_index(device):
exec_net = core.compile_model(func, device)
node = exec_net.input("data")
assert node.get_index() == 0
assert node.index == 0
def test_const_output_get_element_type(device):
@ -54,6 +55,7 @@ def test_const_output_get_element_type(device):
exec_net = core.compile_model(func, device)
node = exec_net.input("data")
assert node.get_element_type() == Type.f32
assert node.element_type == Type.f32
def test_const_output_get_shape(device):
@ -63,6 +65,7 @@ def test_const_output_get_shape(device):
node = exec_net.input("data")
expected_shape = Shape([1, 3, 32, 32])
assert str(node.get_shape()) == str(expected_shape)
assert str(node.shape) == str(expected_shape)
def test_const_output_get_partial_shape(device):
@ -72,6 +75,7 @@ def test_const_output_get_partial_shape(device):
node = exec_net.input("data")
expected_partial_shape = PartialShape([1, 3, 32, 32])
assert node.get_partial_shape() == expected_partial_shape
assert node.partial_shape == expected_partial_shape
def test_const_output_get_target_inputs(device):
@ -81,3 +85,18 @@ def test_const_output_get_target_inputs(device):
outputs = exec_net.outputs
for node in outputs:
assert isinstance(node.get_target_inputs(), set)
assert isinstance(node.target_inputs, set)
def test_const_output_get_names(device):
core = Core()
func = core.read_model(model=test_net_xml, weights=test_net_bin)
exec_net = core.compile_model(func, device)
input_name = "data"
node = exec_net.input(input_name)
expected_names = set()
expected_names.add(input_name)
assert node.get_names() == expected_names
assert node.names == expected_names
assert node.get_any_name() == input_name
assert node.any_name == input_name