From acc92ff6151c316967fb23f7246214cb440b3693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Tue, 23 Apr 2024 06:55:24 +0200 Subject: [PATCH] Use ``TypeGuard`` in ``sphinx.util.inspect`` (#12283) Co-authored-by: Adam Turner <9087854+aa-turner@users.noreply.github.com> --- sphinx/ext/autodoc/mock.py | 4 +- sphinx/util/inspect.py | 77 ++++++++++++++++++++++++++++---------- sphinx/util/typing.py | 18 ++++----- 3 files changed, 69 insertions(+), 30 deletions(-) diff --git a/sphinx/ext/autodoc/mock.py b/sphinx/ext/autodoc/mock.py index f17c3302c..7639c4626 100644 --- a/sphinx/ext/autodoc/mock.py +++ b/sphinx/ext/autodoc/mock.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence from typing import Any + from typing_extensions import TypeGuard + logger = logging.getLogger(__name__) @@ -154,7 +156,7 @@ def mock(modnames: list[str]) -> Iterator[None]: finder.invalidate_caches() -def ismockmodule(subject: Any) -> bool: +def ismockmodule(subject: Any) -> TypeGuard[_MockModule]: """Check if the object is a mocked module.""" return isinstance(subject, _MockModule) diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py index dfd1d01ef..da487a05a 100644 --- a/sphinx/util/inspect.py +++ b/sphinx/util/inspect.py @@ -27,7 +27,36 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence from inspect import _ParameterKind from types import MethodType, ModuleType - from typing import Final + from typing import Final, Protocol, Union + + from typing_extensions import TypeAlias, TypeGuard + + class _SupportsGet(Protocol): + def __get__(self, __instance: Any, __owner: type | None = ...) -> Any: ... # NoQA: E704 + + class _SupportsSet(Protocol): + # instance and value are contravariants but we do not need that precision + def __set__(self, __instance: Any, __value: Any) -> None: ... # NoQA: E704 + + class _SupportsDelete(Protocol): + # instance is contravariant but we do not need that precision + def __delete__(self, __instance: Any) -> None: ... # NoQA: E704 + + _RoutineType: TypeAlias = Union[ + types.FunctionType, + types.LambdaType, + types.MethodType, + types.BuiltinFunctionType, + types.BuiltinMethodType, + types.WrapperDescriptorType, + types.MethodDescriptorType, + types.ClassMethodDescriptorType, + ] + _SignatureType: TypeAlias = Union[ + Callable[..., Any], + staticmethod, + classmethod, + ] logger = logging.getLogger(__name__) @@ -90,7 +119,7 @@ def unwrap_all(obj: Any, *, stop: Callable[[Any], bool] | None = None) -> Any: def getall(obj: Any) -> Sequence[str] | None: - """Get the ``__all__`` attribute of an object as sequence. + """Get the ``__all__`` attribute of an object as a sequence. This returns ``None`` if the given ``obj.__all__`` does not exist and raises :exc:`ValueError` if ``obj.__all__`` is not a list or tuple of @@ -184,12 +213,12 @@ def isNewType(obj: Any) -> bool: return __module__ == 'typing' and __qualname__ == 'NewType..new_type' -def isenumclass(x: Any) -> bool: +def isenumclass(x: Any) -> TypeGuard[type[enum.Enum]]: """Check if the object is an :class:`enumeration class `.""" return isclass(x) and issubclass(x, enum.Enum) -def isenumattribute(x: Any) -> bool: +def isenumattribute(x: Any) -> TypeGuard[enum.Enum]: """Check if the object is an enumeration attribute.""" return isinstance(x, enum.Enum) @@ -206,12 +235,16 @@ def unpartial(obj: Any) -> Any: return obj -def ispartial(obj: Any) -> bool: +def ispartial(obj: Any) -> TypeGuard[partial | partialmethod]: """Check if the object is a partial function or method.""" return isinstance(obj, (partial, partialmethod)) -def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: +def isclassmethod( + obj: Any, + cls: Any = None, + name: str | None = None, +) -> TypeGuard[classmethod]: """Check if the object is a :class:`classmethod`.""" if isinstance(obj, classmethod): return True @@ -227,7 +260,11 @@ def isclassmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: return False -def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: +def isstaticmethod( + obj: Any, + cls: Any = None, + name: str | None = None, +) -> TypeGuard[staticmethod]: """Check if the object is a :class:`staticmethod`.""" if isinstance(obj, staticmethod): return True @@ -241,7 +278,7 @@ def isstaticmethod(obj: Any, cls: Any = None, name: str | None = None) -> bool: return False -def isdescriptor(x: Any) -> bool: +def isdescriptor(x: Any) -> TypeGuard[_SupportsGet | _SupportsSet | _SupportsDelete]: """Check if the object is a :external+python:term:`descriptor`.""" return any( callable(safe_getattr(x, item, None)) for item in ('__get__', '__set__', '__delete__') @@ -308,12 +345,12 @@ def is_singledispatch_function(obj: Any) -> bool: ) -def is_singledispatch_method(obj: Any) -> bool: +def is_singledispatch_method(obj: Any) -> TypeGuard[singledispatchmethod]: """Check if the object is a :class:`~functools.singledispatchmethod`.""" return isinstance(obj, singledispatchmethod) -def isfunction(obj: Any) -> bool: +def isfunction(obj: Any) -> TypeGuard[types.FunctionType]: """Check if the object is a user-defined function. Partial objects are unwrapped before checking them. @@ -323,7 +360,7 @@ def isfunction(obj: Any) -> bool: return inspect.isfunction(unpartial(obj)) -def isbuiltin(obj: Any) -> bool: +def isbuiltin(obj: Any) -> TypeGuard[types.BuiltinFunctionType]: """Check if the object is a built-in function or method. Partial objects are unwrapped before checking them. @@ -333,7 +370,7 @@ def isbuiltin(obj: Any) -> bool: return inspect.isbuiltin(unpartial(obj)) -def isroutine(obj: Any) -> bool: +def isroutine(obj: Any) -> TypeGuard[_RoutineType]: """Check if the object is a kind of function or method. Partial objects are unwrapped before checking them. @@ -343,7 +380,7 @@ def isroutine(obj: Any) -> bool: return inspect.isroutine(unpartial(obj)) -def iscoroutinefunction(obj: Any) -> bool: +def iscoroutinefunction(obj: Any) -> TypeGuard[Callable[..., types.CoroutineType]]: """Check if the object is a :external+python:term:`coroutine` function.""" obj = unwrap_all(obj, stop=_is_wrapped_coroutine) return inspect.iscoroutinefunction(obj) @@ -358,12 +395,12 @@ def _is_wrapped_coroutine(obj: Any) -> bool: return hasattr(obj, '__wrapped__') -def isproperty(obj: Any) -> bool: +def isproperty(obj: Any) -> TypeGuard[property | cached_property]: """Check if the object is property (possibly cached).""" return isinstance(obj, (property, cached_property)) -def isgenericalias(obj: Any) -> bool: +def isgenericalias(obj: Any) -> TypeGuard[types.GenericAlias]: """Check if the object is a generic alias.""" return isinstance(obj, (types.GenericAlias, typing._BaseGenericAlias)) # type: ignore[attr-defined] @@ -579,7 +616,7 @@ class TypeAliasNamespace(dict[str, Any]): raise KeyError -def _should_unwrap(subject: Callable[..., Any]) -> bool: +def _should_unwrap(subject: _SignatureType) -> bool: """Check the function should be unwrapped on getting signature.""" __globals__ = getglobals(subject) # contextmanger should be unwrapped @@ -590,7 +627,7 @@ def _should_unwrap(subject: Callable[..., Any]) -> bool: def signature( - subject: Callable[..., Any], + subject: _SignatureType, bound_method: bool = False, type_aliases: Mapping[str, str] | None = None, ) -> Signature: @@ -603,12 +640,12 @@ def signature( try: if _should_unwrap(subject): - signature = inspect.signature(subject) + signature = inspect.signature(subject) # type: ignore[arg-type] else: - signature = inspect.signature(subject, follow_wrapped=True) + signature = inspect.signature(subject, follow_wrapped=True) # type: ignore[arg-type] except ValueError: # follow built-in wrappers up (ex. functools.lru_cache) - signature = inspect.signature(subject) + signature = inspect.signature(subject) # type: ignore[arg-type] parameters = list(signature.parameters.values()) return_annotation = signature.return_annotation diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py index 007adca9f..39056f91b 100644 --- a/sphinx/util/typing.py +++ b/sphinx/util/typing.py @@ -164,7 +164,7 @@ def is_system_TypeVar(typ: Any) -> bool: return modname == 'typing' and isinstance(typ, TypeVar) -def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typing') -> str: +def restify(cls: Any, mode: _RestifyMode = 'fully-qualified-except-typing') -> str: """Convert python class to a reST reference. :param mode: Specify a method how annotations will be stringified. @@ -229,17 +229,17 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin return f':py:class:`{cls.__name__}`' elif (inspect.isgenericalias(cls) and cls.__module__ == 'typing' - and cls.__origin__ is Union): # type: ignore[attr-defined] + and cls.__origin__ is Union): # *cls* is defined in ``typing``, and thus ``__args__`` must exist - return ' | '.join(restify(a, mode) for a in cls.__args__) # type: ignore[attr-defined] + return ' | '.join(restify(a, mode) for a in cls.__args__) elif inspect.isgenericalias(cls): - if isinstance(cls.__origin__, typing._SpecialForm): # type: ignore[attr-defined] - text = restify(cls.__origin__, mode) # type: ignore[attr-defined,arg-type] + if isinstance(cls.__origin__, typing._SpecialForm): + text = restify(cls.__origin__, mode) elif getattr(cls, '_name', None): - cls_name = cls._name # type: ignore[attr-defined] + cls_name = cls._name text = f':py:class:`{modprefix}{cls.__module__}.{cls_name}`' else: - text = restify(cls.__origin__, mode) # type: ignore[attr-defined] + text = restify(cls.__origin__, mode) origin = getattr(cls, '__origin__', None) if not hasattr(cls, '__args__'): # NoQA: SIM114 @@ -247,7 +247,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin elif all(is_system_TypeVar(a) for a in cls.__args__): # Suppress arguments if all system defined TypeVars (ex. Dict[KT, VT]) pass - elif cls.__module__ == 'typing' and cls._name == 'Callable': # type: ignore[attr-defined] + elif cls.__module__ == 'typing' and cls._name == 'Callable': args = ', '.join(restify(a, mode) for a in cls.__args__[:-1]) text += fr'\ [[{args}], {restify(cls.__args__[-1], mode)}]' elif cls.__module__ == 'typing' and getattr(origin, '_name', None) == 'Literal': @@ -259,7 +259,7 @@ def restify(cls: type | None, mode: _RestifyMode = 'fully-qualified-except-typin return text elif isinstance(cls, typing._SpecialForm): - return f':py:obj:`~{cls.__module__}.{cls._name}`' + return f':py:obj:`~{cls.__module__}.{cls._name}`' # type: ignore[attr-defined] elif sys.version_info[:2] >= (3, 11) and cls is typing.Any: # handle bpo-46998 return f':py:obj:`~{cls.__module__}.{cls.__name__}`'