Merge pull request #4299 from tk0miya/refactor_MockImporter

autodoc: refactor MockImporter
This commit is contained in:
Takeshi KOMIYA 2017-12-14 22:13:16 +09:00 committed by GitHub
commit 9d44cb5952
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 44 deletions

View File

@ -24,7 +24,8 @@ from docutils.parsers.rst import Directive
from docutils.statemachine import ViewList from docutils.statemachine import ViewList
import sphinx import sphinx
from sphinx.ext.autodoc.importer import _MockImporter, import_module from sphinx.ext.autodoc.importer import mock, import_module
from sphinx.ext.autodoc.importer import _MockImporter # to keep compatibility # NOQA
from sphinx.ext.autodoc.inspector import format_annotation, formatargspec # to keep compatibility # NOQA from sphinx.ext.autodoc.inspector import format_annotation, formatargspec # to keep compatibility # NOQA
from sphinx.util import rpartition, force_decode from sphinx.util import rpartition, force_decode
from sphinx.locale import _ from sphinx.locale import _
@ -388,7 +389,7 @@ class Documenter(object):
self.modname, '.'.join(self.objpath)) self.modname, '.'.join(self.objpath))
# always enable mock import hook # always enable mock import hook
# it will do nothing if autodoc_mock_imports is empty # it will do nothing if autodoc_mock_imports is empty
import_hook = _MockImporter(self.env.config.autodoc_mock_imports) with mock(self.env.config.autodoc_mock_imports):
try: try:
logger.debug('[autodoc] import %s', self.modname) logger.debug('[autodoc] import %s', self.modname)
obj = import_module(self.modname, self.env.config.autodoc_warningiserror) obj = import_module(self.modname, self.env.config.autodoc_warningiserror)
@ -427,8 +428,6 @@ class Documenter(object):
self.directive.warn(errmsg) self.directive.warn(errmsg)
self.env.note_reread() self.env.note_reread()
return False return False
finally:
import_hook.disable()
def get_real_modname(self): def get_real_modname(self):
# type: () -> str # type: () -> str

View File

@ -10,15 +10,16 @@
""" """
import sys import sys
import traceback
import warnings import warnings
import traceback
import contextlib
from types import FunctionType, MethodType, ModuleType from types import FunctionType, MethodType, ModuleType
from sphinx.util import logging from sphinx.util import logging
if False: if False:
# For type annotation # For type annotation
from typing import Any, List, Set # NOQA from typing import Any, Generator, List, Set # NOQA
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,7 +78,6 @@ class _MockModule(ModuleType):
class _MockImporter(object): class _MockImporter(object):
def __init__(self, names): def __init__(self, names):
# type: (List[str]) -> None # type: (List[str]) -> None
self.base_packages = set() # type: Set[str] self.base_packages = set() # type: Set[str]
@ -120,6 +120,16 @@ class _MockImporter(object):
return module return module
@contextlib.contextmanager
def mock(names):
# type: (List[str]) -> Generator
try:
importer = _MockImporter(names)
yield
finally:
importer.disable()
def import_module(modname, warningiserror=False): def import_module(modname, warningiserror=False):
""" """
Call __import__(modname), convert exceptions to ImportError Call __import__(modname), convert exceptions to ImportError