From 3a0a9f659e436f1f38d12a784cb01379e20f139f Mon Sep 17 00:00:00 2001 From: Takeshi KOMIYA Date: Sat, 16 Feb 2019 23:40:45 +0900 Subject: [PATCH] refactor: import_object() allows to import nested objects (ex. nested class) --- sphinx/util/__init__.py | 34 +++++++++++++++++++--------------- tests/test_util.py | 24 ++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/sphinx/util/__init__.py b/sphinx/util/__init__.py index bfe99778a..9c78630ff 100644 --- a/sphinx/util/__init__.py +++ b/sphinx/util/__init__.py @@ -603,22 +603,26 @@ class PeekableIterator: def import_object(objname, source=None): # type: (str, str) -> Any + """Import python object by qualname.""" try: - module, name = objname.rsplit('.', 1) - except ValueError as err: - raise ExtensionError('Invalid full object name %s' % objname + - (source and ' (needed for %s)' % source or ''), - err) - try: - return getattr(__import__(module, None, None, [name]), name) - except ImportError as err: - raise ExtensionError('Could not import %s' % module + - (source and ' (needed for %s)' % source or ''), - err) - except AttributeError as err: - raise ExtensionError('Could not find %s' % objname + - (source and ' (needed for %s)' % source or ''), - err) + objpath = objname.split('.') + modname = objpath.pop(0) + obj = __import__(modname) + for name in objpath: + modname += '.' + name + try: + obj = getattr(obj, name) + except AttributeError: + __import__(modname) + obj = getattr(obj, name) + + return obj + except (AttributeError, ImportError) as exc: + if source: + raise ExtensionError('Could not import %s (needed for %s)' % + (objname, source), exc) + else: + raise ExtensionError('Could not import %s' % objname, exc) def encode_uri(uri): diff --git a/tests/test_util.py b/tests/test_util.py index 0926096f4..bcac4b654 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,11 +15,11 @@ import pytest from mock import patch import sphinx -from sphinx.errors import PycodeError +from sphinx.errors import ExtensionError, PycodeError from sphinx.testing.util import strip_escseq from sphinx.util import ( SkipProgressMessage, display_chunk, encode_uri, ensuredir, get_module_source, - parselinenos, progress_message, status_iterator, xmlname_checker + import_object, parselinenos, progress_message, status_iterator, xmlname_checker ) from sphinx.util import logging @@ -71,6 +71,26 @@ def test_get_module_source(): get_module_source('itertools') +def test_import_object(): + module = import_object('sphinx') + assert module.__name__ == 'sphinx' + + module = import_object('sphinx.application') + assert module.__name__ == 'sphinx.application' + + obj = import_object('sphinx.application.Sphinx') + assert obj.__name__ == 'Sphinx' + + with pytest.raises(ExtensionError) as exc: + import_object('sphinx.unknown_module') + assert exc.value.args[0] == 'Could not import sphinx.unknown_module' + + with pytest.raises(ExtensionError) as exc: + import_object('sphinx.unknown_module', 'my extension') + assert exc.value.args[0] == ('Could not import sphinx.unknown_module ' + '(needed for my extension)') + + @pytest.mark.sphinx('dummy') @patch('sphinx.util.console._tw', 40) # terminal width = 40 def test_status_iterator(app, status, warning):