Use `TypeGuard in sphinx.util.inspect` (#12283)

Co-authored-by: Adam Turner <9087854+aa-turner@users.noreply.github.com>
This commit is contained in:
Bénédikt Tran 2024-04-23 06:55:24 +02:00 committed by GitHub
parent b6948b8d74
commit acc92ff615
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 30 deletions

View File

@ -17,6 +17,8 @@ if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from typing import Any
from typing_extensions import TypeGuard
logger = logging.getLogger(__name__)
@ -154,7 +156,7 @@ def mock(modnames: list[str]) -> Iterator[None]:
finder.invalidate_caches()
def ismockmodule(subject: Any) -> bool:
def ismockmodule(subject: Any) -> TypeGuard[_MockModule]:
"""Check if the object is a mocked module."""
return isinstance(subject, _MockModule)

View File

@ -27,7 +27,36 @@ if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from inspect import _ParameterKind
from types import MethodType, ModuleType
from typing import Final
from typing import Final, Protocol, Union
from typing_extensions import TypeAlias, TypeGuard
class _SupportsGet(Protocol):
def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704
class _SupportsSet(Protocol):
# instance and value are contravariants but we do not need that precision
def __set__(self, __instance: Any, __value: Any) -> None: ... # NoQA: E704
class _SupportsDelete(Protocol):
# instance is contravariant but we do not need that precision
def __delete__(self, __instance: Any) -> None: ... # NoQA: E704
_RoutineType: TypeAlias = Union[
types.FunctionType,
types.LambdaType,
types.MethodType,
types.BuiltinFunctionType,
types.BuiltinMethodType,
types.WrapperDescriptorType,
types.MethodDescriptorType,
types.ClassMethodDescriptorType,
]
_SignatureType: TypeAlias = Union[
Callable[..., Any],
staticmethod,
classmethod,
]
logger = logging.getLogger(__name__)
@ -90,7 +119,7 @@ def unwrap_all(obj: Any, *, stop: Callable[[Any], bool] | None = None) -> Any:
def getall(obj: Any) -> Sequence[str] | None:
"""Get the ``__all__`` attribute of an object as sequence.
"""Get the ``__all__`` attribute of an object as a sequence.
This returns ``None`` if the given ``obj.__all__`` does not exist and
raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of
@ -184,12 +213,12 @@ def isNewType(obj: Any) -> bool:
return __module__ == 'typing' and __qualname__ == 'NewType.<locals>.new_type'
def isenumclass(x: Any) -> bool:
def isenumclass(x: Any) -> TypeGuard[type[enum.Enum]]:
"""Check if the object is an :class:`enumeration class <enum.Enum>`."""
return isclass(x) and issubclass(x, enum.Enum)
def isenumattribute(x: Any) -> bool:
def isenumattribute(x: Any) -> TypeGuard[enum.Enum]:
"""Check if the object is an enumeration attribute."""
return isinstance(x, enum.Enum)
@ -206,12 +235,16 @@ def unpartial(obj: Any) -> Any:
return obj
def ispartial(obj: Any) -> bool:
def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]:
"""Check if the object is a partial function or method."""
return isinstance(obj, (partial, partialmethod))
def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
def isclassmethod(
obj: Any,
cls: Any = None,
name: str | None = None,
) -> TypeGuard[classmethod]:
"""Check if the object is a :class:`classmethod`."""
if isinstance(obj, classmethod):
return True
@ -227,7 +260,11 @@ def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
return False
def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
def isstaticmethod(
obj: Any,
cls: Any = None,
name: str | None = None,
) -> TypeGuard[staticmethod]:
"""Check if the object is a :class:`staticmethod`."""
if isinstance(obj, staticmethod):
return True
@ -241,7 +278,7 @@ def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool:
return False
def isdescriptor(x: Any) -> bool:
def isdescriptor(x: Any) -> TypeGuard[_SupportsGet | _SupportsSet | _SupportsDelete]:
"""Check if the object is a :external+python:term:`descriptor`."""
return any(
callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__')
@ -308,12 +345,12 @@ def is_singledispatch_function(obj: Any) -> bool:
)
def is_singledispatch_method(obj: Any) -> bool:
def is_singledispatch_method(obj: Any) -> TypeGuard[singledispatchmethod]:
"""Check if the object is a :class:`~functools.singledispatchmethod`."""
return isinstance(obj, singledispatchmethod)
def isfunction(obj: Any) -> bool:
def isfunction(obj: Any) -> TypeGuard[types.FunctionType]:
"""Check if the object is a user-defined function.
Partial objects are unwrapped before checking them.
@ -323,7 +360,7 @@ def isfunction(obj: Any) -> bool:
return inspect.isfunction(unpartial(obj))
def isbuiltin(obj: Any) -> bool:
def isbuiltin(obj: Any) -> TypeGuard[types.BuiltinFunctionType]:
"""Check if the object is a built-in function or method.
Partial objects are unwrapped before checking them.
@ -333,7 +370,7 @@ def isbuiltin(obj: Any) -> bool:
return inspect.isbuiltin(unpartial(obj))
def isroutine(obj: Any) -> bool:
def isroutine(obj: Any) -> TypeGuard[_RoutineType]:
"""Check if the object is a kind of function or method.
Partial objects are unwrapped before checking them.
@ -343,7 +380,7 @@ def isroutine(obj: Any) -> bool:
return inspect.isroutine(unpartial(obj))
def iscoroutinefunction(obj: Any) -> bool:
def iscoroutinefunction(obj: Any) -> TypeGuard[Callable[..., types.CoroutineType]]:
"""Check if the object is a :external+python:term:`coroutine` function."""
obj = unwrap_all(obj, stop=_is_wrapped_coroutine)
return inspect.iscoroutinefunction(obj)
@ -358,12 +395,12 @@ def _is_wrapped_coroutine(obj: Any) -> bool:
return hasattr(obj, '__wrapped__')
def isproperty(obj: Any) -> bool:
def isproperty(obj: Any) -> TypeGuard[property | cached_property]:
"""Check if the object is property (possibly cached)."""
return isinstance(obj, (property, cached_property))
def isgenericalias(obj: Any) -> bool:
def isgenericalias(obj: Any) -> TypeGuard[types.GenericAlias]:
"""Check if the object is a generic alias."""
return isinstance(obj, (types.GenericAlias, typing._BaseGenericAlias)) # type: ignore[attr-defined]
@ -579,7 +616,7 @@ class TypeAliasNamespace(dict[str, Any]):
raise KeyError
def _should_unwrap(subject: Callable[..., Any]) -> bool:
def _should_unwrap(subject: _SignatureType) -> bool:
"""Check the function should be unwrapped on getting signature."""
__globals__ = getglobals(subject)
# contextmanger should be unwrapped
@ -590,7 +627,7 @@ def _should_unwrap(subject: Callable[..., Any]) -> bool:
def signature(
subject: Callable[..., Any],
subject: _SignatureType,
bound_method: bool = False,
type_aliases: Mapping[str, str] | None = None,
) -> Signature:
@ -603,12 +640,12 @@ def signature(
try:
if _should_unwrap(subject):
signature = inspect.signature(subject)
signature = inspect.signature(subject) # type: ignore[arg-type]
else:
signature = inspect.signature(subject, follow_wrapped=True)
signature = inspect.signature(subject, follow_wrapped=True) # type: ignore[arg-type]
except ValueError:
# follow built-in wrappers up (ex. functools.lru_cache)
signature = inspect.signature(subject)
signature = inspect.signature(subject) # type: ignore[arg-type]
parameters = list(signature.parameters.values())
return_annotation = signature.return_annotation

View File

@ -164,7 +164,7 @@ def is_system_TypeVar(typ: Any) -> bool:
return modname == 'typing' and isinstance(typ, TypeVar)
def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typing') -> str:
def restify(cls: Any, mode: _RestifyMode = 'fully-qualified-except-typing') -> str:
"""Convert python class to a reST reference.
:param mode: Specify a method how annotations will be stringified.
@ -229,17 +229,17 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin
return f':py:class:`{cls.__name__}`'
elif (inspect.isgenericalias(cls)
and cls.__module__ == 'typing'
and cls.__origin__ is Union): # type: ignore[attr-defined]
and cls.__origin__ is Union):
# *cls* is defined in ``typing``, and thus ``__args__`` must exist
return ' | '.join(restify(a, mode) for a in cls.__args__) # type: ignore[attr-defined]
return ' | '.join(restify(a, mode) for a in cls.__args__)
elif inspect.isgenericalias(cls):
if isinstance(cls.__origin__, typing._SpecialForm): # type: ignore[attr-defined]
text = restify(cls.__origin__, mode) # type: ignore[attr-defined,arg-type]
if isinstance(cls.__origin__, typing._SpecialForm):
text = restify(cls.__origin__, mode)
elif getattr(cls, '_name', None):
cls_name = cls._name # type: ignore[attr-defined]
cls_name = cls._name
text = f':py:class:`{modprefix}{cls.__module__}.{cls_name}`'
else:
text = restify(cls.__origin__, mode) # type: ignore[attr-defined]
text = restify(cls.__origin__, mode)
origin = getattr(cls, '__origin__', None)
if not hasattr(cls, '__args__'): # NoQA: SIM114
@ -247,7 +247,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin
elif all(is_system_TypeVar(a) for a in cls.__args__):
# Suppress arguments if all system defined TypeVars (ex. Dict[KT, VT])
pass
elif cls.__module__ == 'typing' and cls._name == 'Callable': # type: ignore[attr-defined]
elif cls.__module__ == 'typing' and cls._name == 'Callable':
args = ', '.join(restify(a, mode) for a in cls.__args__[:-1])
text += fr'\ [[{args}], {restify(cls.__args__[-1], mode)}]'
elif cls.__module__ == 'typing' and getattr(origin, '_name', None) == 'Literal':
@ -259,7 +259,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin
return text
elif isinstance(cls, typing._SpecialForm):
return f':py:obj:`~{cls.__module__}.{cls._name}`'
return f':py:obj:`~{cls.__module__}.{cls._name}`' # type: ignore[attr-defined]
elif sys.version_info[:2] >= (3, 11) and cls is typing.Any:
# handle bpo-46998
return f':py:obj:`~{cls.__module__}.{cls.__name__}`'