[PyOV] Fix getting all names in OVDict (#16665)

* [PyOV] Fix getting all names in OVDict

* Add docs and adjust tests

* Fix linter issues

* Adjust typing and add test for incorrect key type

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Jan Iwaszkiewicz 2023-04-06 14:44:37 +02:00 committed by GitHub
parent d732024ccb
commit 92eb62fe63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 16 deletions

View File

@ -12,7 +12,8 @@ except ImportError:
from singledispatchmethod import singledispatchmethod # type: ignore[no-redef] from singledispatchmethod import singledispatchmethod # type: ignore[no-redef]
from collections.abc import Mapping from collections.abc import Mapping
from typing import Union, Dict, List, Iterator, KeysView, ItemsView, ValuesView from typing import Dict, Set, Tuple, Union, Iterator, Optional
from typing import KeysView, ItemsView, ValuesView
from openvino._pyopenvino import Tensor, ConstOutput from openvino._pyopenvino import Tensor, ConstOutput
from openvino._pyopenvino import InferRequest as InferRequestBase from openvino._pyopenvino import InferRequest as InferRequestBase
@ -71,6 +72,7 @@ class OVDict(Mapping):
""" """
def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None: def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None:
self._dict = _dict self._dict = _dict
self._names: Optional[Dict[ConstOutput, Set[str]]] = None
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
return self._dict.__iter__() return self._dict.__iter__()
@ -81,12 +83,19 @@ class OVDict(Mapping):
def __repr__(self) -> str: def __repr__(self) -> str:
return self._dict.__repr__() return self._dict.__repr__()
def __get_names(self) -> Dict[ConstOutput, Set[str]]:
"""Return names of every output key.
Insert empty set if key has no name.
"""
return {key: key.get_names() for key in self._dict.keys()}
def __get_key(self, index: int) -> ConstOutput: def __get_key(self, index: int) -> ConstOutput:
return list(self._dict.keys())[index] return list(self._dict.keys())[index]
@singledispatchmethod @singledispatchmethod
def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray: def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
raise TypeError("Unknown key type!") raise TypeError(f"Unknown key type: {type(key)}")
@__getitem_impl.register @__getitem_impl.register
def _(self, key: ConstOutput) -> np.ndarray: def _(self, key: ConstOutput) -> np.ndarray:
@ -101,9 +110,11 @@ class OVDict(Mapping):
@__getitem_impl.register @__getitem_impl.register
def _(self, key: str) -> np.ndarray: def _(self, key: str) -> np.ndarray:
try: if self._names is None:
return self._dict[self.__get_key(self.names().index(key))] self._names = self.__get_names()
except ValueError: for port, port_names in self._names.items():
if key in port_names:
return self._dict[port]
raise KeyError(key) raise KeyError(key)
def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray: def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
@ -118,12 +129,14 @@ class OVDict(Mapping):
def items(self) -> ItemsView[ConstOutput, np.ndarray]: def items(self) -> ItemsView[ConstOutput, np.ndarray]:
return self._dict.items() return self._dict.items()
def names(self) -> List[str]: def names(self) -> Tuple[Set[str], ...]:
"""Return a name of every output key. """Return names of every output key.
Throws RuntimeError if any of ConstOutput keys has no name. Insert empty set if key has no name.
""" """
return [key.get_any_name() for key in self._dict.keys()] if self._names is None:
self._names = self.__get_names()
return tuple(self._names.values())
def to_dict(self) -> Dict[ConstOutput, np.ndarray]: def to_dict(self) -> Dict[ConstOutput, np.ndarray]:
"""Return underlaying native dictionary. """Return underlaying native dictionary.

View File

@ -99,7 +99,7 @@ def _check_dict(result, obj, output_names=None):
assert _check_keys(result.keys(), outs) assert _check_keys(result.keys(), outs)
assert _check_values(result) assert _check_values(result)
assert _check_items(result, outs, output_names) assert _check_items(result, outs, output_names)
assert result.names() == output_names assert all([output_names[i] in result.names()[i] for i in range(0, len(output_names))])
return True return True
@ -124,6 +124,15 @@ def test_ovdict_single_output_basic(device, is_direct):
raise TypeError("Unknown `obj` type!") raise TypeError("Unknown `obj` type!")
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_wrong_key_type(device, is_direct):
result, _ = _get_ovdict(device, multi_output=False, direct_infer=is_direct)
with pytest.raises(TypeError) as e:
_ = result[2.0]
assert "Unknown key type: <class 'float'>" in str(e.value)
@pytest.mark.parametrize("is_direct", [True, False]) @pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_single_output_noname(device, is_direct): def test_ovdict_single_output_noname(device, is_direct):
result, obj = _get_ovdict( result, obj = _get_ovdict(
@ -140,13 +149,13 @@ def test_ovdict_single_output_noname(device, is_direct):
assert isinstance(result[outs[0]], np.ndarray) assert isinstance(result[outs[0]], np.ndarray)
assert isinstance(result[0], np.ndarray) assert isinstance(result[0], np.ndarray)
with pytest.raises(RuntimeError) as e0: with pytest.raises(KeyError) as e0:
_ = result["some_name"] _ = result["some_name"]
assert "Attempt to get a name for a Tensor without names" in str(e0.value) assert "some_name" in str(e0.value)
with pytest.raises(RuntimeError) as e1: # Check if returned names are tuple with one empty set
_ = result.names() assert len(result.names()) == 1
assert "Attempt to get a name for a Tensor without names" in str(e1.value) assert result.names()[0] == set()
@pytest.mark.parametrize("is_direct", [True, False]) @pytest.mark.parametrize("is_direct", [True, False])