Add type-check annotations to sphinx.util

This commit is contained in:
Takeshi KOMIYA 2016-11-09 11:45:12 +09:00
parent db732ac0b8
commit 8cfb281b05
12 changed files with 226 additions and 50 deletions

View File

@ -42,19 +42,25 @@ from sphinx.util.nodes import ( # noqa
caption_ref_re) caption_ref_re)
from sphinx.util.matching import patfilter # noqa from sphinx.util.matching import patfilter # noqa
if False:
# For type annotation
from typing import Any, Callable, Iterable, Pattern, Sequence, Tuple # NOQA
# Generally useful regular expressions. # Generally useful regular expressions.
ws_re = re.compile(r'\s+') ws_re = re.compile(r'\s+') # type: Pattern
url_re = re.compile(r'(?P<schema>.+)://.*') url_re = re.compile(r'(?P<schema>.+)://.*') # type: Pattern
# High-level utility functions. # High-level utility functions.
def docname_join(basedocname, docname): def docname_join(basedocname, docname):
# type: (unicode, unicode) -> unicode
return posixpath.normpath( return posixpath.normpath(
posixpath.join('/' + basedocname, '..', docname))[1:] posixpath.join('/' + basedocname, '..', docname))[1:]
def path_stabilize(filepath): def path_stabilize(filepath):
# type: (unicode) -> unicode
"normalize path separater and unicode string" "normalize path separater and unicode string"
newpath = filepath.replace(os.path.sep, SEP) newpath = filepath.replace(os.path.sep, SEP)
if isinstance(newpath, text_type): if isinstance(newpath, text_type):
@ -63,6 +69,7 @@ def path_stabilize(filepath):
def get_matching_files(dirname, exclude_matchers=()): def get_matching_files(dirname, exclude_matchers=()):
# type: (unicode, Tuple[Callable[[unicode], bool], ...]) -> Iterable[unicode]
"""Get all file names in a directory, recursively. """Get all file names in a directory, recursively.
Exclude files and dirs matching some matcher in *exclude_matchers*. Exclude files and dirs matching some matcher in *exclude_matchers*.
@ -75,9 +82,9 @@ def get_matching_files(dirname, exclude_matchers=()):
relativeroot = root[dirlen:] relativeroot = root[dirlen:]
qdirs = enumerate(path_stabilize(path.join(relativeroot, dn)) qdirs = enumerate(path_stabilize(path.join(relativeroot, dn))
for dn in dirs) for dn in dirs) # type: Iterable[Tuple[int, unicode]]
qfiles = enumerate(path_stabilize(path.join(relativeroot, fn)) qfiles = enumerate(path_stabilize(path.join(relativeroot, fn))
for fn in files) for fn in files) # type: Iterable[Tuple[int, unicode]]
for matcher in exclude_matchers: for matcher in exclude_matchers:
qdirs = [entry for entry in qdirs if not matcher(entry[1])] qdirs = [entry for entry in qdirs if not matcher(entry[1])]
qfiles = [entry for entry in qfiles if not matcher(entry[1])] qfiles = [entry for entry in qfiles if not matcher(entry[1])]
@ -89,6 +96,7 @@ def get_matching_files(dirname, exclude_matchers=()):
def get_matching_docs(dirname, suffixes, exclude_matchers=()): def get_matching_docs(dirname, suffixes, exclude_matchers=()):
# type: (unicode, List[unicode], Tuple[Callable[[unicode], bool], ...]) -> Iterable[unicode] # NOQA
"""Get all file names (without suffixes) matching a suffix in a directory, """Get all file names (without suffixes) matching a suffix in a directory,
recursively. recursively.
@ -97,7 +105,7 @@ def get_matching_docs(dirname, suffixes, exclude_matchers=()):
suffixpatterns = ['*' + s for s in suffixes] suffixpatterns = ['*' + s for s in suffixes]
for filename in get_matching_files(dirname, exclude_matchers): for filename in get_matching_files(dirname, exclude_matchers):
for suffixpattern in suffixpatterns: for suffixpattern in suffixpatterns:
if fnmatch.fnmatch(filename, suffixpattern): if fnmatch.fnmatch(filename, suffixpattern): # type: ignore
yield filename[:-len(suffixpattern)+1] yield filename[:-len(suffixpattern)+1]
break break
@ -109,9 +117,10 @@ class FilenameUniqDict(dict):
appear in. Used for images and downloadable files in the environment. appear in. Used for images and downloadable files in the environment.
""" """
def __init__(self): def __init__(self):
self._existing = set() self._existing = set() # type: Set[unicode]
def add_file(self, docname, newfile): def add_file(self, docname, newfile):
# type: (unicode, unicode) -> unicode
if newfile in self: if newfile in self:
self[newfile][0].add(docname) self[newfile][0].add(docname)
return self[newfile][1] return self[newfile][1]
@ -126,6 +135,7 @@ class FilenameUniqDict(dict):
return uniquename return uniquename
def purge_doc(self, docname): def purge_doc(self, docname):
# type: (unicode) -> None
for filename, (docs, unique) in list(self.items()): for filename, (docs, unique) in list(self.items()):
docs.discard(docname) docs.discard(docname)
if not docs: if not docs:
@ -133,6 +143,7 @@ class FilenameUniqDict(dict):
self._existing.discard(unique) self._existing.discard(unique)
def merge_other(self, docnames, other): def merge_other(self, docnames, other):
# type: (List[unicode], Dict[unicode, Tuple[Set[unicode], Any]]) -> None
for filename, (docs, unique) in other.items(): for filename, (docs, unique) in other.items():
for doc in docs & docnames: for doc in docs & docnames:
self.add_file(doc, filename) self.add_file(doc, filename)
@ -146,6 +157,7 @@ class FilenameUniqDict(dict):
def copy_static_entry(source, targetdir, builder, context={}, def copy_static_entry(source, targetdir, builder, context={},
exclude_matchers=(), level=0): exclude_matchers=(), level=0):
# type: (unicode, unicode, Any, Dict, Tuple[Callable, ...], int) -> None
"""[DEPRECATED] Copy a HTML builder static_path entry from source to targetdir. """[DEPRECATED] Copy a HTML builder static_path entry from source to targetdir.
Handles all possible cases of files, directories and subdirectories. Handles all possible cases of files, directories and subdirectories.
@ -183,6 +195,7 @@ _DEBUG_HEADER = '''\
def save_traceback(app): def save_traceback(app):
# type: (Any) -> unicode
"""Save the current exception's traceback in a temporary file.""" """Save the current exception's traceback in a temporary file."""
import sphinx import sphinx
import jinja2 import jinja2
@ -190,7 +203,7 @@ def save_traceback(app):
import platform import platform
exc = sys.exc_info()[1] exc = sys.exc_info()[1]
if isinstance(exc, SphinxParallelError): if isinstance(exc, SphinxParallelError):
exc_format = '(Error in parallel process)\n' + exc.traceback exc_format = '(Error in parallel process)\n' + exc.traceback # type: ignore
else: else:
exc_format = traceback.format_exc() exc_format = traceback.format_exc()
fd, path = tempfile.mkstemp('.log', 'sphinx-err-') fd, path = tempfile.mkstemp('.log', 'sphinx-err-')
@ -220,6 +233,7 @@ def save_traceback(app):
def get_module_source(modname): def get_module_source(modname):
# type: (str) -> Tuple[unicode, unicode]
"""Try to find the source code for a module. """Try to find the source code for a module.
Can return ('file', 'filename') in which case the source is in the given Can return ('file', 'filename') in which case the source is in the given
@ -259,6 +273,7 @@ def get_module_source(modname):
def get_full_modname(modname, attribute): def get_full_modname(modname, attribute):
# type: (str, unicode) -> unicode
__import__(modname) __import__(modname)
module = sys.modules[modname] module = sys.modules[modname]
@ -277,6 +292,7 @@ _coding_re = re.compile(r'coding[:=]\s*([-\w.]+)')
def detect_encoding(readline): def detect_encoding(readline):
# type: (Callable) -> unicode
"""Like tokenize.detect_encoding() from Py3k, but a bit simplified.""" """Like tokenize.detect_encoding() from Py3k, but a bit simplified."""
def read_or_stop(): def read_or_stop():
@ -433,10 +449,11 @@ def split_index_msg(type, value):
def format_exception_cut_frames(x=1): def format_exception_cut_frames(x=1):
# type: (int) -> unicode
"""Format an exception with traceback, but only the last x frames.""" """Format an exception with traceback, but only the last x frames."""
typ, val, tb = sys.exc_info() typ, val, tb = sys.exc_info()
# res = ['Traceback (most recent call last):\n'] # res = ['Traceback (most recent call last):\n']
res = [] res = [] # type: List[unicode]
tbres = traceback.format_tb(tb) tbres = traceback.format_tb(tb)
res += tbres[-x:] res += tbres[-x:]
res += traceback.format_exception_only(typ, val) res += traceback.format_exception_only(typ, val)
@ -449,7 +466,7 @@ class PeekableIterator(object):
what's the next item. what's the next item.
""" """
def __init__(self, iterable): def __init__(self, iterable):
self.remaining = deque() self.remaining = deque() # type: deque
self._iterator = iter(iterable) self._iterator = iter(iterable)
def __iter__(self): def __iter__(self):
@ -477,6 +494,7 @@ class PeekableIterator(object):
def import_object(objname, source=None): def import_object(objname, source=None):
# type: (str, unicode) -> Any
try: try:
module, name = objname.rsplit('.', 1) module, name = objname.rsplit('.', 1)
except ValueError as err: except ValueError as err:
@ -496,7 +514,8 @@ def import_object(objname, source=None):
def encode_uri(uri): def encode_uri(uri):
split = list(urlsplit(uri)) # type: (unicode) -> unicode
split = list(urlsplit(uri)) # type: Any
split[1] = split[1].encode('idna').decode('ascii') split[1] = split[1].encode('idna').decode('ascii')
split[2] = quote_plus(split[2].encode('utf-8'), '/').decode('ascii') split[2] = quote_plus(split[2].encode('utf-8'), '/').decode('ascii')
query = list((q, quote_plus(v.encode('utf-8'))) query = list((q, quote_plus(v.encode('utf-8')))
@ -506,8 +525,9 @@ def encode_uri(uri):
def split_docinfo(text): def split_docinfo(text):
# type: (unicode) -> Sequence[unicode]
docinfo_re = re.compile('\A((?:\s*:\w+:.*?\n)+)', re.M) docinfo_re = re.compile('\A((?:\s*:\w+:.*?\n)+)', re.M)
result = docinfo_re.split(text, 1) result = docinfo_re.split(text, 1) # type: ignore
if len(result) == 1: if len(result) == 1:
return '', result[0] return '', result[0]
else: else:

View File

@ -20,10 +20,11 @@ except ImportError:
colorama = None colorama = None
_ansi_re = re.compile('\x1b\\[(\\d\\d;){0,2}\\d\\dm') _ansi_re = re.compile('\x1b\\[(\\d\\d;){0,2}\\d\\dm')
codes = {} codes = {} # type: Dict[str, str]
def get_terminal_width(): def get_terminal_width():
# type: () -> int
"""Borrowed from the py lib.""" """Borrowed from the py lib."""
try: try:
import termios import termios
@ -43,6 +44,7 @@ _tw = get_terminal_width()
def term_width_line(text): def term_width_line(text):
# type: (str) -> str
if not codes: if not codes:
# if no coloring, don't output fancy backspaces # if no coloring, don't output fancy backspaces
return text + '\n' return text + '\n'
@ -52,6 +54,7 @@ def term_width_line(text):
def color_terminal(): def color_terminal():
# type: () -> bool
if sys.platform == 'win32' and colorama is not None: if sys.platform == 'win32' and colorama is not None:
colorama.init() colorama.init()
return True return True
@ -68,24 +71,29 @@ def color_terminal():
def nocolor(): def nocolor():
# type: () -> None
if sys.platform == 'win32' and colorama is not None: if sys.platform == 'win32' and colorama is not None:
colorama.deinit() colorama.deinit()
codes.clear() codes.clear()
def coloron(): def coloron():
# type: () -> None
codes.update(_orig_codes) codes.update(_orig_codes)
def colorize(name, text): def colorize(name, text):
# type: (str, str) -> str
return codes.get(name, '') + text + codes.get('reset', '') return codes.get(name, '') + text + codes.get('reset', '')
def strip_colors(s): def strip_colors(s):
# type: (str) -> str
return re.compile('\x1b.*?m').sub('', s) return re.compile('\x1b.*?m').sub('', s)
def create_color_func(name): def create_color_func(name):
# type: (str) -> None
def inner(text): def inner(text):
return colorize(name, text) return colorize(name, text)
globals()[name] = inner globals()[name] = inner

View File

@ -15,8 +15,14 @@ from docutils import nodes
from sphinx import addnodes from sphinx import addnodes
if False:
# For type annotation
from typing import Any, Tuple # NOQA
from sphinx.domains import Domain # NOQA
def _is_single_paragraph(node): def _is_single_paragraph(node):
# type: (nodes.Node) -> bool
"""True if the node only contains one paragraph (and system messages).""" """True if the node only contains one paragraph (and system messages)."""
if len(node) == 0: if len(node) == 0:
return False return False
@ -47,6 +53,7 @@ class Field(object):
def __init__(self, name, names=(), label=None, has_arg=True, rolename=None, def __init__(self, name, names=(), label=None, has_arg=True, rolename=None,
bodyrolename=None): bodyrolename=None):
# type: (unicode, Tuple[unicode, ...], unicode, bool, unicode, unicode) -> None
self.name = name self.name = name
self.names = names self.names = names
self.label = label self.label = label
@ -56,6 +63,7 @@ class Field(object):
def make_xref(self, rolename, domain, target, def make_xref(self, rolename, domain, target,
innernode=addnodes.literal_emphasis, contnode=None): innernode=addnodes.literal_emphasis, contnode=None):
# type: (unicode, unicode, unicode, nodes.Node, nodes.Node) -> nodes.Node
if not rolename: if not rolename:
return contnode or innernode(target, target) return contnode or innernode(target, target)
refnode = addnodes.pending_xref('', refdomain=domain, refexplicit=False, refnode = addnodes.pending_xref('', refdomain=domain, refexplicit=False,
@ -65,12 +73,15 @@ class Field(object):
def make_xrefs(self, rolename, domain, target, def make_xrefs(self, rolename, domain, target,
innernode=addnodes.literal_emphasis, contnode=None): innernode=addnodes.literal_emphasis, contnode=None):
# type: (unicode, unicode, unicode, nodes.Node, nodes.Node) -> List[nodes.Node]
return [self.make_xref(rolename, domain, target, innernode, contnode)] return [self.make_xref(rolename, domain, target, innernode, contnode)]
def make_entry(self, fieldarg, content): def make_entry(self, fieldarg, content):
# type: (List, unicode) -> Tuple[List, unicode]
return (fieldarg, content) return (fieldarg, content)
def make_field(self, types, domain, item): def make_field(self, types, domain, item):
# type: (List, unicode, Tuple) -> nodes.field
fieldarg, content = item fieldarg, content = item
fieldname = nodes.field_name('', self.label) fieldname = nodes.field_name('', self.label)
if fieldarg: if fieldarg:
@ -106,10 +117,12 @@ class GroupedField(Field):
def __init__(self, name, names=(), label=None, rolename=None, def __init__(self, name, names=(), label=None, rolename=None,
can_collapse=False): can_collapse=False):
# type: (unicode, Tuple[unicode, ...], unicode, unicode, bool) -> None
Field.__init__(self, name, names, label, True, rolename) Field.__init__(self, name, names, label, True, rolename)
self.can_collapse = can_collapse self.can_collapse = can_collapse
def make_field(self, types, domain, items): def make_field(self, types, domain, items):
# type: (List, unicode, Tuple) -> nodes.field
fieldname = nodes.field_name('', self.label) fieldname = nodes.field_name('', self.label)
listnode = self.list_type() listnode = self.list_type()
for fieldarg, content in items: for fieldarg, content in items:
@ -151,11 +164,13 @@ class TypedField(GroupedField):
def __init__(self, name, names=(), typenames=(), label=None, def __init__(self, name, names=(), typenames=(), label=None,
rolename=None, typerolename=None, can_collapse=False): rolename=None, typerolename=None, can_collapse=False):
# type: (unicode, Tuple[unicode, ...], Tuple[unicode, ...], unicode, unicode, unicode, bool) -> None # NOQA
GroupedField.__init__(self, name, names, label, rolename, can_collapse) GroupedField.__init__(self, name, names, label, rolename, can_collapse)
self.typenames = typenames self.typenames = typenames
self.typerolename = typerolename self.typerolename = typerolename
def make_field(self, types, domain, items): def make_field(self, types, domain, items):
# type: (List, unicode, Tuple) -> nodes.field
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par.extend(self.make_xrefs(self.rolename, domain, fieldarg, par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
@ -196,6 +211,7 @@ class DocFieldTransformer(object):
""" """
def __init__(self, directive): def __init__(self, directive):
# type: (Any) -> None
self.domain = directive.domain self.domain = directive.domain
if '_doc_field_type_map' not in directive.__class__.__dict__: if '_doc_field_type_map' not in directive.__class__.__dict__:
directive.__class__._doc_field_type_map = \ directive.__class__._doc_field_type_map = \
@ -203,6 +219,7 @@ class DocFieldTransformer(object):
self.typemap = directive._doc_field_type_map self.typemap = directive._doc_field_type_map
def preprocess_fieldtypes(self, types): def preprocess_fieldtypes(self, types):
# type: (List) -> Dict[unicode, Tuple[Any, bool]]
typemap = {} typemap = {}
for fieldtype in types: for fieldtype in types:
for name in fieldtype.names: for name in fieldtype.names:
@ -213,6 +230,7 @@ class DocFieldTransformer(object):
return typemap return typemap
def transform_all(self, node): def transform_all(self, node):
# type: (nodes.Node) -> None
"""Transform all field list children of a node.""" """Transform all field list children of a node."""
# don't traverse, only handle field lists that are immediate children # don't traverse, only handle field lists that are immediate children
for child in node: for child in node:
@ -220,12 +238,13 @@ class DocFieldTransformer(object):
self.transform(child) self.transform(child)
def transform(self, node): def transform(self, node):
# type: (nodes.Node) -> None
"""Transform a single field list *node*.""" """Transform a single field list *node*."""
typemap = self.typemap typemap = self.typemap
entries = [] entries = []
groupindices = {} groupindices = {} # type: Dict[unicode, int]
types = {} types = {} # type: Dict[unicode, Dict]
# step 1: traverse all fields and collect field types and content # step 1: traverse all fields and collect field types and content
for field in node: for field in node:

View File

@ -12,11 +12,19 @@ from __future__ import absolute_import
from copy import copy from copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from docutils.parsers.rst import directives, roles from docutils.parsers.rst import directives, roles
if False:
# For type annotation
from typing import Any, Callable, Iterator, Tuple # NOQA
from docutils import nodes # NOQA
from sphinx.environment import BuildEnvironment # NOQA
@contextmanager @contextmanager
def docutils_namespace(): def docutils_namespace():
# type: () -> Iterator[None]
"""Create namespace for reST parsers.""" """Create namespace for reST parsers."""
try: try:
_directives = copy(directives._directives) _directives = copy(directives._directives)
@ -37,9 +45,10 @@ class sphinx_domains(object):
markup takes precedence. markup takes precedence.
""" """
def __init__(self, env): def __init__(self, env):
# type: (BuildEnvironment) -> None
self.env = env self.env = env
self.directive_func = None self.directive_func = None # type: Callable
self.roles_func = None self.roles_func = None # type: Callable
def __enter__(self): def __enter__(self):
self.enable() self.enable()
@ -59,6 +68,7 @@ class sphinx_domains(object):
roles.role = self.role_func roles.role = self.role_func
def lookup_domain_element(self, type, name): def lookup_domain_element(self, type, name):
# type: (unicode, unicode) -> Tuple[Any, List]
"""Lookup a markup element (directive or role), given its name which can """Lookup a markup element (directive or role), given its name which can
be a full name (with domain). be a full name (with domain).
""" """
@ -87,12 +97,14 @@ class sphinx_domains(object):
raise ElementLookupError raise ElementLookupError
def lookup_directive(self, name, lang_module, document): def lookup_directive(self, name, lang_module, document):
# type: (unicode, unicode, nodes.document) -> Tuple[Any, List]
try: try:
return self.lookup_domain_element('directive', name) return self.lookup_domain_element('directive', name)
except ElementLookupError: except ElementLookupError:
return self.directive_func(name, lang_module, document) return self.directive_func(name, lang_module, document)
def lookup_role(self, name, lang_module, lineno, reporter): def lookup_role(self, name, lang_module, lineno, reporter):
# type: (unicode, unicode, int, Any) -> Tuple[Any, List]
try: try:
return self.lookup_domain_element('role', name) return self.lookup_domain_element('role', name)
except ElementLookupError: except ElementLookupError:

View File

@ -22,9 +22,12 @@ from babel.messages.pofile import read_po
from babel.messages.mofile import write_mo from babel.messages.mofile import write_mo
from sphinx.errors import SphinxError from sphinx.errors import SphinxError
from sphinx.util.osutil import walk from sphinx.util.osutil import SEP, walk
from sphinx.util import SEP
if False:
# For type annotation
from typing import Callable # NOQA
from sphinx.environment import BuildEnvironment # NOQA
LocaleFileInfoBase = namedtuple('CatalogInfo', 'base_dir,domain,charset') LocaleFileInfoBase = namedtuple('CatalogInfo', 'base_dir,domain,charset')
@ -33,32 +36,39 @@ class CatalogInfo(LocaleFileInfoBase):
@property @property
def po_file(self): def po_file(self):
# type: () -> unicode
return self.domain + '.po' return self.domain + '.po'
@property @property
def mo_file(self): def mo_file(self):
# type: () -> unicode
return self.domain + '.mo' return self.domain + '.mo'
@property @property
def po_path(self): def po_path(self):
# type: () -> unicode
return path.join(self.base_dir, self.po_file) return path.join(self.base_dir, self.po_file)
@property @property
def mo_path(self): def mo_path(self):
# type: () -> unicode
return path.join(self.base_dir, self.mo_file) return path.join(self.base_dir, self.mo_file)
def is_outdated(self): def is_outdated(self):
# type: () -> bool
return ( return (
not path.exists(self.mo_path) or not path.exists(self.mo_path) or
path.getmtime(self.mo_path) < path.getmtime(self.po_path)) path.getmtime(self.mo_path) < path.getmtime(self.po_path))
def write_mo(self, locale): def write_mo(self, locale):
# type: (unicode) -> None
with io.open(self.po_path, 'rt', encoding=self.charset) as po: with io.open(self.po_path, 'rt', encoding=self.charset) as po:
with io.open(self.mo_path, 'wb') as mo: with io.open(self.mo_path, 'wb') as mo:
write_mo(mo, read_po(po, locale)) write_mo(mo, read_po(po, locale))
def find_catalog(docname, compaction): def find_catalog(docname, compaction):
# type: (unicode, bool) -> unicode
if compaction: if compaction:
ret = docname.split(SEP, 1)[0] ret = docname.split(SEP, 1)[0]
else: else:
@ -68,18 +78,20 @@ def find_catalog(docname, compaction):
def find_catalog_files(docname, srcdir, locale_dirs, lang, compaction): def find_catalog_files(docname, srcdir, locale_dirs, lang, compaction):
# type: (unicode, unicode, List[unicode], unicode, bool) -> List[unicode]
if not(lang and locale_dirs): if not(lang and locale_dirs):
return [] return []
domain = find_catalog(docname, compaction) domain = find_catalog(docname, compaction)
files = [gettext.find(domain, path.join(srcdir, dir_), [lang]) files = [gettext.find(domain, path.join(srcdir, dir_), [lang]) # type: ignore
for dir_ in locale_dirs] for dir_ in locale_dirs] # type: ignore
files = [path.relpath(f, srcdir) for f in files if f] files = [path.relpath(f, srcdir) for f in files if f] # type: ignore
return files return files # type: ignore
def find_catalog_source_files(locale_dirs, locale, domains=None, gettext_compact=False, def find_catalog_source_files(locale_dirs, locale, domains=None, gettext_compact=False,
charset='utf-8', force_all=False): charset='utf-8', force_all=False):
# type: (List[unicode], unicode, List[unicode], bool, unicode, bool) -> Set[CatalogInfo]
""" """
:param list locale_dirs: :param list locale_dirs:
list of path as `['locale_dir1', 'locale_dir2', ...]` to find list of path as `['locale_dir1', 'locale_dir2', ...]` to find
@ -99,7 +111,7 @@ def find_catalog_source_files(locale_dirs, locale, domains=None, gettext_compact
if not locale: if not locale:
return [] # locale is not specified return [] # locale is not specified
catalogs = set() catalogs = set() # type: Set[CatalogInfo]
for locale_dir in locale_dirs: for locale_dir in locale_dirs:
if not locale_dir: if not locale_dir:
continue # skip system locale directory continue # skip system locale directory
@ -158,6 +170,7 @@ date_format_mappings = {
def babel_format_date(date, format, locale, warn=None, formatter=babel.dates.format_date): def babel_format_date(date, format, locale, warn=None, formatter=babel.dates.format_date):
# type: (datetime, unicode, unicode, Callable, Callable) -> unicode
if locale is None: if locale is None:
locale = 'en' locale = 'en'
@ -180,6 +193,7 @@ def babel_format_date(date, format, locale, warn=None, formatter=babel.dates.for
def format_date(format, date=None, language=None, warn=None): def format_date(format, date=None, language=None, warn=None):
# type: (str, datetime, unicode, Callable) -> unicode
if format is None: if format is None:
format = 'medium' format = 'medium'
@ -226,6 +240,7 @@ def format_date(format, date=None, language=None, warn=None):
def get_image_filename_for_language(filename, env): def get_image_filename_for_language(filename, env):
# type: (unicode, BuildEnvironment) -> unicode
if not env.config.language: if not env.config.language:
return filename return filename
@ -245,6 +260,7 @@ def get_image_filename_for_language(filename, env):
def search_image_for_language(filename, env): def search_image_for_language(filename, env):
# type: (unicode, BuildEnvironment) -> unicode
if not env.config.language: if not env.config.language:
return filename return filename

View File

@ -12,10 +12,14 @@
import re import re
from six import PY3, binary_type from six import PY3, binary_type
from six.moves import builtins from six.moves import builtins # type: ignore
from sphinx.util import force_decode from sphinx.util import force_decode
if False:
# For type annotation
from typing import Any, Callable, Tuple # NOQA
# this imports the standard library inspect module without resorting to # this imports the standard library inspect module without resorting to
# relatively import this module # relatively import this module
inspect = __import__('inspect') inspect = __import__('inspect')
@ -67,7 +71,7 @@ else: # 2.7
"""Like inspect.getargspec but supports functools.partial as well.""" """Like inspect.getargspec but supports functools.partial as well."""
if inspect.ismethod(func): if inspect.ismethod(func):
func = func.__func__ func = func.__func__
parts = 0, () parts = 0, () # type: Tuple[int, Tuple[unicode, ...]]
if type(func) is partial: if type(func) is partial:
keywords = func.keywords keywords = func.keywords
if keywords is None: if keywords is None:
@ -101,6 +105,7 @@ except ImportError:
def isenumattribute(x): def isenumattribute(x):
# type: (Any) -> bool
"""Check if the object is attribute of enum.""" """Check if the object is attribute of enum."""
if enum is None: if enum is None:
return False return False
@ -108,6 +113,7 @@ def isenumattribute(x):
def isdescriptor(x): def isdescriptor(x):
# type: (Any) -> bool
"""Check if the object is some kind of descriptor.""" """Check if the object is some kind of descriptor."""
for item in '__get__', '__set__', '__delete__': for item in '__get__', '__set__', '__delete__':
if hasattr(safe_getattr(x, item, None), '__call__'): if hasattr(safe_getattr(x, item, None), '__call__'):
@ -116,6 +122,7 @@ def isdescriptor(x):
def safe_getattr(obj, name, *defargs): def safe_getattr(obj, name, *defargs):
# type: (Any, unicode, unicode) -> object
"""A getattr() that turns all exceptions into AttributeErrors.""" """A getattr() that turns all exceptions into AttributeErrors."""
try: try:
return getattr(obj, name, *defargs) return getattr(obj, name, *defargs)
@ -138,8 +145,9 @@ def safe_getattr(obj, name, *defargs):
def safe_getmembers(object, predicate=None, attr_getter=safe_getattr): def safe_getmembers(object, predicate=None, attr_getter=safe_getattr):
# type: (Any, Callable[[unicode], bool], Callable) -> List[Tuple[unicode, Any]]
"""A version of inspect.getmembers() that uses safe_getattr().""" """A version of inspect.getmembers() that uses safe_getattr()."""
results = [] results = [] # type: List[Tuple[unicode, Any]]
for key in dir(object): for key in dir(object):
try: try:
value = attr_getter(object, key, None) value = attr_getter(object, key, None)
@ -152,6 +160,7 @@ def safe_getmembers(object, predicate=None, attr_getter=safe_getattr):
def object_description(object): def object_description(object):
# type: (Any) -> unicode
"""A repr() implementation that returns text safe to use in reST context.""" """A repr() implementation that returns text safe to use in reST context."""
try: try:
s = repr(object) s = repr(object)
@ -166,6 +175,7 @@ def object_description(object):
def is_builtin_class_method(obj, attr_name): def is_builtin_class_method(obj, attr_name):
# type: (Any, unicode) -> bool
"""If attr_name is implemented at builtin class, return True. """If attr_name is implemented at builtin class, return True.
>>> is_builtin_class_method(int, '__init__') >>> is_builtin_class_method(int, '__init__')
@ -177,6 +187,6 @@ def is_builtin_class_method(obj, attr_name):
classes = [c for c in inspect.getmro(obj) if attr_name in c.__dict__] classes = [c for c in inspect.getmro(obj) if attr_name in c.__dict__]
cls = classes[0] if classes else object cls = classes[0] if classes else object
if not hasattr(builtins, safe_getattr(cls, '__name__', '')): if not hasattr(builtins, safe_getattr(cls, '__name__', '')): # type: ignore
return False return False
return getattr(builtins, safe_getattr(cls, '__name__', '')) is cls return getattr(builtins, safe_getattr(cls, '__name__', '')) is cls # type: ignore

View File

@ -16,6 +16,10 @@ from six import iteritems, integer_types, string_types
from sphinx.util.pycompat import u from sphinx.util.pycompat import u
if False:
# For type annotation
from typing import Any, IO, Union # NOQA
_str_re = re.compile(r'"(\\\\|\\"|[^"])*"') _str_re = re.compile(r'"(\\\\|\\"|[^"])*"')
_int_re = re.compile(r'\d+') _int_re = re.compile(r'\d+')
_name_re = re.compile(r'[a-zA-Z_]\w*') _name_re = re.compile(r'[a-zA-Z_]\w*')
@ -37,6 +41,7 @@ ESCAPED = re.compile(r'\\u.{4}|\\.')
def encode_string(s): def encode_string(s):
# type: (str) -> str
def replace(match): def replace(match):
s = match.group(0) s = match.group(0)
try: try:
@ -55,6 +60,7 @@ def encode_string(s):
def decode_string(s): def decode_string(s):
# type: (str) -> str
return ESCAPED.sub(lambda m: eval(u + '"' + m.group() + '"'), s) return ESCAPED.sub(lambda m: eval(u + '"' + m.group() + '"'), s)
@ -77,6 +83,7 @@ double in super""".split())
def dumps(obj, key=False): def dumps(obj, key=False):
# type: (Any, bool) -> str
if key: if key:
if not isinstance(obj, string_types): if not isinstance(obj, string_types):
obj = str(obj) obj = str(obj)
@ -88,7 +95,7 @@ def dumps(obj, key=False):
return 'null' return 'null'
elif obj is True or obj is False: elif obj is True or obj is False:
return obj and 'true' or 'false' return obj and 'true' or 'false'
elif isinstance(obj, integer_types + (float,)): elif isinstance(obj, integer_types + (float,)): # type: ignore
return str(obj) return str(obj)
elif isinstance(obj, dict): elif isinstance(obj, dict):
return '{%s}' % ','.join(sorted('%s:%s' % ( return '{%s}' % ','.join(sorted('%s:%s' % (
@ -100,20 +107,22 @@ def dumps(obj, key=False):
elif isinstance(obj, (tuple, list)): elif isinstance(obj, (tuple, list)):
return '[%s]' % ','.join(dumps(x) for x in obj) return '[%s]' % ','.join(dumps(x) for x in obj)
elif isinstance(obj, string_types): elif isinstance(obj, string_types):
return encode_string(obj) return encode_string(obj) # type: ignore
raise TypeError(type(obj)) raise TypeError(type(obj))
def dump(obj, f): def dump(obj, f):
# type: (Any, IO) -> None
f.write(dumps(obj)) f.write(dumps(obj))
def loads(x): def loads(x):
# type: (str) -> Any
"""Loader that can read the JS subset the indexer produces.""" """Loader that can read the JS subset the indexer produces."""
nothing = object() nothing = object()
i = 0 i = 0
n = len(x) n = len(x)
stack = [] stack = [] # type: List[Union[List, Dict]]
obj = nothing obj = nothing
key = False key = False
keys = [] keys = []
@ -164,6 +173,7 @@ def loads(x):
raise ValueError("multiple values") raise ValueError("multiple values")
key = False key = False
else: else:
y = None # type: Any
m = _str_re.match(x, i) m = _str_re.match(x, i)
if m: if m:
y = decode_string(m.group()[1:-1]) y = decode_string(m.group()[1:-1])
@ -193,11 +203,12 @@ def loads(x):
obj[keys[-1]] = y obj[keys[-1]] = y
key = False key = False
else: else:
obj.append(y) obj.append(y) # type: ignore
if obj is nothing: if obj is nothing:
raise ValueError("nothing loaded from string") raise ValueError("nothing loaded from string")
return obj return obj
def load(f): def load(f):
# type: (IO) -> Any
return loads(f.read()) return loads(f.read())

View File

@ -11,15 +11,20 @@
import re import re
if False:
# For type annotation
from typing import Callable, Match, Pattern # NOQA
def _translate_pattern(pat): def _translate_pattern(pat):
# type: (unicode) -> unicode
"""Translate a shell-style glob pattern to a regular expression. """Translate a shell-style glob pattern to a regular expression.
Adapted from the fnmatch module, but enhanced so that single stars don't Adapted from the fnmatch module, but enhanced so that single stars don't
match slashes. match slashes.
""" """
i, n = 0, len(pat) i, n = 0, len(pat)
res = '' res = '' # type: unicode
while i < n: while i < n:
c = pat[i] c = pat[i]
i += 1 i += 1
@ -59,6 +64,7 @@ def _translate_pattern(pat):
def compile_matchers(patterns): def compile_matchers(patterns):
# type: (List[unicode]) -> List[Callable[[unicode], Match[unicode]]]
return [re.compile(_translate_pattern(pat)).match for pat in patterns] return [re.compile(_translate_pattern(pat)).match for pat in patterns]
@ -70,23 +76,27 @@ class Matcher(object):
""" """
def __init__(self, patterns): def __init__(self, patterns):
# type: (List[unicode]) -> None
expanded = [pat[3:] for pat in patterns if pat.startswith('**/')] expanded = [pat[3:] for pat in patterns if pat.startswith('**/')]
self.patterns = compile_matchers(patterns + expanded) self.patterns = compile_matchers(patterns + expanded)
def __call__(self, string): def __call__(self, string):
# type: (unicode) -> bool
return self.match(string) return self.match(string)
def match(self, string): def match(self, string):
# type: (unicode) -> bool
return any(pat(string) for pat in self.patterns) return any(pat(string) for pat in self.patterns)
DOTFILES = Matcher(['**/.*']) DOTFILES = Matcher(['**/.*'])
_pat_cache = {} _pat_cache = {} # type: Dict[unicode, Pattern]
def patmatch(name, pat): def patmatch(name, pat):
# type: (unicode, unicode) -> re.Match
"""Return if name matches pat. Adapted from fnmatch module.""" """Return if name matches pat. Adapted from fnmatch module."""
if pat not in _pat_cache: if pat not in _pat_cache:
_pat_cache[pat] = re.compile(_translate_pattern(pat)) _pat_cache[pat] = re.compile(_translate_pattern(pat))
@ -94,6 +104,7 @@ def patmatch(name, pat):
def patfilter(names, pat): def patfilter(names, pat):
# type: (List[unicode], unicode) -> List[unicode]
"""Return the subset of the list NAMES that match PAT. """Return the subset of the list NAMES that match PAT.
Adapted from fnmatch module. Adapted from fnmatch module.

View File

@ -13,19 +13,28 @@ from __future__ import absolute_import
import re import re
from six import text_type from six import text_type
from docutils import nodes from docutils import nodes
from sphinx import addnodes from sphinx import addnodes
from sphinx.locale import pairindextypes from sphinx.locale import pairindextypes
if False:
# For type annotation
from typing import Any, Callable, Iterable, Tuple, Union # NOQA
from sphinx.builders import Builder # NOQA
from sphinx.utils.tags import Tags # NOQA
class WarningStream(object): class WarningStream(object):
def __init__(self, warnfunc): def __init__(self, warnfunc):
# type: (Callable) -> None
self.warnfunc = warnfunc self.warnfunc = warnfunc
self._re = re.compile(r'\((DEBUG|INFO|WARNING|ERROR|SEVERE)/[0-4]\)') self._re = re.compile(r'\((DEBUG|INFO|WARNING|ERROR|SEVERE)/[0-4]\)')
def write(self, text): def write(self, text):
# type: (str) -> None
text = text.strip() text = text.strip()
if text: if text:
self.warnfunc(self._re.sub(r'\1:', text), None, '') self.warnfunc(self._re.sub(r'\1:', text), None, '')
@ -37,6 +46,7 @@ caption_ref_re = explicit_title_re # b/w compat alias
def apply_source_workaround(node): def apply_source_workaround(node):
# type: (nodes.Node) -> None
# workaround: nodes.term have wrong rawsource if classifier is specified. # workaround: nodes.term have wrong rawsource if classifier is specified.
# The behavior of docutils-0.11, 0.12 is: # The behavior of docutils-0.11, 0.12 is:
# * when ``term text : classifier1 : classifier2`` is specified, # * when ``term text : classifier1 : classifier2`` is specified,
@ -87,6 +97,7 @@ IGNORED_NODES = (
def is_pending_meta(node): def is_pending_meta(node):
# type: (nodes.Node) -> bool
if (isinstance(node, nodes.pending) and if (isinstance(node, nodes.pending) and
isinstance(node.details.get('nodes', [None])[0], addnodes.meta)): isinstance(node.details.get('nodes', [None])[0], addnodes.meta)):
return True return True
@ -95,6 +106,7 @@ def is_pending_meta(node):
def is_translatable(node): def is_translatable(node):
# type: (nodes.Node) -> bool
if isinstance(node, addnodes.translatable): if isinstance(node, addnodes.translatable):
return True return True
@ -137,6 +149,7 @@ META_TYPE_NODES = (
def extract_messages(doctree): def extract_messages(doctree):
# type: (nodes.Node) -> Iterable[Tuple[nodes.Node, unicode]]
"""Extract translatable messages from a document tree.""" """Extract translatable messages from a document tree."""
for node in doctree.traverse(is_translatable): for node in doctree.traverse(is_translatable):
if isinstance(node, addnodes.translatable): if isinstance(node, addnodes.translatable):
@ -164,12 +177,14 @@ def extract_messages(doctree):
def find_source_node(node): def find_source_node(node):
# type: (nodes.Node) -> unicode
for pnode in traverse_parent(node): for pnode in traverse_parent(node):
if pnode.source: if pnode.source:
return pnode.source return pnode.source
def traverse_parent(node, cls=None): def traverse_parent(node, cls=None):
# type: (nodes.Node, Any) -> Iterable[nodes.Node]
while node: while node:
if cls is None or isinstance(node, cls): if cls is None or isinstance(node, cls):
yield node yield node
@ -177,6 +192,7 @@ def traverse_parent(node, cls=None):
def traverse_translatable_index(doctree): def traverse_translatable_index(doctree):
# type: (nodes.Node) -> Iterable[Tuple[nodes.Node, List[unicode]]]
"""Traverse translatable index node from a document tree.""" """Traverse translatable index node from a document tree."""
def is_block_index(node): def is_block_index(node):
return isinstance(node, addnodes.index) and \ return isinstance(node, addnodes.index) and \
@ -190,6 +206,7 @@ def traverse_translatable_index(doctree):
def nested_parse_with_titles(state, content, node): def nested_parse_with_titles(state, content, node):
# type: (Any, List[unicode], nodes.Node) -> unicode
"""Version of state.nested_parse() that allows titles and does not require """Version of state.nested_parse() that allows titles and does not require
titles to have the same decoration as the calling document. titles to have the same decoration as the calling document.
@ -209,6 +226,7 @@ def nested_parse_with_titles(state, content, node):
def clean_astext(node): def clean_astext(node):
# type: (nodes.Node) -> unicode
"""Like node.astext(), but ignore images.""" """Like node.astext(), but ignore images."""
node = node.deepcopy() node = node.deepcopy()
for img in node.traverse(nodes.image): for img in node.traverse(nodes.image):
@ -217,6 +235,7 @@ def clean_astext(node):
def split_explicit_title(text): def split_explicit_title(text):
# type: (str) -> Tuple[bool, unicode, unicode]
"""Split role content into title and target, if given.""" """Split role content into title and target, if given."""
match = explicit_title_re.match(text) match = explicit_title_re.match(text)
if match: if match:
@ -230,7 +249,8 @@ indextypes = [
def process_index_entry(entry, targetid): def process_index_entry(entry, targetid):
indexentries = [] # type: (unicode, unicode) -> List[Tuple[unicode, unicode, unicode, unicode, unicode]]
indexentries = [] # type: List[Tuple[unicode, unicode, unicode, unicode, unicode]]
entry = entry.strip() entry = entry.strip()
oentry = entry oentry = entry
main = '' main = ''
@ -266,6 +286,7 @@ def process_index_entry(entry, targetid):
def inline_all_toctrees(builder, docnameset, docname, tree, colorfunc, traversed): def inline_all_toctrees(builder, docnameset, docname, tree, colorfunc, traversed):
# type: (Builder, Set[unicode], unicode, nodes.Node, Callable, nodes.Node) -> nodes.Node
"""Inline all toctrees in the *tree*. """Inline all toctrees in the *tree*.
Record all docnames in *docnameset*, and output docnames with *colorfunc*. Record all docnames in *docnameset*, and output docnames with *colorfunc*.
@ -299,6 +320,7 @@ def inline_all_toctrees(builder, docnameset, docname, tree, colorfunc, traversed
def make_refnode(builder, fromdocname, todocname, targetid, child, title=None): def make_refnode(builder, fromdocname, todocname, targetid, child, title=None):
# type: (Builder, unicode, unicode, unicode, nodes.Node, unicode) -> nodes.reference
"""Shortcut to create a reference node.""" """Shortcut to create a reference node."""
node = nodes.reference('', '', internal=True) node = nodes.reference('', '', internal=True)
if fromdocname == todocname: if fromdocname == todocname:
@ -313,15 +335,18 @@ def make_refnode(builder, fromdocname, todocname, targetid, child, title=None):
def set_source_info(directive, node): def set_source_info(directive, node):
# type: (Any, nodes.Node) -> None
node.source, node.line = \ node.source, node.line = \
directive.state_machine.get_source_and_line(directive.lineno) directive.state_machine.get_source_and_line(directive.lineno)
def set_role_source_info(inliner, lineno, node): def set_role_source_info(inliner, lineno, node):
# type: (Any, unicode, nodes.Node) -> None
node.source, node.line = inliner.reporter.get_source_and_line(lineno) node.source, node.line = inliner.reporter.get_source_and_line(lineno)
def process_only_nodes(doctree, tags, warn_node=None): def process_only_nodes(doctree, tags, warn_node=None):
# type: (nodes.Node, Tags, Callable) -> None
# A comment on the comment() nodes being inserted: replacing by [] would # A comment on the comment() nodes being inserted: replacing by [] would
# result in a "Losing ids" exception if there is a target node before # result in a "Losing ids" exception if there is a target node before
# the only node, so we make sure docutils can transfer the id to # the only node, so we make sure docutils can transfer the id to
@ -345,6 +370,7 @@ def process_only_nodes(doctree, tags, warn_node=None):
# monkey-patch Element.copy to copy the rawsource and line # monkey-patch Element.copy to copy the rawsource and line
def _new_copy(self): def _new_copy(self):
# type: (nodes.Node) -> nodes.Node
newnode = self.__class__(self.rawsource, **self.attributes) newnode = self.__class__(self.rawsource, **self.attributes)
if isinstance(self, nodes.Element): if isinstance(self, nodes.Element):
newnode.source = self.source newnode.source = self.source

View File

@ -21,9 +21,12 @@ import filecmp
from os import path from os import path
import contextlib import contextlib
from io import BytesIO, StringIO from io import BytesIO, StringIO
from six import PY2, text_type from six import PY2, text_type
if False:
# For type annotation
from typing import Any, Iterator, Tuple, Union # NOQA
# Errnos that we need. # Errnos that we need.
EEXIST = getattr(errno, 'EEXIST', 0) EEXIST = getattr(errno, 'EEXIST', 0)
ENOENT = getattr(errno, 'ENOENT', 0) ENOENT = getattr(errno, 'ENOENT', 0)
@ -39,15 +42,18 @@ SEP = "/"
def os_path(canonicalpath): def os_path(canonicalpath):
# type: (unicode) -> unicode
return canonicalpath.replace(SEP, path.sep) return canonicalpath.replace(SEP, path.sep)
def canon_path(nativepath): def canon_path(nativepath):
# type: (unicode) -> unicode
"""Return path in OS-independent form""" """Return path in OS-independent form"""
return nativepath.replace(path.sep, SEP) return nativepath.replace(path.sep, SEP)
def relative_uri(base, to): def relative_uri(base, to):
# type: (unicode, unicode) -> unicode
"""Return a relative URL from ``base`` to ``to``.""" """Return a relative URL from ``base`` to ``to``."""
if to.startswith(SEP): if to.startswith(SEP):
return to return to
@ -71,6 +77,7 @@ def relative_uri(base, to):
def ensuredir(path): def ensuredir(path):
# type: (unicode) -> None
"""Ensure that a path exists.""" """Ensure that a path exists."""
try: try:
os.makedirs(path) os.makedirs(path)
@ -84,6 +91,7 @@ def ensuredir(path):
# that check UnicodeError. # that check UnicodeError.
# The customization obstacle to replace the function with the os.walk. # The customization obstacle to replace the function with the os.walk.
def walk(top, topdown=True, followlinks=False): def walk(top, topdown=True, followlinks=False):
# type: (unicode, bool, bool) -> Iterator[Tuple[unicode, List[unicode], List[unicode]]]
"""Backport of os.walk from 2.6, where the *followlinks* argument was """Backport of os.walk from 2.6, where the *followlinks* argument was
added. added.
""" """
@ -115,6 +123,7 @@ def walk(top, topdown=True, followlinks=False):
def mtimes_of_files(dirnames, suffix): def mtimes_of_files(dirnames, suffix):
# type: (List[unicode], unicode) -> Iterator[float]
for dirname in dirnames: for dirname in dirnames:
for root, dirs, files in os.walk(dirname): for root, dirs, files in os.walk(dirname):
for sfile in files: for sfile in files:
@ -126,6 +135,7 @@ def mtimes_of_files(dirnames, suffix):
def movefile(source, dest): def movefile(source, dest):
# type: (unicode, unicode) -> None
"""Move a file, removing the destination if it exists.""" """Move a file, removing the destination if it exists."""
if os.path.exists(dest): if os.path.exists(dest):
try: try:
@ -136,6 +146,7 @@ def movefile(source, dest):
def copytimes(source, dest): def copytimes(source, dest):
# type: (unicode, unicode) -> None
"""Copy a file's modification times.""" """Copy a file's modification times."""
st = os.stat(source) st = os.stat(source)
if hasattr(os, 'utime'): if hasattr(os, 'utime'):
@ -143,6 +154,7 @@ def copytimes(source, dest):
def copyfile(source, dest): def copyfile(source, dest):
# type: (unicode, unicode) -> None
"""Copy a file and its modification times, if possible. """Copy a file and its modification times, if possible.
Note: ``copyfile`` skips copying if the file has not been changed""" Note: ``copyfile`` skips copying if the file has not been changed"""
@ -159,10 +171,12 @@ no_fn_re = re.compile(r'[^a-zA-Z0-9_-]')
def make_filename(string): def make_filename(string):
# type: (str) -> unicode
return no_fn_re.sub('', string) or 'sphinx' return no_fn_re.sub('', string) or 'sphinx'
def ustrftime(format, *args): def ustrftime(format, *args):
# type: (unicode, Any) -> unicode
# [DEPRECATED] strftime for unicode strings # [DEPRECATED] strftime for unicode strings
# It will be removed at Sphinx-1.5 # It will be removed at Sphinx-1.5
if not args: if not args:
@ -171,7 +185,7 @@ def ustrftime(format, *args):
source_date_epoch = os.getenv('SOURCE_DATE_EPOCH') source_date_epoch = os.getenv('SOURCE_DATE_EPOCH')
if source_date_epoch is not None: if source_date_epoch is not None:
time_struct = time.gmtime(float(source_date_epoch)) time_struct = time.gmtime(float(source_date_epoch))
args = [time_struct] args = [time_struct] # type: ignore
if PY2: if PY2:
# if a locale is set, the time strings are encoded in the encoding # if a locale is set, the time strings are encoded in the encoding
# given by LC_TIME; if that is available, use it # given by LC_TIME; if that is available, use it
@ -188,16 +202,18 @@ def ustrftime(format, *args):
def safe_relpath(path, start=None): def safe_relpath(path, start=None):
# type: (unicode, unicode) -> unicode
try: try:
return os.path.relpath(path, start) return os.path.relpath(path, start)
except ValueError: except ValueError:
return path return path
fs_encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() fs_encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() # type: unicode
def abspath(pathdir): def abspath(pathdir):
# type: (unicode) -> unicode
pathdir = path.abspath(pathdir) pathdir = path.abspath(pathdir)
if isinstance(pathdir, bytes): if isinstance(pathdir, bytes):
pathdir = pathdir.decode(fs_encoding) pathdir = pathdir.decode(fs_encoding)
@ -205,6 +221,7 @@ def abspath(pathdir):
def getcwd(): def getcwd():
# type: () -> unicode
if hasattr(os, 'getcwdu'): if hasattr(os, 'getcwdu'):
return os.getcwdu() return os.getcwdu()
return os.getcwd() return os.getcwd()
@ -212,6 +229,7 @@ def getcwd():
@contextlib.contextmanager @contextlib.contextmanager
def cd(target_dir): def cd(target_dir):
# type: (unicode) -> Iterator[None]
cwd = getcwd() cwd = getcwd()
try: try:
os.chdir(target_dir) os.chdir(target_dir)
@ -233,10 +251,12 @@ class FileAvoidWrite(object):
Objects can be used as context managers. Objects can be used as context managers.
""" """
def __init__(self, path): def __init__(self, path):
# type: (unicode) -> None
self._path = path self._path = path
self._io = None self._io = None # type: Union[StringIO, BytesIO]
def write(self, data): def write(self, data):
# type: (Union[str, bytes]) -> None
if not self._io: if not self._io:
if isinstance(data, text_type): if isinstance(data, text_type):
self._io = StringIO() self._io = StringIO()
@ -246,6 +266,7 @@ class FileAvoidWrite(object):
self._io.write(data) self._io.write(data)
def close(self): def close(self):
# type: () -> None
"""Stop accepting writes and write file, if needed.""" """Stop accepting writes and write file, if needed."""
if not self._io: if not self._io:
raise Exception('FileAvoidWrite does not support empty files.') raise Exception('FileAvoidWrite does not support empty files.')
@ -288,6 +309,7 @@ class FileAvoidWrite(object):
def rmtree(path): def rmtree(path):
# type: (unicode) -> None
if os.path.isdir(path): if os.path.isdir(path):
shutil.rmtree(path) shutil.rmtree(path)
else: else:

View File

@ -13,16 +13,19 @@ import os
import time import time
import traceback import traceback
from math import sqrt from math import sqrt
from six import iteritems
try: try:
import multiprocessing import multiprocessing
except ImportError: except ImportError:
multiprocessing = None multiprocessing = None
from six import iteritems
from sphinx.errors import SphinxParallelError from sphinx.errors import SphinxParallelError
if False:
# For type annotation
from typing import Any, Callable, Sequence # NOQA
# our parallel functionality only works for the forking Process # our parallel functionality only works for the forking Process
parallel_available = multiprocessing and (os.name == 'posix') parallel_available = multiprocessing and (os.name == 'posix')
@ -31,9 +34,11 @@ class SerialTasks(object):
"""Has the same interface as ParallelTasks, but executes tasks directly.""" """Has the same interface as ParallelTasks, but executes tasks directly."""
def __init__(self, nproc=1): def __init__(self, nproc=1):
# type: (int) -> None
pass pass
def add_task(self, task_func, arg=None, result_func=None): def add_task(self, task_func, arg=None, result_func=None):
# type: (Callable, Any, Callable) -> None
if arg is not None: if arg is not None:
res = task_func(arg) res = task_func(arg)
else: else:
@ -42,6 +47,7 @@ class SerialTasks(object):
result_func(res) result_func(res)
def join(self): def join(self):
# type: () -> None
pass pass
@ -49,23 +55,25 @@ class ParallelTasks(object):
"""Executes *nproc* tasks in parallel after forking.""" """Executes *nproc* tasks in parallel after forking."""
def __init__(self, nproc): def __init__(self, nproc):
# type: (int) -> None
self.nproc = nproc self.nproc = nproc
# (optional) function performed by each task on the result of main task # (optional) function performed by each task on the result of main task
self._result_funcs = {} self._result_funcs = {} # type: Dict[int, Callable]
# task arguments # task arguments
self._args = {} self._args = {} # type: Dict[int, List[Any]]
# list of subprocesses (both started and waiting) # list of subprocesses (both started and waiting)
self._procs = {} self._procs = {} # type: Dict[int, multiprocessing.Process]
# list of receiving pipe connections of running subprocesses # list of receiving pipe connections of running subprocesses
self._precvs = {} self._precvs = {} # type: Dict[int, Any]
# list of receiving pipe connections of waiting subprocesses # list of receiving pipe connections of waiting subprocesses
self._precvsWaiting = {} self._precvsWaiting = {} # type: Dict[int, Any]
# number of working subprocesses # number of working subprocesses
self._pworking = 0 self._pworking = 0
# task number of each subprocess # task number of each subprocess
self._taskid = 0 self._taskid = 0
def _process(self, pipe, func, arg): def _process(self, pipe, func, arg):
# type: (Any, Callable, Any) -> None
try: try:
if arg is None: if arg is None:
ret = func() ret = func()
@ -76,6 +84,7 @@ class ParallelTasks(object):
pipe.send((True, (err, traceback.format_exc()))) pipe.send((True, (err, traceback.format_exc())))
def add_task(self, task_func, arg=None, result_func=None): def add_task(self, task_func, arg=None, result_func=None):
# type: (Callable, Any, Callable) -> None
tid = self._taskid tid = self._taskid
self._taskid += 1 self._taskid += 1
self._result_funcs[tid] = result_func or (lambda arg: None) self._result_funcs[tid] = result_func or (lambda arg: None)
@ -88,10 +97,12 @@ class ParallelTasks(object):
self._join_one() self._join_one()
def join(self): def join(self):
# type: () -> None
while self._pworking: while self._pworking:
self._join_one() self._join_one()
def _join_one(self): def _join_one(self):
# type: () -> None
for tid, pipe in iteritems(self._precvs): for tid, pipe in iteritems(self._precvs):
if pipe.poll(): if pipe.poll():
exc, result = pipe.recv() exc, result = pipe.recv()
@ -111,6 +122,7 @@ class ParallelTasks(object):
def make_chunks(arguments, nproc, maxbatch=10): def make_chunks(arguments, nproc, maxbatch=10):
# type: (Sequence[unicode], int, int) -> List[Any]
# determine how many documents to read in one go # determine how many documents to read in one go
nargs = len(arguments) nargs = len(arguments)
chunksize = nargs // nproc chunksize = nargs // nproc

View File

@ -14,11 +14,13 @@ import sys
import codecs import codecs
import warnings import warnings
from six import class_types from six import PY3, class_types, text_type, exec_
from six.moves import zip_longest from six.moves import zip_longest
from itertools import product from itertools import product
from six import PY3, text_type, exec_ if False:
# For type annotation
from typing import Any, Callable # NOQA
NoneType = type(None) NoneType = type(None)
@ -33,6 +35,7 @@ if PY3:
# safely encode a string for printing to the terminal # safely encode a string for printing to the terminal
def terminal_safe(s): def terminal_safe(s):
# type: (unicode) -> unicode
return s.encode('ascii', 'backslashreplace').decode('ascii') return s.encode('ascii', 'backslashreplace').decode('ascii')
# some kind of default system encoding; should be used with a lenient # some kind of default system encoding; should be used with a lenient
# error handler # error handler
@ -40,6 +43,7 @@ if PY3:
# support for running 2to3 over config files # support for running 2to3 over config files
def convert_with_2to3(filepath): def convert_with_2to3(filepath):
# type: (unicode) -> unicode
from lib2to3.refactor import RefactoringTool, get_fixers_from_package from lib2to3.refactor import RefactoringTool, get_fixers_from_package
from lib2to3.pgen2.parse import ParseError from lib2to3.pgen2.parse import ParseError
fixers = get_fixers_from_package('lib2to3.fixes') fixers = get_fixers_from_package('lib2to3.fixes')
@ -68,13 +72,15 @@ else:
# Python 2 # Python 2
u = 'u' u = 'u'
# no need to refactor on 2.x versions # no need to refactor on 2.x versions
convert_with_2to3 = None convert_with_2to3 = None # type: ignore
def TextIOWrapper(stream, encoding): def TextIOWrapper(stream, encoding):
# type: (file, str) -> unicode
return codecs.lookup(encoding or 'ascii')[2](stream) return codecs.lookup(encoding or 'ascii')[2](stream)
# safely encode a string for printing to the terminal # safely encode a string for printing to the terminal
def terminal_safe(s): def terminal_safe(s):
# type: (unicode) -> unicode
return s.encode('ascii', 'backslashreplace') return s.encode('ascii', 'backslashreplace')
# some kind of default system encoding; should be used with a lenient # some kind of default system encoding; should be used with a lenient
# error handler # error handler
@ -91,6 +97,7 @@ else:
# backport from python3 # backport from python3
def indent(text, prefix, predicate=None): def indent(text, prefix, predicate=None):
# type: (unicode, unicode, Callable) -> unicode
if predicate is None: if predicate is None:
def predicate(line): def predicate(line):
return line.strip() return line.strip()
@ -102,6 +109,7 @@ else:
def execfile_(filepath, _globals, open=open): def execfile_(filepath, _globals, open=open):
# type: (unicode, Any, Callable) -> None
from sphinx.util.osutil import fs_encoding from sphinx.util.osutil import fs_encoding
# get config source -- 'b' is a no-op under 2.x, while 'U' is # get config source -- 'b' is a no-op under 2.x, while 'U' is
# ignored under 3.x (but 3.x compile() accepts \r\n newlines) # ignored under 3.x (but 3.x compile() accepts \r\n newlines)
@ -132,6 +140,7 @@ def execfile_(filepath, _globals, open=open):
class _DeprecationWrapper(object): class _DeprecationWrapper(object):
def __init__(self, mod, deprecated): def __init__(self, mod, deprecated):
# type: (Any, Dict) -> None
self._mod = mod self._mod = mod
self._deprecated = deprecated self._deprecated = deprecated
@ -145,7 +154,7 @@ class _DeprecationWrapper(object):
return getattr(self._mod, attr) return getattr(self._mod, attr)
sys.modules[__name__] = _DeprecationWrapper(sys.modules[__name__], dict( sys.modules[__name__] = _DeprecationWrapper(sys.modules[__name__], dict( # type: ignore
zip_longest = zip_longest, zip_longest = zip_longest,
product = product, product = product,
all = all, all = all,