Introduce types for the `setup()` function

This commit is contained in:
Adam Turner 2023-12-29 17:10:39 +00:00
parent 0145f95716
commit 20e804ab90
5 changed files with 31 additions and 9 deletions

View File

@ -7,7 +7,7 @@ import time
import traceback
import types
from os import getenv, path
from typing import TYPE_CHECKING, Any, Callable, Literal, NamedTuple
from typing import TYPE_CHECKING, Any, Literal, NamedTuple
from sphinx.errors import ConfigError, ExtensionError
from sphinx.locale import _, __
@ -27,6 +27,7 @@ if TYPE_CHECKING:
from sphinx.application import Sphinx
from sphinx.environment import BuildEnvironment
from sphinx.util.tags import Tags
from sphinx.util.typing import _ExtensionSetupFunc
logger = logging.getLogger(__name__)
@ -166,7 +167,7 @@ class Config:
self.overrides = dict(overrides) if overrides is not None else {}
self.values = Config.config_values.copy()
self._raw_config = raw_config
self.setup: Callable | None = raw_config.get('setup')
self.setup: _ExtensionSetupFunc | None = raw_config.get('setup')
if 'extensions' in self.overrides:
extensions = self.overrides.pop('extensions')

View File

@ -155,7 +155,7 @@ class BuildEnvironment:
self.config_status_extra: str = ''
self.events: EventManager = app.events
self.project: Project = app.project
self.version: dict[str, str] = app.registry.get_envversion(app)
self.version: dict[str, int] = app.registry.get_envversion(app)
# the method of doctree versioning; see set_versioning_method
self.versioning_condition: bool | Callable | None = None

View File

@ -13,6 +13,7 @@ from sphinx.util import logging
if TYPE_CHECKING:
from sphinx.application import Sphinx
from sphinx.config import Config
from sphinx.util.typing import _ExtensionMetadata
logger = logging.getLogger(__name__)
@ -21,7 +22,7 @@ class Extension:
def __init__(self, name: str, module: Any, **kwargs: Any) -> None:
self.name = name
self.module = module
self.metadata = kwargs
self.metadata: _ExtensionMetadata = kwargs # type: ignore[assignment]
self.version = kwargs.pop('version', 'unknown version')
# The extension supports parallel read or not. The default value

View File

@ -39,7 +39,12 @@ if TYPE_CHECKING:
from sphinx.config import Config
from sphinx.environment import BuildEnvironment
from sphinx.ext.autodoc import Documenter
from sphinx.util.typing import RoleFunction, TitleGetter
from sphinx.util.typing import (
RoleFunction,
TitleGetter,
_ExtensionMetadata,
_ExtensionSetupFunc,
)
logger = logging.getLogger(__name__)
@ -450,11 +455,11 @@ class SphinxComponentRegistry:
raise ExtensionError(__('Could not import extension %s') % extname,
err) from err
setup = getattr(mod, 'setup', None)
setup: _ExtensionSetupFunc | None = getattr(mod, 'setup', None)
if setup is None:
logger.warning(__('extension %r has no setup() function; is it really '
'a Sphinx extension module?'), extname)
metadata: dict[str, Any] = {}
metadata: _ExtensionMetadata = {}
else:
try:
metadata = setup(app)
@ -476,7 +481,7 @@ class SphinxComponentRegistry:
app.extensions[extname] = Extension(extname, mod, **metadata)
def get_envversion(self, app: Sphinx) -> dict[str, str]:
def get_envversion(self, app: Sphinx) -> dict[str, int]:
from sphinx.environment import ENV_VERSION
envversion = {ext.name: ext.metadata['env_version'] for ext in app.extensions.values()
if ext.metadata.get('env_version')}

View File

@ -7,7 +7,7 @@ import typing
from collections.abc import Sequence
from struct import Struct
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypedDict, TypeVar, Union
from docutils import nodes
from docutils.parsers.rst.states import Inliner
@ -15,6 +15,8 @@ from docutils.parsers.rst.states import Inliner
if TYPE_CHECKING:
import enum
from sphinx.application import Sphinx
if sys.version_info >= (3, 10):
from types import UnionType
else:
@ -64,6 +66,19 @@ InventoryItem = tuple[
Inventory = dict[str, dict[str, InventoryItem]]
# return of a setup() function
# https://www.sphinx-doc.org/en/master/extdev/index.html#extension-metadata
class _ExtensionMetadata(TypedDict, total=False):
version: str
env_version: int
parallel_read_safe: bool
parallel_write_safe: bool
if TYPE_CHECKING:
_ExtensionSetupFunc = Callable[[Sphinx], _ExtensionMetadata]
def get_type_hints(
obj: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None,
) -> dict[str, Any]: