[PyOV] Fix getting all names in OVDict (#16665)
* [PyOV] Fix getting all names in OVDict * Add docs and adjust tests * Fix linter issues * Adjust typing and add test for incorrect key type --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
d732024ccb
commit
92eb62fe63
@ -12,7 +12,8 @@ except ImportError:
|
|||||||
from singledispatchmethod import singledispatchmethod # type: ignore[no-redef]
|
from singledispatchmethod import singledispatchmethod # type: ignore[no-redef]
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Union, Dict, List, Iterator, KeysView, ItemsView, ValuesView
|
from typing import Dict, Set, Tuple, Union, Iterator, Optional
|
||||||
|
from typing import KeysView, ItemsView, ValuesView
|
||||||
|
|
||||||
from openvino._pyopenvino import Tensor, ConstOutput
|
from openvino._pyopenvino import Tensor, ConstOutput
|
||||||
from openvino._pyopenvino import InferRequest as InferRequestBase
|
from openvino._pyopenvino import InferRequest as InferRequestBase
|
||||||
@ -71,6 +72,7 @@ class OVDict(Mapping):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None:
|
def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None:
|
||||||
self._dict = _dict
|
self._dict = _dict
|
||||||
|
self._names: Optional[Dict[ConstOutput, Set[str]]] = None
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
def __iter__(self) -> Iterator:
|
||||||
return self._dict.__iter__()
|
return self._dict.__iter__()
|
||||||
@ -81,12 +83,19 @@ class OVDict(Mapping):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self._dict.__repr__()
|
return self._dict.__repr__()
|
||||||
|
|
||||||
|
def __get_names(self) -> Dict[ConstOutput, Set[str]]:
|
||||||
|
"""Return names of every output key.
|
||||||
|
|
||||||
|
Insert empty set if key has no name.
|
||||||
|
"""
|
||||||
|
return {key: key.get_names() for key in self._dict.keys()}
|
||||||
|
|
||||||
def __get_key(self, index: int) -> ConstOutput:
|
def __get_key(self, index: int) -> ConstOutput:
|
||||||
return list(self._dict.keys())[index]
|
return list(self._dict.keys())[index]
|
||||||
|
|
||||||
@singledispatchmethod
|
@singledispatchmethod
|
||||||
def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
|
def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
|
||||||
raise TypeError("Unknown key type!")
|
raise TypeError(f"Unknown key type: {type(key)}")
|
||||||
|
|
||||||
@__getitem_impl.register
|
@__getitem_impl.register
|
||||||
def _(self, key: ConstOutput) -> np.ndarray:
|
def _(self, key: ConstOutput) -> np.ndarray:
|
||||||
@ -101,9 +110,11 @@ class OVDict(Mapping):
|
|||||||
|
|
||||||
@__getitem_impl.register
|
@__getitem_impl.register
|
||||||
def _(self, key: str) -> np.ndarray:
|
def _(self, key: str) -> np.ndarray:
|
||||||
try:
|
if self._names is None:
|
||||||
return self._dict[self.__get_key(self.names().index(key))]
|
self._names = self.__get_names()
|
||||||
except ValueError:
|
for port, port_names in self._names.items():
|
||||||
|
if key in port_names:
|
||||||
|
return self._dict[port]
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
|
|
||||||
def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
|
def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray:
|
||||||
@ -118,12 +129,14 @@ class OVDict(Mapping):
|
|||||||
def items(self) -> ItemsView[ConstOutput, np.ndarray]:
|
def items(self) -> ItemsView[ConstOutput, np.ndarray]:
|
||||||
return self._dict.items()
|
return self._dict.items()
|
||||||
|
|
||||||
def names(self) -> List[str]:
|
def names(self) -> Tuple[Set[str], ...]:
|
||||||
"""Return a name of every output key.
|
"""Return names of every output key.
|
||||||
|
|
||||||
Throws RuntimeError if any of ConstOutput keys has no name.
|
Insert empty set if key has no name.
|
||||||
"""
|
"""
|
||||||
return [key.get_any_name() for key in self._dict.keys()]
|
if self._names is None:
|
||||||
|
self._names = self.__get_names()
|
||||||
|
return tuple(self._names.values())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[ConstOutput, np.ndarray]:
|
def to_dict(self) -> Dict[ConstOutput, np.ndarray]:
|
||||||
"""Return underlaying native dictionary.
|
"""Return underlaying native dictionary.
|
||||||
|
@ -99,7 +99,7 @@ def _check_dict(result, obj, output_names=None):
|
|||||||
assert _check_keys(result.keys(), outs)
|
assert _check_keys(result.keys(), outs)
|
||||||
assert _check_values(result)
|
assert _check_values(result)
|
||||||
assert _check_items(result, outs, output_names)
|
assert _check_items(result, outs, output_names)
|
||||||
assert result.names() == output_names
|
assert all([output_names[i] in result.names()[i] for i in range(0, len(output_names))])
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -124,6 +124,15 @@ def test_ovdict_single_output_basic(device, is_direct):
|
|||||||
raise TypeError("Unknown `obj` type!")
|
raise TypeError("Unknown `obj` type!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_direct", [True, False])
|
||||||
|
def test_ovdict_wrong_key_type(device, is_direct):
|
||||||
|
result, _ = _get_ovdict(device, multi_output=False, direct_infer=is_direct)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as e:
|
||||||
|
_ = result[2.0]
|
||||||
|
assert "Unknown key type: <class 'float'>" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_direct", [True, False])
|
@pytest.mark.parametrize("is_direct", [True, False])
|
||||||
def test_ovdict_single_output_noname(device, is_direct):
|
def test_ovdict_single_output_noname(device, is_direct):
|
||||||
result, obj = _get_ovdict(
|
result, obj = _get_ovdict(
|
||||||
@ -140,13 +149,13 @@ def test_ovdict_single_output_noname(device, is_direct):
|
|||||||
assert isinstance(result[outs[0]], np.ndarray)
|
assert isinstance(result[outs[0]], np.ndarray)
|
||||||
assert isinstance(result[0], np.ndarray)
|
assert isinstance(result[0], np.ndarray)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as e0:
|
with pytest.raises(KeyError) as e0:
|
||||||
_ = result["some_name"]
|
_ = result["some_name"]
|
||||||
assert "Attempt to get a name for a Tensor without names" in str(e0.value)
|
assert "some_name" in str(e0.value)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as e1:
|
# Check if returned names are tuple with one empty set
|
||||||
_ = result.names()
|
assert len(result.names()) == 1
|
||||||
assert "Attempt to get a name for a Tensor without names" in str(e1.value)
|
assert result.names()[0] == set()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_direct", [True, False])
|
@pytest.mark.parametrize("is_direct", [True, False])
|
||||||
|
Loading…
Reference in New Issue
Block a user