Merge pull request #6425 from tk0miya/refactor_type_annotation2

Migrate to py3 style type annotation: sphinx.util
This commit is contained in:
Takeshi KOMIYA 2019-06-02 21:44:16 +09:00 committed by GitHub
commit ef4ad32025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 237 deletions

View File

@ -24,6 +24,9 @@ from datetime import datetime
from hashlib import md5
from os import path
from time import mktime, strptime
from typing import (
Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Set, Tuple, Type
)
from urllib.parse import urlsplit, urlunsplit, quote_plus, parse_qsl, urlencode
from docutils.utils import relative_path
@ -34,6 +37,7 @@ from sphinx.locale import __
from sphinx.util import logging
from sphinx.util.console import strip_colors, colorize, bold, term_width_line # type: ignore
from sphinx.util.fileutil import copy_asset_file
from sphinx.util.typing import PathMatcher
from sphinx.util import smartypants # noqa
# import other utilities; partly for backwards compatibility, so don't
@ -46,10 +50,11 @@ from sphinx.util.nodes import ( # noqa
caption_ref_re)
from sphinx.util.matching import patfilter # noqa
if False:
# For type annotation
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Set, Tuple, Type, Union # NOQA
from sphinx.application import Sphinx
from sphinx.builders import Builder
logger = logging.getLogger(__name__)
@ -60,21 +65,19 @@ url_re = re.compile(r'(?P<schema>.+)://.*') # type: Pattern
# High-level utility functions.
def docname_join(basedocname, docname):
# type: (str, str) -> str
def docname_join(basedocname: str, docname: str) -> str:
return posixpath.normpath(
posixpath.join('/' + basedocname, '..', docname))[1:]
def path_stabilize(filepath):
# type: (str) -> str
def path_stabilize(filepath: str) -> str:
"normalize path separater and unicode string"
newpath = filepath.replace(os.path.sep, SEP)
return unicodedata.normalize('NFC', newpath)
def get_matching_files(dirname, exclude_matchers=()):
# type: (str, Tuple[Callable[[str], bool], ...]) -> Iterable[str]
def get_matching_files(dirname: str,
exclude_matchers: Tuple[PathMatcher, ...] = ()) -> Iterable[str]: # NOQA
"""Get all file names in a directory, recursively.
Exclude files and dirs matching some matcher in *exclude_matchers*.
@ -100,8 +103,8 @@ def get_matching_files(dirname, exclude_matchers=()):
yield filename
def get_matching_docs(dirname, suffixes, exclude_matchers=()):
# type: (str, List[str], Tuple[Callable[[str], bool], ...]) -> Iterable[str] # NOQA
def get_matching_docs(dirname: str, suffixes: List[str],
exclude_matchers: Tuple[PathMatcher, ...] = ()) -> Iterable[str]:
"""Get all file names (without suffixes) matching a suffix in a directory,
recursively.
@ -123,12 +126,10 @@ class FilenameUniqDict(dict):
interpreted as filenames, and keeps track of a set of docnames they
appear in. Used for images and downloadable files in the environment.
"""
def __init__(self):
# type: () -> None
def __init__(self) -> None:
self._existing = set() # type: Set[str]
def add_file(self, docname, newfile):
# type: (str, str) -> str
def add_file(self, docname: str, newfile: str) -> str:
if newfile in self:
self[newfile][0].add(docname)
return self[newfile][1]
@ -142,26 +143,22 @@ class FilenameUniqDict(dict):
self._existing.add(uniquename)
return uniquename
def purge_doc(self, docname):
# type: (str) -> None
def purge_doc(self, docname: str) -> None:
for filename, (docs, unique) in list(self.items()):
docs.discard(docname)
if not docs:
del self[filename]
self._existing.discard(unique)
def merge_other(self, docnames, other):
# type: (Set[str], Dict[str, Tuple[Set[str], Any]]) -> None
def merge_other(self, docnames: Set[str], other: Dict[str, Tuple[Set[str], Any]]) -> None:
for filename, (docs, unique) in other.items():
for doc in docs & set(docnames):
self.add_file(doc, filename)
def __getstate__(self):
# type: () -> Set[str]
def __getstate__(self) -> Set[str]:
return self._existing
def __setstate__(self, state):
# type: (Set[str]) -> None
def __setstate__(self, state: Set[str]) -> None:
self._existing = state
@ -172,8 +169,7 @@ class DownloadFiles(dict):
Hence don't hack this directly.
"""
def add_file(self, docname, filename):
# type: (str, str) -> None
def add_file(self, docname: str, filename: str) -> None:
if filename not in self:
digest = md5(filename.encode()).hexdigest()
dest = '%s/%s' % (digest, os.path.basename(filename))
@ -182,23 +178,20 @@ class DownloadFiles(dict):
self[filename][0].add(docname)
return self[filename][1]
def purge_doc(self, docname):
# type: (str) -> None
def purge_doc(self, docname: str) -> None:
for filename, (docs, dest) in list(self.items()):
docs.discard(docname)
if not docs:
del self[filename]
def merge_other(self, docnames, other):
# type: (Set[str], Dict[str, Tuple[Set[str], Any]]) -> None
def merge_other(self, docnames: Set[str], other: Dict[str, Tuple[Set[str], Any]]) -> None:
for filename, (docs, dest) in other.items():
for docname in docs & set(docnames):
self.add_file(docname, filename)
def copy_static_entry(source, targetdir, builder, context={},
exclude_matchers=(), level=0):
# type: (str, str, Any, Dict, Tuple[Callable, ...], int) -> None
def copy_static_entry(source: str, targetdir: str, builder: "Builder", context: Dict = {},
exclude_matchers: Tuple[PathMatcher, ...] = (), level: int = 0) -> None:
"""[DEPRECATED] Copy a HTML builder static_path entry from source to targetdir.
Handles all possible cases of files, directories and subdirectories.
@ -237,8 +230,7 @@ _DEBUG_HEADER = '''\
'''
def save_traceback(app):
# type: (Any) -> str
def save_traceback(app: "Sphinx") -> str:
"""Save the current exception's traceback in a temporary file."""
import sphinx
import jinja2
@ -273,8 +265,7 @@ def save_traceback(app):
return path
def get_module_source(modname):
# type: (str) -> Tuple[str, str]
def get_module_source(modname: str) -> Tuple[str, str]:
"""Try to find the source code for a module.
Can return ('file', 'filename') in which case the source is in the given
@ -321,8 +312,7 @@ def get_module_source(modname):
return 'file', filename
def get_full_modname(modname, attribute):
# type: (str, str) -> str
def get_full_modname(modname: str, attribute: str) -> str:
if modname is None:
# Prevents a TypeError: if the last getattr() call will return None
# then it's better to return it directly
@ -344,8 +334,7 @@ def get_full_modname(modname, attribute):
_coding_re = re.compile(r'coding[:=]\s*([-\w.]+)')
def detect_encoding(readline):
# type: (Callable[[], bytes]) -> str
def detect_encoding(readline: Callable[[], bytes]) -> str:
"""Like tokenize.detect_encoding() from Py3k, but a bit simplified."""
def read_or_stop():
@ -401,12 +390,10 @@ def detect_encoding(readline):
class UnicodeDecodeErrorHandler:
"""Custom error handler for open() that warns and replaces."""
def __init__(self, docname):
# type: (str) -> None
def __init__(self, docname: str) -> None:
self.docname = docname
def __call__(self, error):
# type: (UnicodeDecodeError) -> Tuple[Union[str, str], int]
def __call__(self, error: UnicodeDecodeError) -> Tuple[str, int]:
linestart = error.object.rfind(b'\n', 0, error.start)
lineend = error.object.find(b'\n', error.start)
if lineend == -1:
@ -426,26 +413,22 @@ class Tee:
"""
File-like object writing to two streams.
"""
def __init__(self, stream1, stream2):
# type: (IO, IO) -> None
def __init__(self, stream1: IO, stream2: IO) -> None:
self.stream1 = stream1
self.stream2 = stream2
def write(self, text):
# type: (str) -> None
def write(self, text: str) -> None:
self.stream1.write(text)
self.stream2.write(text)
def flush(self):
# type: () -> None
def flush(self) -> None:
if hasattr(self.stream1, 'flush'):
self.stream1.flush()
if hasattr(self.stream2, 'flush'):
self.stream2.flush()
def parselinenos(spec, total):
# type: (str, int) -> List[int]
def parselinenos(spec: str, total: int) -> List[int]:
"""Parse a line number spec (such as "1,2,4-6") and return a list of
wanted line numbers.
"""
@ -472,8 +455,7 @@ def parselinenos(spec, total):
return items
def force_decode(string, encoding):
# type: (str, str) -> str
def force_decode(string: str, encoding: str) -> str:
"""Forcibly get a unicode string out of a bytestring."""
warnings.warn('force_decode() is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
@ -491,26 +473,22 @@ def force_decode(string, encoding):
class attrdict(dict):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
warnings.warn('The attrdict class is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
def __getattr__(self, key):
# type: (str) -> str
def __getattr__(self, key: str) -> str:
return self[key]
def __setattr__(self, key, val):
# type: (str, str) -> None
def __setattr__(self, key: str, val: str) -> None:
self[key] = val
def __delattr__(self, key):
# type: (str) -> None
def __delattr__(self, key: str) -> None:
del self[key]
def rpartition(s, t):
# type: (str, str) -> Tuple[str, str]
def rpartition(s: str, t: str) -> Tuple[str, str]:
"""Similar to str.rpartition from 2.5, but doesn't return the separator."""
i = s.rfind(t)
if i != -1:
@ -518,8 +496,7 @@ def rpartition(s, t):
return '', s
def split_into(n, type, value):
# type: (int, str, str) -> List[str]
def split_into(n: int, type: str, value: str) -> List[str]:
"""Split an index entry into a given number of parts at semicolons."""
parts = [x.strip() for x in value.split(';', n - 1)]
if sum(1 for part in parts if part) < n:
@ -527,8 +504,7 @@ def split_into(n, type, value):
return parts
def split_index_msg(type, value):
# type: (str, str) -> List[str]
def split_index_msg(type: str, value: str) -> List[str]:
# new entry types must be listed in directives/other.py!
if type == 'single':
try:
@ -549,8 +525,7 @@ def split_index_msg(type, value):
return result
def format_exception_cut_frames(x=1):
# type: (int) -> str
def format_exception_cut_frames(x: int = 1) -> str:
"""Format an exception with traceback, but only the last x frames."""
typ, val, tb = sys.exc_info()
# res = ['Traceback (most recent call last):\n']
@ -566,19 +541,16 @@ class PeekableIterator:
An iterator which wraps any iterable and makes it possible to peek to see
what's the next item.
"""
def __init__(self, iterable):
# type: (Iterable) -> None
def __init__(self, iterable: Iterable) -> None:
self.remaining = deque() # type: deque
self._iterator = iter(iterable)
warnings.warn('PeekableIterator is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
def __iter__(self):
# type: () -> PeekableIterator
def __iter__(self) -> "PeekableIterator":
return self
def __next__(self):
# type: () -> Any
def __next__(self) -> Any:
"""Return the next item from the iterator."""
if self.remaining:
return self.remaining.popleft()
@ -586,23 +558,20 @@ class PeekableIterator:
next = __next__ # Python 2 compatibility
def push(self, item):
# type: (Any) -> None
def push(self, item: Any) -> None:
"""Push the `item` on the internal stack, it will be returned on the
next :meth:`next` call.
"""
self.remaining.append(item)
def peek(self):
# type: () -> Any
def peek(self) -> Any:
"""Return the next item without changing the state of the iterator."""
item = next(self)
self.push(item)
return item
def import_object(objname, source=None):
# type: (str, str) -> Any
def import_object(objname: str, source: str = None) -> Any:
"""Import python object by qualname."""
try:
objpath = objname.split('.')
@ -625,8 +594,7 @@ def import_object(objname, source=None):
raise ExtensionError('Could not import %s' % objname, exc)
def encode_uri(uri):
# type: (str) -> str
def encode_uri(uri: str) -> str:
split = list(urlsplit(uri))
split[1] = split[1].encode('idna').decode('ascii')
split[2] = quote_plus(split[2].encode(), '/')
@ -635,8 +603,7 @@ def encode_uri(uri):
return urlunsplit(split)
def display_chunk(chunk):
# type: (Any) -> str
def display_chunk(chunk: Any) -> str:
if isinstance(chunk, (list, tuple)):
if len(chunk) == 1:
return str(chunk[0])
@ -644,8 +611,8 @@ def display_chunk(chunk):
return str(chunk)
def old_status_iterator(iterable, summary, color="darkgreen", stringify_func=display_chunk):
# type: (Iterable, str, str, Callable[[Any], str]) -> Iterator
def old_status_iterator(iterable: Iterable, summary: str, color: str = "darkgreen",
stringify_func: Callable[[Any], str] = display_chunk) -> Iterator:
l = 0
for item in iterable:
if l == 0:
@ -659,9 +626,9 @@ def old_status_iterator(iterable, summary, color="darkgreen", stringify_func=dis
# new version with progress info
def status_iterator(iterable, summary, color="darkgreen", length=0, verbosity=0,
stringify_func=display_chunk):
# type: (Iterable, str, str, int, int, Callable[[Any], str]) -> Iterable
def status_iterator(iterable: Iterable, summary: str, color: str = "darkgreen",
length: int = 0, verbosity: int = 0,
stringify_func: Callable[[Any], str] = display_chunk) -> Iterable:
if length == 0:
yield from old_status_iterator(iterable, summary, color, stringify_func)
return
@ -685,16 +652,13 @@ class SkipProgressMessage(Exception):
class progress_message:
def __init__(self, message):
# type: (str) -> None
def __init__(self, message: str) -> None:
self.message = message
def __enter__(self):
# type: () -> None
def __enter__(self) -> None:
logger.info(bold(self.message + '... '), nonl=True)
def __exit__(self, exc_type, exc_value, traceback):
# type: (Type[Exception], Exception, Any) -> bool
def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: Any) -> bool: # NOQA
if isinstance(exc_value, SkipProgressMessage):
logger.info(__('skipped'))
if exc_value.args:
@ -707,8 +671,7 @@ class progress_message:
return False
def __call__(self, f):
# type: (Callable) -> Callable
def __call__(self, f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(*args, **kwargs):
with self:
@ -717,8 +680,7 @@ class progress_message:
return wrapper
def epoch_to_rfc1123(epoch):
# type: (float) -> str
def epoch_to_rfc1123(epoch: float) -> str:
"""Convert datetime format epoch to RFC1123."""
from babel.dates import format_datetime
@ -727,13 +689,11 @@ def epoch_to_rfc1123(epoch):
return format_datetime(dt, fmt, locale='en') + ' GMT'
def rfc1123_to_epoch(rfc1123):
# type: (str) -> float
def rfc1123_to_epoch(rfc1123: str) -> float:
return mktime(strptime(rfc1123, '%a, %d %b %Y %H:%M:%S %Z'))
def xmlname_checker():
# type: () -> Pattern
def xmlname_checker() -> Pattern:
# https://www.w3.org/TR/REC-xml/#NT-Name
name_start_chars = [
':', ['A', 'Z'], '_', ['a', 'z'], ['\u00C0', '\u00D6'],
@ -747,8 +707,7 @@ def xmlname_checker():
['\u203F', '\u2040']
]
def convert(entries, splitter='|'):
# type: (Any, str) -> str
def convert(entries: Any, splitter: str = '|') -> str:
results = []
for entry in entries:
if isinstance(entry, list):

View File

@ -9,14 +9,10 @@
"""
import sys
if False:
# For type annotation
from typing import List # NOQA
from typing import List
def prepare_docstring(s, ignore=1, tabsize=8):
# type: (str, int, int) -> List[str]
def prepare_docstring(s: str, ignore: int = 1, tabsize: int = 8) -> List[str]:
"""Convert a docstring into lines of parseable reST. Remove common leading
indentation, where the indentation of a given number of lines (usually just
one) is ignored.
@ -49,8 +45,7 @@ def prepare_docstring(s, ignore=1, tabsize=8):
return lines
def prepare_commentdoc(s):
# type: (str) -> List[str]
def prepare_commentdoc(s: str) -> List[str]:
"""Extract documentation comment lines (starting with #:) and return them
as a list of lines. Returns an empty list if there is no documentation.
"""

View File

@ -10,20 +10,20 @@
import os
import posixpath
from typing import Dict
from docutils.utils import relative_path
from sphinx.util.osutil import copyfile, ensuredir
from sphinx.util.typing import PathMatcher
if False:
# For type annotation
from typing import Callable, Dict, Union # NOQA
from sphinx.util.matching import Matcher # NOQA
from sphinx.util.template import BaseRenderer # NOQA
from sphinx.util.template import BaseRenderer
def copy_asset_file(source, destination, context=None, renderer=None):
# type: (str, str, Dict, BaseRenderer) -> None
def copy_asset_file(source: str, destination: str,
context: Dict = None, renderer: "BaseRenderer" = None) -> None:
"""Copy an asset file to destination.
On copying, it expands the template variables if context argument is given and
@ -55,8 +55,8 @@ def copy_asset_file(source, destination, context=None, renderer=None):
copyfile(source, destination)
def copy_asset(source, destination, excluded=lambda path: False, context=None, renderer=None):
# type: (str, str, Union[Callable[[str], bool], Matcher], Dict, BaseRenderer) -> None
def copy_asset(source: str, destination: str, excluded: PathMatcher = lambda path: False,
context: Dict = None, renderer: "BaseRenderer" = None) -> None:
"""Copy asset files to destination recursively.
On copying, it expands the template variables if context argument is given and

View File

@ -14,6 +14,7 @@ import warnings
from collections import namedtuple
from datetime import datetime
from os import path
from typing import Callable, Generator, List, Set, Tuple
import babel.dates
from babel.messages.mofile import write_mo
@ -26,13 +27,12 @@ from sphinx.util import logging
from sphinx.util.matching import Matcher
from sphinx.util.osutil import SEP, canon_path, relpath
logger = logging.getLogger(__name__)
if False:
# For type annotation
from typing import Callable, Generator, List, Set, Tuple # NOQA
from sphinx.environment import BuildEnvironment # NOQA
from sphinx.environment import BuildEnvironment
logger = logging.getLogger(__name__)
LocaleFileInfoBase = namedtuple('CatalogInfo', 'base_dir,domain,charset')
@ -40,33 +40,27 @@ LocaleFileInfoBase = namedtuple('CatalogInfo', 'base_dir,domain,charset')
class CatalogInfo(LocaleFileInfoBase):
@property
def po_file(self):
# type: () -> str
def po_file(self) -> str:
return self.domain + '.po'
@property
def mo_file(self):
# type: () -> str
def mo_file(self) -> str:
return self.domain + '.mo'
@property
def po_path(self):
# type: () -> str
def po_path(self) -> str:
return path.join(self.base_dir, self.po_file)
@property
def mo_path(self):
# type: () -> str
def mo_path(self) -> str:
return path.join(self.base_dir, self.mo_file)
def is_outdated(self):
# type: () -> bool
def is_outdated(self) -> bool:
return (
not path.exists(self.mo_path) or
path.getmtime(self.mo_path) < path.getmtime(self.po_path))
def write_mo(self, locale):
# type: (str) -> None
def write_mo(self, locale: str) -> None:
with open(self.po_path, encoding=self.charset) as file_po:
try:
po = read_po(file_po, locale)
@ -84,16 +78,15 @@ class CatalogInfo(LocaleFileInfoBase):
class CatalogRepository:
"""A repository for message catalogs."""
def __init__(self, basedir, locale_dirs, language, encoding):
# type: (str, List[str], str, str) -> None
def __init__(self, basedir: str, locale_dirs: List[str],
language: str, encoding: str) -> None:
self.basedir = basedir
self._locale_dirs = locale_dirs
self.language = language
self.encoding = encoding
@property
def locale_dirs(self):
# type: () -> Generator[str, None, None]
def locale_dirs(self) -> Generator[str, None, None]:
if not self.language:
return
@ -103,8 +96,7 @@ class CatalogRepository:
yield locale_dir
@property
def pofiles(self):
# type: () -> Generator[Tuple[str, str], None, None]
def pofiles(self) -> Generator[Tuple[str, str], None, None]:
for locale_dir in self.locale_dirs:
basedir = path.join(locale_dir, self.language, 'LC_MESSAGES')
for root, dirnames, filenames in os.walk(basedir):
@ -119,15 +111,13 @@ class CatalogRepository:
yield basedir, relpath(fullpath, basedir)
@property
def catalogs(self):
# type: () -> Generator[CatalogInfo, None, None]
def catalogs(self) -> Generator[CatalogInfo, None, None]:
for basedir, filename in self.pofiles:
domain = canon_path(path.splitext(filename)[0])
yield CatalogInfo(basedir, domain, self.encoding)
def find_catalog(docname, compaction):
# type: (str, bool) -> str
def find_catalog(docname: str, compaction: bool) -> str:
warnings.warn('find_catalog() is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
if compaction:
@ -138,8 +128,7 @@ def find_catalog(docname, compaction):
return ret
def docname_to_domain(docname, compation):
# type: (str, bool) -> str
def docname_to_domain(docname: str, compation: bool) -> str:
"""Convert docname to domain for catalogs."""
if compation:
return docname.split(SEP, 1)[0]
@ -147,8 +136,8 @@ def docname_to_domain(docname, compation):
return docname
def find_catalog_files(docname, srcdir, locale_dirs, lang, compaction):
# type: (str, str, List[str], str, bool) -> List[str]
def find_catalog_files(docname: str, srcdir: str, locale_dirs: List[str],
lang: str, compaction: bool) -> List[str]:
warnings.warn('find_catalog_files() is deprecated.',
RemovedInSphinx40Warning, stacklevel=2)
if not(lang and locale_dirs):
@ -161,10 +150,10 @@ def find_catalog_files(docname, srcdir, locale_dirs, lang, compaction):
return files
def find_catalog_source_files(locale_dirs, locale, domains=None, gettext_compact=None,
charset='utf-8', force_all=False,
excluded=Matcher([])):
# type: (List[str], str, List[str], bool, str, bool, Matcher) -> Set[CatalogInfo]
def find_catalog_source_files(locale_dirs: List[str], locale: str, domains: List[str] = None,
gettext_compact: bool = None, charset: str = 'utf-8',
force_all: bool = False, excluded: Matcher = Matcher([])
) -> Set[CatalogInfo]:
"""
:param list locale_dirs:
list of path as `['locale_dir1', 'locale_dir2', ...]` to find
@ -256,8 +245,8 @@ date_format_mappings = {
date_format_re = re.compile('(%s)' % '|'.join(date_format_mappings))
def babel_format_date(date, format, locale, formatter=babel.dates.format_date):
# type: (datetime, str, str, Callable) -> str
def babel_format_date(date: datetime, format: str, locale: str,
formatter: Callable = babel.dates.format_date) -> str:
if locale is None:
locale = 'en'
@ -277,8 +266,7 @@ def babel_format_date(date, format, locale, formatter=babel.dates.format_date):
return format
def format_date(format, date=None, language=None):
# type: (str, datetime, str) -> str
def format_date(format: str, date: datetime = None, language: str = None) -> str:
if date is None:
# If time is not specified, try to use $SOURCE_DATE_EPOCH variable
# See https://wiki.debian.org/ReproducibleBuilds/TimestampsProposal
@ -312,8 +300,7 @@ def format_date(format, date=None, language=None):
return "".join(result)
def get_image_filename_for_language(filename, env):
# type: (str, BuildEnvironment) -> str
def get_image_filename_for_language(filename: str, env: "BuildEnvironment") -> str:
if not env.config.language:
return filename
@ -332,8 +319,7 @@ def get_image_filename_for_language(filename, env):
raise SphinxError('Invalid figure_language_filename: %r' % exc)
def search_image_for_language(filename, env):
# type: (str, BuildEnvironment) -> str
def search_image_for_language(filename: str, env: "BuildEnvironment") -> str:
if not env.config.language:
return filename

View File

@ -14,7 +14,7 @@ import warnings
from collections import OrderedDict
from io import BytesIO
from os import path
from typing import NamedTuple
from typing import IO, NamedTuple, Tuple
import imagesize
@ -25,10 +25,6 @@ try:
except ImportError:
Image = None
if False:
# For type annotation
from typing import IO, Tuple # NOQA
mime_suffixes = OrderedDict([
('.gif', 'image/gif'),
('.jpg', 'image/jpeg'),
@ -43,8 +39,7 @@ DataURI = NamedTuple('DataURI', [('mimetype', str),
('data', bytes)])
def get_image_size(filename):
# type: (str) -> Tuple[int, int]
def get_image_size(filename: str) -> Tuple[int, int]:
try:
size = imagesize.get(filename)
if size[0] == -1:
@ -63,8 +58,7 @@ def get_image_size(filename):
return None
def guess_mimetype_for_stream(stream, default=None):
# type: (IO, str) -> str
def guess_mimetype_for_stream(stream: IO, default: str = None) -> str:
imgtype = imghdr.what(stream) # type: ignore
if imgtype:
return 'image/' + imgtype
@ -72,8 +66,7 @@ def guess_mimetype_for_stream(stream, default=None):
return default
def guess_mimetype(filename='', content=None, default=None):
# type: (str, bytes, str) -> str
def guess_mimetype(filename: str = '', content: bytes = None, default: str = None) -> str:
_, ext = path.splitext(filename.lower())
if ext in mime_suffixes:
return mime_suffixes[ext]
@ -90,8 +83,7 @@ def guess_mimetype(filename='', content=None, default=None):
return default
def get_image_extension(mimetype):
# type: (str) -> str
def get_image_extension(mimetype: str) -> str:
for ext, _mimetype in mime_suffixes.items():
if mimetype == _mimetype:
return ext
@ -99,8 +91,7 @@ def get_image_extension(mimetype):
return None
def parse_data_uri(uri):
# type: (str) -> DataURI
def parse_data_uri(uri: str) -> DataURI:
if not uri.startswith('data:'):
return None
@ -121,8 +112,7 @@ def parse_data_uri(uri):
return DataURI(mimetype, charset, image_data)
def test_svg(h, f):
# type: (bytes, IO) -> str
def test_svg(h: bytes, f: IO) -> str:
"""An additional imghdr library helper; test the header is SVG's or not."""
try:
if '<svg' in h.decode().lower():

View File

@ -20,15 +20,12 @@ from inspect import ( # NOQA
isclass, ismethod, ismethoddescriptor, isroutine
)
from io import StringIO
from typing import Any, Callable, Mapping, List, Tuple
from sphinx.deprecation import RemovedInSphinx30Warning
from sphinx.util import logging
from sphinx.util.typing import NoneType
if False:
# For type annotation
from typing import Any, Callable, Mapping, List, Tuple, Type # NOQA
if sys.version_info > (3, 7):
from types import (
ClassMethodDescriptorType,
@ -114,26 +111,22 @@ def getargspec(func):
kwonlyargs, kwdefaults, annotations)
def isenumclass(x):
# type: (Type) -> bool
def isenumclass(x: Any) -> bool:
"""Check if the object is subclass of enum."""
return inspect.isclass(x) and issubclass(x, enum.Enum)
def isenumattribute(x):
# type: (Any) -> bool
def isenumattribute(x: Any) -> bool:
"""Check if the object is attribute of enum."""
return isinstance(x, enum.Enum)
def ispartial(obj):
# type: (Any) -> bool
def ispartial(obj: Any) -> bool:
"""Check if the object is partial."""
return isinstance(obj, (partial, partialmethod))
def isclassmethod(obj):
# type: (Any) -> bool
def isclassmethod(obj: Any) -> bool:
"""Check if the object is classmethod."""
if isinstance(obj, classmethod):
return True
@ -143,8 +136,7 @@ def isclassmethod(obj):
return False
def isstaticmethod(obj, cls=None, name=None):
# type: (Any, Any, str) -> bool
def isstaticmethod(obj: Any, cls: Any = None, name: str = None) -> bool:
"""Check if the object is staticmethod."""
if isinstance(obj, staticmethod):
return True
@ -163,8 +155,7 @@ def isstaticmethod(obj, cls=None, name=None):
return False
def isdescriptor(x):
# type: (Any) -> bool
def isdescriptor(x: Any) -> bool:
"""Check if the object is some kind of descriptor."""
for item in '__get__', '__set__', '__delete__':
if hasattr(safe_getattr(x, item, None), '__call__'):
@ -172,14 +163,12 @@ def isdescriptor(x):
return False
def isabstractmethod(obj):
# type: (Any) -> bool
def isabstractmethod(obj: Any) -> bool:
"""Check if the object is an abstractmethod."""
return safe_getattr(obj, '__isabstractmethod__', False) is True
def isattributedescriptor(obj):
# type: (Any) -> bool
def isattributedescriptor(obj: Any) -> bool:
"""Check if the object is an attribute like descriptor."""
if inspect.isdatadescriptor(object):
# data descriptor is kind of attribute
@ -206,20 +195,17 @@ def isattributedescriptor(obj):
return False
def isfunction(obj):
# type: (Any) -> bool
def isfunction(obj: Any) -> bool:
"""Check if the object is function."""
return inspect.isfunction(obj) or ispartial(obj) and inspect.isfunction(obj.func)
def isbuiltin(obj):
# type: (Any) -> bool
def isbuiltin(obj: Any) -> bool:
"""Check if the object is builtin."""
return inspect.isbuiltin(obj) or ispartial(obj) and inspect.isbuiltin(obj.func)
def iscoroutinefunction(obj):
# type: (Any) -> bool
def iscoroutinefunction(obj: Any) -> bool:
"""Check if the object is coroutine-function."""
if inspect.iscoroutinefunction(obj):
return True
@ -230,14 +216,12 @@ def iscoroutinefunction(obj):
return False
def isproperty(obj):
# type: (Any) -> bool
def isproperty(obj: Any) -> bool:
"""Check if the object is property."""
return isinstance(obj, property)
def safe_getattr(obj, name, *defargs):
# type: (Any, str, Any) -> Any
def safe_getattr(obj: Any, name: str, *defargs) -> Any:
"""A getattr() that turns all exceptions into AttributeErrors."""
try:
return getattr(obj, name, *defargs)
@ -259,8 +243,8 @@ def safe_getattr(obj, name, *defargs):
raise AttributeError(name)
def safe_getmembers(object, predicate=None, attr_getter=safe_getattr):
# type: (Any, Callable[[str], bool], Callable) -> List[Tuple[str, Any]]
def safe_getmembers(object: Any, predicate: Callable[[str], bool] = None,
attr_getter: Callable = safe_getattr) -> List[Tuple[str, Any]]:
"""A version of inspect.getmembers() that uses safe_getattr()."""
results = [] # type: List[Tuple[str, Any]]
for key in dir(object):
@ -274,8 +258,7 @@ def safe_getmembers(object, predicate=None, attr_getter=safe_getattr):
return results
def object_description(object):
# type: (Any) -> str
def object_description(object: Any) -> str:
"""A repr() implementation that returns text safe to use in reST context."""
if isinstance(object, dict):
try:
@ -312,8 +295,7 @@ def object_description(object):
return s.replace('\n', ' ')
def is_builtin_class_method(obj, attr_name):
# type: (Any, str) -> bool
def is_builtin_class_method(obj: Any, attr_name: str) -> bool:
"""If attr_name is implemented at builtin class, return True.
>>> is_builtin_class_method(int, '__init__')
@ -339,8 +321,8 @@ class Parameter:
VAR_KEYWORD = 4
empty = object()
def __init__(self, name, kind=POSITIONAL_OR_KEYWORD, default=empty):
# type: (str, int, Any) -> None
def __init__(self, name: str, kind: int = POSITIONAL_OR_KEYWORD,
default: Any = empty) -> None:
self.name = name
self.kind = kind
self.default = default
@ -355,8 +337,8 @@ class Signature:
its return annotation.
"""
def __init__(self, subject, bound_method=False, has_retval=True):
# type: (Callable, bool, bool) -> None
def __init__(self, subject: Callable, bound_method: bool = False,
has_retval: bool = True) -> None:
# check subject is not a built-in class (ex. int, str)
if (isinstance(subject, type) and
is_builtin_class_method(subject, "__new__") and
@ -401,16 +383,14 @@ class Signature:
self.skip_first_argument = False
@property
def parameters(self):
# type: () -> Mapping
def parameters(self) -> Mapping:
if self.partialmethod_with_noargs:
return {}
else:
return self.signature.parameters
@property
def return_annotation(self):
# type: () -> Any
def return_annotation(self) -> Any:
if self.signature:
if self.has_retval:
return self.signature.return_annotation
@ -419,8 +399,7 @@ class Signature:
else:
return None
def format_args(self, show_annotation=True):
# type: (bool) -> str
def format_args(self, show_annotation: bool = True) -> str:
args = []
last_kind = None
for i, param in enumerate(self.parameters.values()):
@ -475,8 +454,7 @@ class Signature:
return '(%s) -> %s' % (', '.join(args), annotation)
def format_annotation(self, annotation):
# type: (Any) -> str
def format_annotation(self, annotation: Any) -> str:
"""Return formatted representation of a type annotation.
Show qualified names for types and additional details for types from
@ -502,8 +480,7 @@ class Signature:
else:
return self.format_annotation_old(annotation)
def format_annotation_new(self, annotation):
# type: (Any) -> str
def format_annotation_new(self, annotation: Any) -> str:
"""format_annotation() for py37+"""
module = getattr(annotation, '__module__', None)
if module == 'typing':
@ -539,8 +516,7 @@ class Signature:
return qualname
def format_annotation_old(self, annotation):
# type: (Any) -> str
def format_annotation_old(self, annotation: Any) -> str:
"""format_annotation() for py36 or below"""
module = getattr(annotation, '__module__', None)
if module == 'typing':
@ -641,8 +617,8 @@ class Signature:
return qualname
def getdoc(obj, attrgetter=safe_getattr, allow_inherited=False):
# type: (Any, Callable, bool) -> str
def getdoc(obj: Any, attrgetter: Callable = safe_getattr,
allow_inherited: bool = False) -> str:
"""Get the docstring for the object.
This tries to obtain the docstring for some kind of objects additionally:

View File

@ -23,6 +23,9 @@ TextlikeNode = Union[nodes.Text, nodes.TextElement]
# type of None
NoneType = type(None)
# path matcher
PathMatcher = Callable[[str], bool]
# common role functions
RoleFunction = Callable[[str, str, str, int, Inliner, Dict, List[str]],
Tuple[List[nodes.Node], List[nodes.system_message]]]

View File

@ -793,7 +793,7 @@ def test_autodoc_imported_members(app):
"imported-members": None,
"ignore-module-all": None}
actual = do_autodoc(app, 'module', 'target', options)
assert '.. py:function:: save_traceback(app)' in actual
assert '.. py:function:: save_traceback(app: Sphinx) -> str' in actual
@pytest.mark.sphinx('html', testroot='ext-autodoc')
@ -1795,7 +1795,7 @@ def test_autodoc_default_options(app):
actual = do_autodoc(app, 'class', 'target.CustomIter')
assert ' .. py:method:: target.CustomIter' not in actual
actual = do_autodoc(app, 'module', 'target')
assert '.. py:function:: save_traceback(app)' not in actual
assert '.. py:function:: save_traceback(app: Sphinx) -> str' not in actual
# with :members:
app.config.autodoc_default_options = {'members': None}
@ -1866,7 +1866,8 @@ def test_autodoc_default_options(app):
'ignore-module-all': None,
}
actual = do_autodoc(app, 'module', 'target')
assert '.. py:function:: save_traceback(app)' in actual
print('\n'.join(actual))
assert '.. py:function:: save_traceback(app: Sphinx) -> str' in actual
@pytest.mark.sphinx('html', testroot='ext-autodoc')