Migrate to py3 style type annotation: sphinx.domains.python

This commit is contained in:
Takeshi KOMIYA 2019-06-29 23:45:57 +09:00
parent 3a81e0ad7d
commit ee6e44a04f

View File

@ -10,31 +10,31 @@
import re
import warnings
from typing import Any, Dict, Iterable, Iterator, List, Tuple, Type
from typing import cast
from docutils import nodes
from docutils.nodes import Element, Node
from docutils.parsers.rst import directives
from sphinx import addnodes, locale
from sphinx.addnodes import pending_xref, desc_signature
from sphinx.application import Sphinx
from sphinx.builders import Builder
from sphinx.deprecation import (
DeprecatedDict, RemovedInSphinx30Warning, RemovedInSphinx40Warning
)
from sphinx.directives import ObjectDescription
from sphinx.domains import Domain, ObjType, Index, IndexEntry
from sphinx.environment import BuildEnvironment
from sphinx.locale import _, __
from sphinx.roles import XRefRole
from sphinx.util import logging
from sphinx.util.docfields import Field, GroupedField, TypedField
from sphinx.util.docutils import SphinxDirective
from sphinx.util.nodes import make_refnode
from sphinx.util.typing import TextlikeNode
if False:
# For type annotation
from typing import Any, Dict, Iterable, Iterator, List, Tuple, Type # NOQA
from sphinx.application import Sphinx # NOQA
from sphinx.builders import Builder # NOQA
from sphinx.environment import BuildEnvironment # NOQA
from sphinx.util.typing import TextlikeNode # NOQA
logger = logging.getLogger(__name__)
@ -67,8 +67,7 @@ locale.pairindextypes = DeprecatedDict(
)
def _pseudo_parse_arglist(signode, arglist):
# type: (addnodes.desc_signature, str) -> None
def _pseudo_parse_arglist(signode: desc_signature, arglist: str) -> None:
""""Parse" a list of arguments separated by commas.
Arguments can have "optional" annotations given by enclosing them in
@ -76,7 +75,7 @@ def _pseudo_parse_arglist(signode, arglist):
string literal (e.g. default argument value).
"""
paramlist = addnodes.desc_parameterlist()
stack = [paramlist] # type: List[nodes.Element]
stack = [paramlist] # type: List[Element]
try:
for argument in arglist.split(','):
argument = argument.strip()
@ -119,15 +118,9 @@ def _pseudo_parse_arglist(signode, arglist):
# This override allows our inline type specifiers to behave like :class: link
# when it comes to handling "." and "~" prefixes.
class PyXrefMixin:
def make_xref(self,
rolename, # type: str
domain, # type: str
target, # type: str
innernode=nodes.emphasis, # type: Type[TextlikeNode]
contnode=None, # type: nodes.Node
env=None, # type: BuildEnvironment
):
# type: (...) -> nodes.Node
def make_xref(self, rolename: str, domain: str, target: str,
innernode: Type[TextlikeNode] = nodes.emphasis,
contnode: Node = None, env: BuildEnvironment = None) -> Node:
result = super().make_xref(rolename, domain, target, # type: ignore
innernode, contnode, env)
result['refspecific'] = True
@ -142,15 +135,9 @@ class PyXrefMixin:
break
return result
def make_xrefs(self,
rolename, # type: str
domain, # type: str
target, # type: str
innernode=nodes.emphasis, # type: Type[TextlikeNode]
contnode=None, # type: nodes.Node
env=None, # type: BuildEnvironment
):
# type: (...) -> List[nodes.Node]
def make_xrefs(self, rolename: str, domain: str, target: str,
innernode: Type[TextlikeNode] = nodes.emphasis,
contnode: Node = None, env: BuildEnvironment = None) -> List[Node]:
delims = r'(\s*[\[\]\(\),](?:\s*or\s)?\s*|\s+or\s+)'
delims_re = re.compile(delims)
sub_targets = re.split(delims, target)
@ -172,9 +159,9 @@ class PyXrefMixin:
class PyField(PyXrefMixin, Field):
def make_xref(self, rolename, domain, target,
innernode=nodes.emphasis, contnode=None, env=None):
# type: (str, str, str, Type[TextlikeNode], nodes.Node, BuildEnvironment) -> nodes.Node # NOQA
def make_xref(self, rolename: str, domain: str, target: str,
innernode: Type[TextlikeNode] = nodes.emphasis,
contnode: Node = None, env: BuildEnvironment = None) -> Node:
if rolename == 'class' and target == 'None':
# None is not a type, so use obj role instead.
rolename = 'obj'
@ -187,9 +174,9 @@ class PyGroupedField(PyXrefMixin, GroupedField):
class PyTypedField(PyXrefMixin, TypedField):
def make_xref(self, rolename, domain, target,
innernode=nodes.emphasis, contnode=None, env=None):
# type: (str, str, str, Type[TextlikeNode], nodes.Node, BuildEnvironment) -> nodes.Node # NOQA
def make_xref(self, rolename: str, domain: str, target: str,
innernode: Type[TextlikeNode] = nodes.emphasis,
contnode: Node = None, env: BuildEnvironment = None) -> Node:
if rolename == 'class' and target == 'None':
# None is not a type, so use obj role instead.
rolename = 'obj'
@ -231,22 +218,19 @@ class PyObject(ObjectDescription):
allow_nesting = False
def get_signature_prefix(self, sig):
# type: (str) -> str
def get_signature_prefix(self, sig: str) -> str:
"""May return a prefix to put before the object name in the
signature.
"""
return ''
def needs_arglist(self):
# type: () -> bool
def needs_arglist(self) -> bool:
"""May return true if an empty argument list is to be generated even if
the document contains none.
"""
return False
def handle_signature(self, sig, signode):
# type: (str, addnodes.desc_signature) -> Tuple[str, str]
def handle_signature(self, sig: str, signode: desc_signature) -> Tuple[str, str]:
"""Transform a Python signature into RST nodes.
Return (fully qualified name of the thing, classname if any).
@ -320,13 +304,12 @@ class PyObject(ObjectDescription):
return fullname, prefix
def get_index_text(self, modname, name):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name: Tuple[str, str]) -> str:
"""Return the text for the index entry of the object."""
raise NotImplementedError('must be implemented in subclasses')
def add_target_and_index(self, name_cls, sig, signode):
# type: (Tuple[str, str], str, addnodes.desc_signature) -> None
def add_target_and_index(self, name_cls: Tuple[str, str], sig: str,
signode: desc_signature) -> None:
modname = self.options.get('module', self.env.ref_context.get('py:module'))
fullname = (modname and modname + '.' or '') + name_cls[0]
# note target
@ -345,8 +328,7 @@ class PyObject(ObjectDescription):
self.indexnode['entries'].append(('single', indextext,
fullname, '', None))
def before_content(self):
# type: () -> None
def before_content(self) -> None:
"""Handle object nesting before content
:py:class:`PyObject` represents Python language constructs. For
@ -379,8 +361,7 @@ class PyObject(ObjectDescription):
modules.append(self.env.ref_context.get('py:module'))
self.env.ref_context['py:module'] = self.options['module']
def after_content(self):
# type: () -> None
def after_content(self) -> None:
"""Handle object de-nesting after content
If this class is a nestable object, removing the last nested class prefix
@ -411,19 +392,16 @@ class PyModulelevel(PyObject):
Description of an object on module level (functions, data).
"""
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
warnings.warn('PyClassmember is deprecated.',
RemovedInSphinx40Warning)
return super().run()
def needs_arglist(self):
# type: () -> bool
def needs_arglist(self) -> bool:
return self.objtype == 'function'
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
if self.objtype == 'function':
if not modname:
return _('%s() (built-in function)') % name_cls[0]
@ -444,19 +422,16 @@ class PyFunction(PyObject):
'async': directives.flag,
})
def get_signature_prefix(self, sig):
# type: (str) -> str
def get_signature_prefix(self, sig: str) -> str:
if 'async' in self.options:
return 'async '
else:
return ''
def needs_arglist(self):
# type: () -> bool
def needs_arglist(self) -> bool:
return True
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
name, cls = name_cls
if modname:
return _('%s() (in module %s)') % (name, modname)
@ -467,8 +442,7 @@ class PyFunction(PyObject):
class PyVariable(PyObject):
"""Description of a variable."""
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
name, cls = name_cls
if modname:
return _('%s (in module %s)') % (name, modname)
@ -483,12 +457,10 @@ class PyClasslike(PyObject):
allow_nesting = True
def get_signature_prefix(self, sig):
# type: (str) -> str
def get_signature_prefix(self, sig: str) -> str:
return self.objtype + ' '
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
if self.objtype == 'class':
if not modname:
return _('%s (built-in class)') % name_cls[0]
@ -504,27 +476,23 @@ class PyClassmember(PyObject):
Description of a class member (methods, attributes).
"""
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
warnings.warn('PyClassmember is deprecated.',
RemovedInSphinx40Warning)
return super().run()
def needs_arglist(self):
# type: () -> bool
def needs_arglist(self) -> bool:
return self.objtype.endswith('method')
def get_signature_prefix(self, sig):
# type: (str) -> str
def get_signature_prefix(self, sig: str) -> str:
if self.objtype == 'staticmethod':
return 'static '
elif self.objtype == 'classmethod':
return 'classmethod '
return ''
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
name, cls = name_cls
add_modules = self.env.config.add_module_names
if self.objtype == 'method':
@ -593,15 +561,13 @@ class PyMethod(PyObject):
'staticmethod': directives.flag,
})
def needs_arglist(self):
# type: () -> bool
def needs_arglist(self) -> bool:
if 'property' in self.options:
return False
else:
return True
def get_signature_prefix(self, sig):
# type: (str) -> str
def get_signature_prefix(self, sig: str) -> str:
prefix = []
if 'abstractmethod' in self.options:
prefix.append('abstract')
@ -619,8 +585,7 @@ class PyMethod(PyObject):
else:
return ''
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
name, cls = name_cls
try:
clsname, methname = name.rsplit('.', 1)
@ -647,8 +612,7 @@ class PyClassMethod(PyMethod):
option_spec = PyObject.option_spec.copy()
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
self.name = 'py:method'
self.options['classmethod'] = True
@ -660,8 +624,7 @@ class PyStaticMethod(PyMethod):
option_spec = PyObject.option_spec.copy()
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
self.name = 'py:method'
self.options['staticmethod'] = True
@ -671,8 +634,7 @@ class PyStaticMethod(PyMethod):
class PyAttribute(PyObject):
"""Description of an attribute."""
def get_index_text(self, modname, name_cls):
# type: (str, Tuple[str, str]) -> str
def get_index_text(self, modname: str, name_cls: Tuple[str, str]) -> str:
name, cls = name_cls
try:
clsname, attrname = name.rsplit('.', 1)
@ -691,14 +653,12 @@ class PyDecoratorMixin:
"""
Mixin for decorator directives.
"""
def handle_signature(self, sig, signode):
# type: (str, addnodes.desc_signature) -> Tuple[str, str]
def handle_signature(self, sig: str, signode: desc_signature) -> Tuple[str, str]:
ret = super().handle_signature(sig, signode) # type: ignore
signode.insert(0, addnodes.desc_addname('@', '@'))
return ret
def needs_arglist(self):
# type: () -> bool
def needs_arglist(self) -> bool:
return False
@ -706,8 +666,7 @@ class PyDecoratorFunction(PyDecoratorMixin, PyModulelevel):
"""
Directive to mark functions meant to be used as decorators.
"""
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
# a decorator function is a function after all
self.name = 'py:function'
return super().run()
@ -717,8 +676,7 @@ class PyDecoratorMethod(PyDecoratorMixin, PyClassmember):
"""
Directive to mark methods meant to be used as decorators.
"""
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
self.name = 'py:method'
return super().run()
@ -739,14 +697,13 @@ class PyModule(SphinxDirective):
'deprecated': directives.flag,
}
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
domain = cast(PythonDomain, self.env.get_domain('py'))
modname = self.arguments[0].strip()
noindex = 'noindex' in self.options
self.env.ref_context['py:module'] = modname
ret = [] # type: List[nodes.Node]
ret = [] # type: List[Node]
if not noindex:
# note module to the domain
domain.note_module(modname,
@ -780,8 +737,7 @@ class PyCurrentModule(SphinxDirective):
final_argument_whitespace = False
option_spec = {} # type: Dict
def run(self):
# type: () -> List[nodes.Node]
def run(self) -> List[Node]:
modname = self.arguments[0].strip()
if modname == 'None':
self.env.ref_context.pop('py:module', None)
@ -791,8 +747,8 @@ class PyCurrentModule(SphinxDirective):
class PyXRefRole(XRefRole):
def process_link(self, env, refnode, has_explicit_title, title, target):
# type: (BuildEnvironment, nodes.Element, bool, str, str) -> Tuple[str, str]
def process_link(self, env: BuildEnvironment, refnode: Element,
has_explicit_title: bool, title: str, target: str) -> Tuple[str, str]:
refnode['py:module'] = env.ref_context.get('py:module')
refnode['py:class'] = env.ref_context.get('py:class')
if not has_explicit_title:
@ -822,8 +778,8 @@ class PythonModuleIndex(Index):
localname = _('Python Module Index')
shortname = _('modules')
def generate(self, docnames=None):
# type: (Iterable[str]) -> Tuple[List[Tuple[str, List[IndexEntry]]], bool]
def generate(self, docnames: Iterable[str] = None
) -> Tuple[List[Tuple[str, List[IndexEntry]]], bool]:
content = {} # type: Dict[str, List[IndexEntry]]
# list of prefixes to ignore
ignores = None # type: List[str]
@ -937,12 +893,10 @@ class PythonDomain(Domain):
]
@property
def objects(self):
# type: () -> Dict[str, Tuple[str, str]]
def objects(self) -> Dict[str, Tuple[str, str]]:
return self.data.setdefault('objects', {}) # fullname -> docname, objtype
def note_object(self, name, objtype, location=None):
# type: (str, str, Any) -> None
def note_object(self, name: str, objtype: str, location: Any = None) -> None:
"""Note a python object for cross reference.
.. versionadded:: 2.1
@ -955,20 +909,17 @@ class PythonDomain(Domain):
self.objects[name] = (self.env.docname, objtype)
@property
def modules(self):
# type: () -> Dict[str, Tuple[str, str, str, bool]]
def modules(self) -> Dict[str, Tuple[str, str, str, bool]]:
return self.data.setdefault('modules', {}) # modname -> docname, synopsis, platform, deprecated # NOQA
def note_module(self, name, synopsis, platform, deprecated):
# type: (str, str, str, bool) -> None
def note_module(self, name: str, synopsis: str, platform: str, deprecated: bool) -> None:
"""Note a python module for cross reference.
.. versionadded:: 2.1
"""
self.modules[name] = (self.env.docname, synopsis, platform, deprecated)
def clear_doc(self, docname):
# type: (str) -> None
def clear_doc(self, docname: str) -> None:
for fullname, (fn, _l) in list(self.objects.items()):
if fn == docname:
del self.objects[fullname]
@ -976,8 +927,7 @@ class PythonDomain(Domain):
if fn == docname:
del self.modules[modname]
def merge_domaindata(self, docnames, otherdata):
# type: (List[str], Dict) -> None
def merge_domaindata(self, docnames: List[str], otherdata: Dict) -> None:
# XXX check duplicates?
for fullname, (fn, objtype) in otherdata['objects'].items():
if fn in docnames:
@ -986,8 +936,8 @@ class PythonDomain(Domain):
if data[0] in docnames:
self.modules[modname] = data
def find_obj(self, env, modname, classname, name, type, searchmode=0):
# type: (BuildEnvironment, str, str, str, str, int) -> List[Tuple[str, Any]]
def find_obj(self, env: BuildEnvironment, modname: str, classname: str,
name: str, type: str, searchmode: int = 0) -> List[Tuple[str, Any]]:
"""Find a Python object for "name", perhaps using the given module
and/or classname. Returns a list of (name, object entry) tuples.
"""
@ -1049,9 +999,9 @@ class PythonDomain(Domain):
matches.append((newname, self.objects[newname]))
return matches
def resolve_xref(self, env, fromdocname, builder,
type, target, node, contnode):
# type: (BuildEnvironment, str, Builder, str, str, addnodes.pending_xref, nodes.Element) -> nodes.Element # NOQA
def resolve_xref(self, env: BuildEnvironment, fromdocname: str, builder: Builder,
type: str, target: str, node: pending_xref, contnode: Element
) -> Element:
modname = node.get('py:module')
clsname = node.get('py:class')
searchmode = node.hasattr('refspecific') and 1 or 0
@ -1070,12 +1020,12 @@ class PythonDomain(Domain):
else:
return make_refnode(builder, fromdocname, obj[0], name, contnode, name)
def resolve_any_xref(self, env, fromdocname, builder, target,
node, contnode):
# type: (BuildEnvironment, str, Builder, str, addnodes.pending_xref, nodes.Element) -> List[Tuple[str, nodes.Element]] # NOQA
def resolve_any_xref(self, env: BuildEnvironment, fromdocname: str, builder: Builder,
target: str, node: pending_xref, contnode: Element
) -> List[Tuple[str, Element]]:
modname = node.get('py:module')
clsname = node.get('py:class')
results = [] # type: List[Tuple[str, nodes.Element]]
results = [] # type: List[Tuple[str, Element]]
# always search in "refspecific" mode with the :any: role
matches = self.find_obj(env, modname, clsname, target, None, 1)
@ -1090,8 +1040,8 @@ class PythonDomain(Domain):
contnode, name)))
return results
def _make_module_refnode(self, builder, fromdocname, name, contnode):
# type: (Builder, str, str, nodes.Node) -> nodes.Element
def _make_module_refnode(self, builder: Builder, fromdocname: str, name: str,
contnode: Node) -> Element:
# get additional info for modules
docname, synopsis, platform, deprecated = self.modules[name]
title = name
@ -1104,16 +1054,14 @@ class PythonDomain(Domain):
return make_refnode(builder, fromdocname, docname,
'module-' + name, contnode, title)
def get_objects(self):
# type: () -> Iterator[Tuple[str, str, str, str, str, int]]
def get_objects(self) -> Iterator[Tuple[str, str, str, str, str, int]]:
for modname, info in self.modules.items():
yield (modname, modname, 'module', info[0], 'module-' + modname, 0)
for refname, (docname, type) in self.objects.items():
if type != 'module': # modules are already handled
yield (refname, refname, type, docname, refname, 1)
def get_full_qualified_name(self, node):
# type: (nodes.Element) -> str
def get_full_qualified_name(self, node: Element) -> str:
modname = node.get('py:module')
clsname = node.get('py:class')
target = node.get('reftarget')
@ -1123,8 +1071,7 @@ class PythonDomain(Domain):
return '.'.join(filter(None, [modname, clsname, target]))
def setup(app):
# type: (Sphinx) -> Dict[str, Any]
def setup(app: Sphinx) -> Dict[str, Any]:
app.add_domain(PythonDomain)
return {