[PyOV] Fix getting all names in OVDict (#16665)
* [PyOV] Fix getting all names in OVDict * Add docs and adjust tests * Fix linter issues * Adjust typing and add test for incorrect key type --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
d732024ccb
commit
92eb62fe63
@ -12,7 +12,8 @@ except ImportError:
|
||||
from singledispatchmethod import singledispatchmethod # type: ignore[no-redef]
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Union, Dict, List, Iterator, KeysView, ItemsView, ValuesView
|
||||
from typing import Dict, Set, Tuple, Union, Iterator, Optional
|
||||
from typing import KeysView, ItemsView, ValuesView
|
||||
|
||||
from openvino._pyopenvino import Tensor, ConstOutput
|
||||
from openvino._pyopenvino import InferRequest as InferRequestBase
|
||||
@ -71,6 +72,7 @@ class OVDict(Mapping):
|
||||
"""
|
||||
def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None:
|
||||
self._dict = _dict
|
||||
self._names: Optional[Dict[ConstOutput, Set[str]]] = None
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return self._dict.__iter__()
|
||||
@ -81,12 +83,19 @@ class OVDict(Mapping):
|
||||
def __repr__(self) -> str:
|
||||
return self._dict.__repr__()
|
||||
|
||||
def __get_names(self) -> Dict[ConstOutput, Set[str]]:
|
||||
"""Return names of every output key.
|
||||
|
||||
Insert empty set if key has no name.
|
||||
"""
|
||||
return {key: key.get_names() for key in self._dict.keys()}
|
||||
|
||||
def __get_key(self, index: int) -> ConstOutput:
|
||||
return list(self._dict.keys())[index]
|
||||
|
||||
@singledispatchmethod
|
||||
def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
|
||||
raise TypeError("Unknown key type!")
|
||||
raise TypeError(f"Unknown key type: {type(key)}")
|
||||
|
||||
@__getitem_impl.register
|
||||
def _(self, key: ConstOutput) -> np.ndarray:
|
||||
@ -101,10 +110,12 @@ class OVDict(Mapping):
|
||||
|
||||
@__getitem_impl.register
|
||||
def _(self, key: str) -> np.ndarray:
|
||||
try:
|
||||
return self._dict[self.__get_key(self.names().index(key))]
|
||||
except ValueError:
|
||||
raise KeyError(key)
|
||||
if self._names is None:
|
||||
self._names = self.__get_names()
|
||||
for port, port_names in self._names.items():
|
||||
if key in port_names:
|
||||
return self._dict[port]
|
||||
raise KeyError(key)
|
||||
|
||||
def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
|
||||
return self.__getitem_impl(key)
|
||||
@ -118,12 +129,14 @@ class OVDict(Mapping):
|
||||
def items(self) -> ItemsView[ConstOutput, np.ndarray]:
|
||||
return self._dict.items()
|
||||
|
||||
def names(self) -> List[str]:
|
||||
"""Return a name of every output key.
|
||||
def names(self) -> Tuple[Set[str], ...]:
|
||||
"""Return names of every output key.
|
||||
|
||||
Throws RuntimeError if any of ConstOutput keys has no name.
|
||||
Insert empty set if key has no name.
|
||||
"""
|
||||
return [key.get_any_name() for key in self._dict.keys()]
|
||||
if self._names is None:
|
||||
self._names = self.__get_names()
|
||||
return tuple(self._names.values())
|
||||
|
||||
def to_dict(self) -> Dict[ConstOutput, np.ndarray]:
|
||||
"""Return underlaying native dictionary.
|
||||
|
@ -99,7 +99,7 @@ def _check_dict(result, obj, output_names=None):
|
||||
assert _check_keys(result.keys(), outs)
|
||||
assert _check_values(result)
|
||||
assert _check_items(result, outs, output_names)
|
||||
assert result.names() == output_names
|
||||
assert all([output_names[i] in result.names()[i] for i in range(0, len(output_names))])
|
||||
|
||||
return True
|
||||
|
||||
@ -124,6 +124,15 @@ def test_ovdict_single_output_basic(device, is_direct):
|
||||
raise TypeError("Unknown `obj` type!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_direct", [True, False])
|
||||
def test_ovdict_wrong_key_type(device, is_direct):
|
||||
result, _ = _get_ovdict(device, multi_output=False, direct_infer=is_direct)
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
_ = result[2.0]
|
||||
assert "Unknown key type: <class 'float'>" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_direct", [True, False])
|
||||
def test_ovdict_single_output_noname(device, is_direct):
|
||||
result, obj = _get_ovdict(
|
||||
@ -140,13 +149,13 @@ def test_ovdict_single_output_noname(device, is_direct):
|
||||
assert isinstance(result[outs[0]], np.ndarray)
|
||||
assert isinstance(result[0], np.ndarray)
|
||||
|
||||
with pytest.raises(RuntimeError) as e0:
|
||||
with pytest.raises(KeyError) as e0:
|
||||
_ = result["some_name"]
|
||||
assert "Attempt to get a name for a Tensor without names" in str(e0.value)
|
||||
assert "some_name" in str(e0.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as e1:
|
||||
_ = result.names()
|
||||
assert "Attempt to get a name for a Tensor without names" in str(e1.value)
|
||||
# Check if returned names are tuple with one empty set
|
||||
assert len(result.names()) == 1
|
||||
assert result.names()[0] == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_direct", [True, False])
|
||||
|
Loading…
Reference in New Issue
Block a user