diff --git a/src/bindings/python/src/openvino/runtime/utils/data_helpers/wrappers.py b/src/bindings/python/src/openvino/runtime/utils/data_helpers/wrappers.py index e2849b8d5e0..1bf23a7cad4 100644 --- a/src/bindings/python/src/openvino/runtime/utils/data_helpers/wrappers.py +++ b/src/bindings/python/src/openvino/runtime/utils/data_helpers/wrappers.py @@ -12,7 +12,8 @@ except ImportError: from singledispatchmethod import singledispatchmethod # type: ignore[no-redef] from collections.abc import Mapping -from typing import Union, Dict, List, Iterator, KeysView, ItemsView, ValuesView +from typing import Dict, Set, Tuple, Union, Iterator, Optional +from typing import KeysView, ItemsView, ValuesView from openvino._pyopenvino import Tensor, ConstOutput from openvino._pyopenvino import InferRequest as InferRequestBase @@ -71,6 +72,7 @@ class OVDict(Mapping): """ def __init__(self, _dict: Dict[ConstOutput, np.ndarray]) -> None: self._dict = _dict + self._names: Optional[Dict[ConstOutput, Set[str]]] = None def __iter__(self) -> Iterator: return self._dict.__iter__() @@ -81,12 +83,19 @@ class OVDict(Mapping): def __repr__(self) -> str: return self._dict.__repr__() + def __get_names(self) -> Dict[ConstOutput, Set[str]]: + """Return names of every output key. + + Insert empty set if key has no name. + """ + return {key: key.get_names() for key in self._dict.keys()} + def __get_key(self, index: int) -> ConstOutput: return list(self._dict.keys())[index] @singledispatchmethod def __getitem_impl(self, key: Union[ConstOutput, int, str]) -> np.ndarray: - raise TypeError("Unknown key type!") + raise TypeError(f"Unknown key type: {type(key)}") @__getitem_impl.register def _(self, key: ConstOutput) -> np.ndarray: @@ -101,10 +110,12 @@ class OVDict(Mapping): @__getitem_impl.register def _(self, key: str) -> np.ndarray: - try: - return self._dict[self.__get_key(self.names().index(key))] - except ValueError: - raise KeyError(key) + if self._names is None: + self._names = self.__get_names() + for port, port_names in self._names.items(): + if key in port_names: + return self._dict[port] + raise KeyError(key) def __getitem__(self, key: Union[ConstOutput, int, str]) -> np.ndarray: return self.__getitem_impl(key) @@ -118,12 +129,14 @@ class OVDict(Mapping): def items(self) -> ItemsView[ConstOutput, np.ndarray]: return self._dict.items() - def names(self) -> List[str]: - """Return a name of every output key. + def names(self) -> Tuple[Set[str], ...]: + """Return names of every output key. - Throws RuntimeError if any of ConstOutput keys has no name. + Insert empty set if key has no name. """ - return [key.get_any_name() for key in self._dict.keys()] + if self._names is None: + self._names = self.__get_names() + return tuple(self._names.values()) def to_dict(self) -> Dict[ConstOutput, np.ndarray]: """Return underlaying native dictionary. diff --git a/src/bindings/python/tests/test_runtime/test_ovdict.py b/src/bindings/python/tests/test_runtime/test_ovdict.py index e8c76a6d8d3..460cf68d73f 100644 --- a/src/bindings/python/tests/test_runtime/test_ovdict.py +++ b/src/bindings/python/tests/test_runtime/test_ovdict.py @@ -99,7 +99,7 @@ def _check_dict(result, obj, output_names=None): assert _check_keys(result.keys(), outs) assert _check_values(result) assert _check_items(result, outs, output_names) - assert result.names() == output_names + assert all([output_names[i] in result.names()[i] for i in range(0, len(output_names))]) return True @@ -124,6 +124,15 @@ def test_ovdict_single_output_basic(device, is_direct): raise TypeError("Unknown `obj` type!") +@pytest.mark.parametrize("is_direct", [True, False]) +def test_ovdict_wrong_key_type(device, is_direct): + result, _ = _get_ovdict(device, multi_output=False, direct_infer=is_direct) + + with pytest.raises(TypeError) as e: + _ = result[2.0] + assert "Unknown key type: " in str(e.value) + + @pytest.mark.parametrize("is_direct", [True, False]) def test_ovdict_single_output_noname(device, is_direct): result, obj = _get_ovdict( @@ -140,13 +149,13 @@ def test_ovdict_single_output_noname(device, is_direct): assert isinstance(result[outs[0]], np.ndarray) assert isinstance(result[0], np.ndarray) - with pytest.raises(RuntimeError) as e0: + with pytest.raises(KeyError) as e0: _ = result["some_name"] - assert "Attempt to get a name for a Tensor without names" in str(e0.value) + assert "some_name" in str(e0.value) - with pytest.raises(RuntimeError) as e1: - _ = result.names() - assert "Attempt to get a name for a Tensor without names" in str(e1.value) + # Check if returned names are tuple with one empty set + assert len(result.names()) == 1 + assert result.names()[0] == set() @pytest.mark.parametrize("is_direct", [True, False])