Migrate to py3 style type annotation: sphinx.util

This commit is contained in:
Takeshi KOMIYA 2019-05-30 01:19:21 +09:00
parent 27dd8367c6
commit afbf6d811d
3 changed files with 71 additions and 108 deletions

View File

@ -24,6 +24,9 @@ from datetime import datetime
from hashlib import md5
from os import path
from time import mktime, strptime
from typing import (
Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Set, Tuple, Type
)
from urllib.parse import urlsplit, urlunsplit, quote_plus, parse_qsl, urlencode
from docutils.utils import relative_path
@ -34,6 +37,7 @@ from sphinx.locale import __
from sphinx.util import logging
from sphinx.util.console import strip_colors, colorize, bold, term_width_line # type: ignore
from sphinx.util.fileutil import copy_asset_file
from sphinx.util.typing import PathMatcher
from sphinx.util import smartypants # noqa
# import other utilities; partly for backwards compatibility, so don't
@ -46,10 +50,11 @@ from sphinx.util.nodes import ( # noqa
caption_ref_re)
from sphinx.util.matching import patfilter # noqa
if False:
# For type annotation
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Set, Tuple, Type, Union # NOQA
from sphinx.application import Sphinx
from sphinx.builders import Builder
logger = logging.getLogger(__name__)
@ -60,21 +65,19 @@ url_re = re.compile(r'(?P<schema>.+)://.*') # type: Pattern
# High-level utility functions.
def docname_join(basedocname, docname):
# type: (str, str) -> str
def docname_join(basedocname: str, docname: str) -> str:
return posixpath.normpath(
posixpath.join('/' + basedocname, '..', docname))[1:]
def path_stabilize(filepath):
# type: (str) -> str
def path_stabilize(filepath: str) -> str:
"normalize path separater and unicode string"
newpath = filepath.replace(os.path.sep, SEP)
return unicodedata.normalize('NFC', newpath)
def get_matching_files(dirname, exclude_matchers=()):
# type: (str, Tuple[Callable[[str], bool], ...]) -> Iterable[str]
def get_matching_files(dirname: str,
exclude_matchers: Tuple[PathMatcher, ...] = ()) -> Iterable[str]: # NOQA
"""Get all file names in a directory, recursively.
Exclude files and dirs matching some matcher in *exclude_matchers*.
@ -100,8 +103,8 @@ def get_matching_files(dirname, exclude_matchers=()):
yield filename
def get_matching_docs(dirname, suffixes, exclude_matchers=()):
# type: (str, List[str], Tuple[Callable[[str], bool], ...]) -> Iterable[str] # NOQA
def get_matching_docs(dirname: str, suffixes: List[str],
exclude_matchers: Tuple[PathMatcher, ...] = ()) -> Iterable[str]:
"""Get all file names (without suffixes) matching a suffix in a directory,
recursively.
@ -123,12 +126,10 @@ class FilenameUniqDict(dict):
interpreted as filenames, and keeps track of a set of docnames they
appear in. Used for images and downloadable files in the environment.
"""
def __init__(self):
# type: () -> None
def __init__(self) -> None:
self._existing = set() # type: Set[str]
def add_file(self, docname, newfile):
# type: (str, str) -> str
def add_file(self, docname: str, newfile: str) -> str:
if newfile in self:
self[newfile][0].add(docname)
return self[newfile][1]
@ -142,26 +143,22 @@ class FilenameUniqDict(dict):
self._existing.add(uniquename)
return uniquename
def purge_doc(self, docname):
# type: (str) -> None
def purge_doc(self, docname: str) -> None:
for filename, (docs, unique) in list(self.items()):
docs.discard(docname)
if not docs:
del self[filename]
self._existing.discard(unique)
def merge_other(self, docnames, other):
# type: (Set[str], Dict[str, Tuple[Set[str], Any]]) -> None
def merge_other(self, docnames: Set[str], other: Dict[str, Tuple[Set[str], Any]]) -> None:
for filename, (docs, unique) in other.items():
for doc in docs & set(docnames):
self.add_file(doc, filename)
def __getstate__(self):
# type: () -> Set[str]
def __getstate__(self) -> Set[str]:
return self._existing
def __setstate__(self, state):
# type: (Set[str]) -> None
def __setstate__(self, state: Set[str]) -> None:
self._existing = state
@ -172,8 +169,7 @@ class DownloadFiles(dict):
Hence don't hack this directly.
"""
def add_file(self, docname, filename):
# type: (str, str) -> None
def add_file(self, docname: str, filename: str) -> None:
if filename not in self:
digest = md5(filename.encode()).hexdigest()
dest = '%s/%s' % (digest, os.path.basename(filename))
@ -182,23 +178,20 @@ class DownloadFiles(dict):
self[filename][0].add(docname)
return self[filename][1]
def purge_doc(self, docname):
# type: (str) -> None
def purge_doc(self, docname: str) -> None:
for filename, (docs, dest) in list(self.items()):
docs.discard(docname)
if not docs:
del self[filename]
def merge_other(self, docnames, other):
# type: (Set[str], Dict[str, Tuple[Set[str], Any]]) -> None
def merge_other(self, docnames: Set[str], other: Dict[str, Tuple[Set[str], Any]]) -> None:
for filename, (docs, dest) in other.items():
for docname in docs & set(docnames):
self.add_file(docname, filename)
def copy_static_entry(source, targetdir, builder, context={},
exclude_matchers=(), level=0):
# type: (str, str, Any, Dict, Tuple[Callable, ...], int) -> None
def copy_static_entry(source: str, targetdir: str, builder: "Builder", context: Dict = {},
exclude_matchers: Tuple[PathMatcher, ...] = (), level: int = 0) -> None:
"""[DEPRECATED] Copy a HTML builder static_path entry from source to targetdir.
Handles all possible cases of files, directories and subdirectories.
@ -237,8 +230,7 @@ _DEBUG_HEADER = '''\
'''
def save_traceback(app):
# type: (Any) -> str
def save_traceback(app: "Sphinx") -> str:
"""Save the current exception's traceback in a temporary file."""
import sphinx
import jinja2
@ -273,8 +265,7 @@ def save_traceback(app):
return path
def get_module_source(modname):
# type: (str) -> Tuple[str, str]
def get_module_source(modname: str) -> Tuple[str, str]:
"""Try to find the source code for a module.
Can return ('file', 'filename') in which case the source is in the given
@ -321,8 +312,7 @@ def get_module_source(modname):
return 'file', filename
def get_full_modname(modname, attribute):
# type: (str, str) -> str
def get_full_modname(modname: str, attribute: str) -> str:
if modname is None:
# Prevents a TypeError: if the last getattr() call will return None
# then it's better to return it directly
@ -344,8 +334,7 @@ def get_full_modname(modname, attribute):
_coding_re = re.compile(r'coding[:=]\s*([-\w.]+)')
def detect_encoding(readline):
# type: (Callable[[], bytes]) -> str
def detect_encoding(readline: Callable[[], bytes]) -> str:
"""Like tokenize.detect_encoding() from Py3k, but a bit simplified."""
def read_or_stop():
@ -401,12 +390,10 @@ def detect_encoding(readline):
class UnicodeDecodeErrorHandler:
"""Custom error handler for open() that warns and replaces."""
def __init__(self, docname):
# type: (str) -> None
def __init__(self, docname: str) -> None:
self.docname = docname
def __call__(self, error):
# type: (UnicodeDecodeError) -> Tuple[Union[str, str], int]
def __call__(self, error: UnicodeDecodeError) -> Tuple[str, int]:
linestart = error.object.rfind(b'\n', 0, error.start)
lineend = error.object.find(b'\n', error.start)
if lineend == -1:
@ -426,26 +413,22 @@ class Tee:
"""
File-like object writing to two streams.
"""
def __init__(self, stream1, stream2):
# type: (IO, IO) -> None
def __init__(self, stream1: IO, stream2: IO) -> None:
self.stream1 = stream1
self.stream2 = stream2
def write(self, text):
# type: (str) -> None
def write(self, text: str) -> None:
self.stream1.write(text)
self.stream2.write(text)
def flush(self):
# type: () -> None
def flush(self) -> None:
if hasattr(self.stream1, 'flush'):
self.stream1.flush()
if hasattr(self.stream2, 'flush'):
self.stream2.flush()
def parselinenos(spec, total):
# type: (str, int) -> List[int]
def parselinenos(spec: str, total: int) -> List[int]:
"""Parse a line number spec (such as "1,2,4-6") and return a list of
wanted line numbers.
"""
@ -472,8 +455,7 @@ def parselinenos(spec, total):
return items
def force_decode(string, encoding):
# type: (str, str) -> str
def force_decode(string: str, encoding: str) -> str:
"""Forcibly get a unicode string out of a bytestring."""
warnings.warn('force_decode() is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
@ -491,26 +473,22 @@ def force_decode(string, encoding):
class attrdict(dict):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
warnings.warn('The attrdict class is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
def __getattr__(self, key):
# type: (str) -> str
def __getattr__(self, key: str) -> str:
return self[key]
def __setattr__(self, key, val):
# type: (str, str) -> None
def __setattr__(self, key: str, val: str) -> None:
self[key] = val
def __delattr__(self, key):
# type: (str) -> None
def __delattr__(self, key: str) -> None:
del self[key]
def rpartition(s, t):
# type: (str, str) -> Tuple[str, str]
def rpartition(s: str, t: str) -> Tuple[str, str]:
"""Similar to str.rpartition from 2.5, but doesn't return the separator."""
i = s.rfind(t)
if i != -1:
@ -518,8 +496,7 @@ def rpartition(s, t):
return '', s
def split_into(n, type, value):
# type: (int, str, str) -> List[str]
def split_into(n: int, type: str, value: str) -> List[str]:
"""Split an index entry into a given number of parts at semicolons."""
parts = [x.strip() for x in value.split(';', n - 1)]
if sum(1 for part in parts if part) < n:
@ -527,8 +504,7 @@ def split_into(n, type, value):
return parts
def split_index_msg(type, value):
# type: (str, str) -> List[str]
def split_index_msg(type: str, value: str) -> List[str]:
# new entry types must be listed in directives/other.py!
if type == 'single':
try:
@ -549,8 +525,7 @@ def split_index_msg(type, value):
return result
def format_exception_cut_frames(x=1):
# type: (int) -> str
def format_exception_cut_frames(x: int = 1) -> str:
"""Format an exception with traceback, but only the last x frames."""
typ, val, tb = sys.exc_info()
# res = ['Traceback (most recent call last):\n']
@ -566,19 +541,16 @@ class PeekableIterator:
An iterator which wraps any iterable and makes it possible to peek to see
what's the next item.
"""
def __init__(self, iterable):
# type: (Iterable) -> None
def __init__(self, iterable: Iterable) -> None:
self.remaining = deque() # type: deque
self._iterator = iter(iterable)
warnings.warn('PeekableIterator is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
def __iter__(self):
# type: () -> PeekableIterator
def __iter__(self) -> "PeekableIterator":
return self
def __next__(self):
# type: () -> Any
def __next__(self) -> Any:
"""Return the next item from the iterator."""
if self.remaining:
return self.remaining.popleft()
@ -586,23 +558,20 @@ class PeekableIterator:
next = __next__ # Python 2 compatibility
def push(self, item):
# type: (Any) -> None
def push(self, item: Any) -> None:
"""Push the `item` on the internal stack, it will be returned on the
next :meth:`next` call.
"""
self.remaining.append(item)
def peek(self):
# type: () -> Any
def peek(self) -> Any:
"""Return the next item without changing the state of the iterator."""
item = next(self)
self.push(item)
return item
def import_object(objname, source=None):
# type: (str, str) -> Any
def import_object(objname: str, source: str = None) -> Any:
"""Import python object by qualname."""
try:
objpath = objname.split('.')
@ -625,8 +594,7 @@ def import_object(objname, source=None):
raise ExtensionError('Could not import %s' % objname, exc)
def encode_uri(uri):
# type: (str) -> str
def encode_uri(uri: str) -> str:
split = list(urlsplit(uri))
split[1] = split[1].encode('idna').decode('ascii')
split[2] = quote_plus(split[2].encode(), '/')
@ -635,8 +603,7 @@ def encode_uri(uri):
return urlunsplit(split)
def display_chunk(chunk):
# type: (Any) -> str
def display_chunk(chunk: Any) -> str:
if isinstance(chunk, (list, tuple)):
if len(chunk) == 1:
return str(chunk[0])
@ -644,8 +611,8 @@ def display_chunk(chunk):
return str(chunk)
def old_status_iterator(iterable, summary, color="darkgreen", stringify_func=display_chunk):
# type: (Iterable, str, str, Callable[[Any], str]) -> Iterator
def old_status_iterator(iterable: Iterable, summary: str, color: str = "darkgreen",
stringify_func: Callable[[Any], str] = display_chunk) -> Iterator:
l = 0
for item in iterable:
if l == 0:
@ -659,9 +626,9 @@ def old_status_iterator(iterable, summary, color="darkgreen", stringify_func=dis
# new version with progress info
def status_iterator(iterable, summary, color="darkgreen", length=0, verbosity=0,
stringify_func=display_chunk):
# type: (Iterable, str, str, int, int, Callable[[Any], str]) -> Iterable
def status_iterator(iterable: Iterable, summary: str, color: str = "darkgreen",
length: int = 0, verbosity: int = 0,
stringify_func: Callable[[Any], str] = display_chunk) -> Iterable:
if length == 0:
yield from old_status_iterator(iterable, summary, color, stringify_func)
return
@ -685,16 +652,13 @@ class SkipProgressMessage(Exception):
class progress_message:
def __init__(self, message):
# type: (str) -> None
def __init__(self, message: str) -> None:
self.message = message
def __enter__(self):
# type: () -> None
def __enter__(self) -> None:
logger.info(bold(self.message + '... '), nonl=True)
def __exit__(self, exc_type, exc_value, traceback):
# type: (Type[Exception], Exception, Any) -> bool
def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: Any) -> bool: # NOQA
if isinstance(exc_value, SkipProgressMessage):
logger.info(__('skipped'))
if exc_value.args:
@ -707,8 +671,7 @@ class progress_message:
return False
def __call__(self, f):
# type: (Callable) -> Callable
def __call__(self, f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(*args, **kwargs):
with self:
@ -717,8 +680,7 @@ class progress_message:
return wrapper
def epoch_to_rfc1123(epoch):
# type: (float) -> str
def epoch_to_rfc1123(epoch: float) -> str:
"""Convert datetime format epoch to RFC1123."""
from babel.dates import format_datetime
@ -727,13 +689,11 @@ def epoch_to_rfc1123(epoch):
return format_datetime(dt, fmt, locale='en') + ' GMT'
def rfc1123_to_epoch(rfc1123):
# type: (str) -> float
def rfc1123_to_epoch(rfc1123: str) -> float:
return mktime(strptime(rfc1123, '%a, %d %b %Y %H:%M:%S %Z'))
def xmlname_checker():
# type: () -> Pattern
def xmlname_checker() -> Pattern:
# https://www.w3.org/TR/REC-xml/#NT-Name
name_start_chars = [
':', ['A', 'Z'], '_', ['a', 'z'], ['\u00C0', '\u00D6'],
@ -747,8 +707,7 @@ def xmlname_checker():
['\u203F', '\u2040']
]
def convert(entries, splitter='|'):
# type: (Any, str) -> str
def convert(entries: Any, splitter: str = '|') -> str:
results = []
for entry in entries:
if isinstance(entry, list):

View File

@ -23,6 +23,9 @@ TextlikeNode = Union[nodes.Text, nodes.TextElement]
# type of None
NoneType = type(None)
# path matcher
PathMatcher = Callable[[str], bool]
# common role functions
RoleFunction = Callable[[str, str, str, int, Inliner, Dict, List[str]],
Tuple[List[nodes.Node], List[nodes.system_message]]]

View File

@ -793,7 +793,7 @@ def test_autodoc_imported_members(app):
"imported-members": None,
"ignore-module-all": None}
actual = do_autodoc(app, 'module', 'target', options)
assert '.. py:function:: save_traceback(app)' in actual
assert '.. py:function:: save_traceback(app: Sphinx) -> str' in actual
@pytest.mark.sphinx('html', testroot='ext-autodoc')
@ -1795,7 +1795,7 @@ def test_autodoc_default_options(app):
actual = do_autodoc(app, 'class', 'target.CustomIter')
assert ' .. py:method:: target.CustomIter' not in actual
actual = do_autodoc(app, 'module', 'target')
assert '.. py:function:: save_traceback(app)' not in actual
assert '.. py:function:: save_traceback(app: Sphinx) -> str' not in actual
# with :members:
app.config.autodoc_default_options = {'members': None}
@ -1866,7 +1866,8 @@ def test_autodoc_default_options(app):
'ignore-module-all': None,
}
actual = do_autodoc(app, 'module', 'target')
assert '.. py:function:: save_traceback(app)' in actual
print('\n'.join(actual))
assert '.. py:function:: save_traceback(app: Sphinx) -> str' in actual
@pytest.mark.sphinx('html', testroot='ext-autodoc')