[PyOV] Fix getting all names in OVDict (#16665)

* [PyOV] Fix getting all names in OVDict

* Add docs and adjust tests

* Fix linter issues

* Adjust typing and add test for incorrect key type

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Jan Iwaszkiewicz 2023-04-06 14:44:37 +02:00 committed by GitHub
parent d732024ccb
commit 92eb62fe63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 16 deletions

View File

@ -12,7 +12,8 @@ except ImportError:
from singledispatchmethod import singledispatchmethod # type: ignore[no-redef]
from collections.abc import Mapping
from typing import Union, Dict, List, Iterator, KeysView, ItemsView, ValuesView
from typing import Dict, Set, Tuple, Union, Iterator, Optional
from typing import KeysView, ItemsView, ValuesView
from openvino._pyopenvino import Tensor, ConstOutput
from openvino._pyopenvino import InferRequest as InferRequestBase
@ -71,6 +72,7 @@ class OVDict(Mapping):
"""
def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None:
self._dict = _dict
self._names: Optional[Dict[ConstOutput, Set[str]]] = None
def __iter__(self) -> Iterator:
return self._dict.__iter__()
@ -81,12 +83,19 @@ class OVDict(Mapping):
def __repr__(self) -> str:
return self._dict.__repr__()
def __get_names(self) -> Dict[ConstOutput, Set[str]]:
"""Return names of every output key.
Insert empty set if key has no name.
"""
return {key: key.get_names() for key in self._dict.keys()}
def __get_key(self, index: int) -> ConstOutput:
return list(self._dict.keys())[index]
@singledispatchmethod
def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
raise TypeError("Unknown key type!")
raise TypeError(f"Unknown key type: {type(key)}")
@__getitem_impl.register
def _(self, key: ConstOutput) -> np.ndarray:
@ -101,10 +110,12 @@ class OVDict(Mapping):
@__getitem_impl.register
def _(self, key: str) -> np.ndarray:
try:
return self._dict[self.__get_key(self.names().index(key))]
except ValueError:
raise KeyError(key)
if self._names is None:
self._names = self.__get_names()
for port, port_names in self._names.items():
if key in port_names:
return self._dict[port]
raise KeyError(key)
def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
return self.__getitem_impl(key)
@ -118,12 +129,14 @@ class OVDict(Mapping):
def items(self) -> ItemsView[ConstOutput, np.ndarray]:
return self._dict.items()
def names(self) -> List[str]:
"""Return a name of every output key.
def names(self) -> Tuple[Set[str], ...]:
"""Return names of every output key.
Throws RuntimeError if any of ConstOutput keys has no name.
Insert empty set if key has no name.
"""
return [key.get_any_name() for key in self._dict.keys()]
if self._names is None:
self._names = self.__get_names()
return tuple(self._names.values())
def to_dict(self) -> Dict[ConstOutput, np.ndarray]:
"""Return underlaying native dictionary.

View File

@ -99,7 +99,7 @@ def _check_dict(result, obj, output_names=None):
assert _check_keys(result.keys(), outs)
assert _check_values(result)
assert _check_items(result, outs, output_names)
assert result.names() == output_names
assert all([output_names[i] in result.names()[i] for i in range(0, len(output_names))])
return True
@ -124,6 +124,15 @@ def test_ovdict_single_output_basic(device, is_direct):
raise TypeError("Unknown `obj` type!")
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_wrong_key_type(device, is_direct):
result, _ = _get_ovdict(device, multi_output=False, direct_infer=is_direct)
with pytest.raises(TypeError) as e:
_ = result[2.0]
assert "Unknown key type: <class 'float'>" in str(e.value)
@pytest.mark.parametrize("is_direct", [True, False])
def test_ovdict_single_output_noname(device, is_direct):
result, obj = _get_ovdict(
@ -140,13 +149,13 @@ def test_ovdict_single_output_noname(device, is_direct):
assert isinstance(result[outs[0]], np.ndarray)
assert isinstance(result[0], np.ndarray)
with pytest.raises(RuntimeError) as e0:
with pytest.raises(KeyError) as e0:
_ = result["some_name"]
assert "Attempt to get a name for a Tensor without names" in str(e0.value)
assert "some_name" in str(e0.value)
with pytest.raises(RuntimeError) as e1:
_ = result.names()
assert "Attempt to get a name for a Tensor without names" in str(e1.value)
# Check if returned names are tuple with one empty set
assert len(result.names()) == 1
assert result.names()[0] == set()
@pytest.mark.parametrize("is_direct", [True, False])