Reimplement needs_extensions checker as a config-init handler

This commit is contained in:
Takeshi KOMIYA 2018-01-27 22:07:27 +09:00
parent 4647fcee45
commit f49a7c9024
2 changed files with 18 additions and 9 deletions

View File

@ -33,7 +33,6 @@ from sphinx.errors import (
ApplicationError, ConfigError, ExtensionError, VersionRequirementError ApplicationError, ConfigError, ExtensionError, VersionRequirementError
) )
from sphinx.events import EventManager from sphinx.events import EventManager
from sphinx.extension import verify_required_extensions
from sphinx.locale import __ from sphinx.locale import __
from sphinx.registry import SphinxComponentRegistry from sphinx.registry import SphinxComponentRegistry
from sphinx.util import import_object from sphinx.util import import_object
@ -86,6 +85,7 @@ builtin_extensions = (
'sphinx.directives.code', 'sphinx.directives.code',
'sphinx.directives.other', 'sphinx.directives.other',
'sphinx.directives.patches', 'sphinx.directives.patches',
'sphinx.extension',
'sphinx.io', 'sphinx.io',
'sphinx.parsers', 'sphinx.parsers',
'sphinx.registry', 'sphinx.registry',
@ -238,9 +238,6 @@ class Sphinx(object):
self.config.init_values() self.config.init_values()
self.emit('config-inited', self.config) self.emit('config-inited', self.config)
# check extension versions if requested
verify_required_extensions(self, self.config.needs_extensions)
# check primary_domain if requested # check primary_domain if requested
primary_domain = self.config.primary_domain primary_domain = self.config.primary_domain
if primary_domain and not self.registry.has_domain(primary_domain): if primary_domain and not self.registry.has_domain(primary_domain):

View File

@ -17,8 +17,9 @@ from sphinx.util import logging
if False: if False:
# For type annotation # For type annotation
from typing import Dict # NOQA from typing import Any, Dict # NOQA
from sphinx.application import Sphinx # NOQA from sphinx.application import Sphinx # NOQA
from sphinx.config import Config # NOQA
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,13 +42,13 @@ class Extension(object):
self.parallel_write_safe = kwargs.pop('parallel_write_safe', True) self.parallel_write_safe = kwargs.pop('parallel_write_safe', True)
def verify_required_extensions(app, requirements): def verify_needs_extensions(app, config):
# type: (Sphinx, Dict[unicode, unicode]) -> None # type: (Sphinx, Config) -> None
"""Verify the required Sphinx extensions are loaded.""" """Verify the required Sphinx extensions are loaded."""
if requirements is None: if config.needs_extensions is None:
return return
for extname, reqversion in iteritems(requirements): for extname, reqversion in iteritems(config.needs_extensions):
extension = app.extensions.get(extname) extension = app.extensions.get(extname)
if extension is None: if extension is None:
logger.warning(__('The %s extension is required by needs_extensions settings,' logger.warning(__('The %s extension is required by needs_extensions settings,'
@ -59,3 +60,14 @@ def verify_required_extensions(app, requirements):
'version %s and therefore cannot be built with ' 'version %s and therefore cannot be built with '
'the loaded version (%s).') % 'the loaded version (%s).') %
(extname, reqversion, extension.version)) (extname, reqversion, extension.version))
def setup(app):
# type: (Sphinx) -> Dict[unicode, Any]
app.connect('config-inited', verify_needs_extensions)
return {
'version': 'builtin',
'parallel_read_safe': True,
'parallel_write_safe': True,
}