diff --git a/sphinx/directives/__init__.py b/sphinx/directives/__init__.py index c54c6ed4d..dc7989311 100644 --- a/sphinx/directives/__init__.py +++ b/sphinx/directives/__init__.py @@ -36,7 +36,7 @@ if False: from sphinx.application import Sphinx # NOQA from sphinx.config import Config # NOQA from sphinx.environment import BuildEnvironment # NOQA - from sphinx.util.typing import unicode # NOQA + from sphinx.util.typing import N_co, unicode # NOQA # RE to strip backslash escapes @@ -116,7 +116,7 @@ class ObjectDescription(SphinxDirective): pass def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] """ Main directive entry function, called by docutils upon encountering the directive. @@ -198,7 +198,7 @@ class DefaultRole(SphinxDirective): final_argument_whitespace = False def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] if not self.arguments: if '' in roles._roles: # restore the "default" default role @@ -230,7 +230,7 @@ class DefaultDomain(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] domain_name = self.arguments[0].lower() # if domain_name not in env.domains: # # try searching by label diff --git a/sphinx/directives/code.py b/sphinx/directives/code.py index 19363155a..4cbcd0862 100644 --- a/sphinx/directives/code.py +++ b/sphinx/directives/code.py @@ -29,7 +29,7 @@ if False: from typing import Any, Dict, List, Tuple # NOQA from sphinx.application import Sphinx # NOQA from sphinx.config import Config # NOQA - from sphinx.util.typing import unicode # NOQA + from sphinx.util.typing import N_co, unicode # NOQA logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ class Highlight(SphinxDirective): } def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] linenothreshold = self.options.get('linenothreshold', sys.maxsize) return [addnodes.highlightlang(lang=self.arguments[0].strip(), linenothreshold=linenothreshold)] @@ -59,7 +59,7 @@ class HighlightLang(Highlight): """highlightlang directive (deprecated)""" def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] warnings.warn('highlightlang directive is deprecated. ' 'Please use highlight directive instead.', RemovedInSphinx40Warning, stacklevel=2) @@ -94,13 +94,16 @@ def container_wrapper(directive, literal_node, caption): if isinstance(parsed[0], nodes.system_message): msg = __('Invalid caption: %s' % parsed[0].astext()) raise ValueError(msg) - caption_node = nodes.caption(parsed[0].rawsource, '', - *parsed[0].children) - caption_node.source = literal_node.source - caption_node.line = literal_node.line - container_node += caption_node - container_node += literal_node - return container_node + elif isinstance(parsed[0], nodes.Element): + caption_node = nodes.caption(parsed[0].rawsource, '', + *parsed[0].children) + caption_node.source = literal_node.source + caption_node.line = literal_node.line + container_node += caption_node + container_node += literal_node + return container_node + else: + raise RuntimeError # never reached class CodeBlock(SphinxDirective): @@ -124,7 +127,7 @@ class CodeBlock(SphinxDirective): } def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] document = self.state.document code = u'\n'.join(self.content) location = self.state_machine.get_source_and_line(self.lineno) @@ -151,7 +154,7 @@ class CodeBlock(SphinxDirective): lines = dedent_lines(lines, self.options['dedent'], location=location) code = '\n'.join(lines) - literal = nodes.literal_block(code, code) + literal = nodes.literal_block(code, code) # type: nodes.Element literal['language'] = self.arguments[0] literal['linenos'] = 'linenos' in self.options or \ 'lineno-start' in self.options @@ -416,7 +419,7 @@ class LiteralInclude(SphinxDirective): } def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] document = self.state.document if not document.settings.file_insertion_enabled: return [document.reporter.warning('File insertion disabled', @@ -434,7 +437,7 @@ class LiteralInclude(SphinxDirective): reader = LiteralIncludeReader(filename, self.options, self.config) text, lines = reader.read(location=location) - retnode = nodes.literal_block(text, text, source=filename) + retnode = nodes.literal_block(text, text, source=filename) # type: nodes.Element set_source_info(self, retnode) if self.options.get('diff'): # if diff is set, set udiff retnode['language'] = 'udiff' diff --git a/sphinx/directives/other.py b/sphinx/directives/other.py index dbc1190b5..53f008c2b 100644 --- a/sphinx/directives/other.py +++ b/sphinx/directives/other.py @@ -28,7 +28,7 @@ if False: # For type annotation from typing import Any, Dict, Generator, List, Tuple # NOQA from sphinx.application import Sphinx # NOQA - from sphinx.util.typing import unicode # NOQA + from sphinx.util.typing import N_co, unicode # NOQA glob_re = re.compile(r'.*[*?\[].*') @@ -63,7 +63,7 @@ class TocTree(SphinxDirective): } def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] subnode = addnodes.toctree() subnode['parent'] = self.env.docname @@ -163,10 +163,10 @@ class Author(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] if not self.config.show_authors: return [] - para = nodes.paragraph(translatable=False) # type: nodes.Node + para = nodes.paragraph(translatable=False) # type: nodes.Element emph = nodes.emphasis() para += emph if self.name == 'sectionauthor': @@ -195,7 +195,7 @@ class Index(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] arguments = self.arguments[0].split('\n') targetid = 'index-%s' % self.env.new_serialno('index') targetnode = nodes.target('', '', ids=[targetid]) @@ -227,7 +227,7 @@ class TabularColumns(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] node = addnodes.tabular_col_spec() node['spec'] = self.arguments[0] set_source_info(self, node) @@ -245,10 +245,10 @@ class Centered(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] if not self.arguments: return [] - subnode = addnodes.centered() # type: nodes.Node + subnode = addnodes.centered() # type: nodes.Element inodes, messages = self.state.inline_text(self.arguments[0], self.lineno) subnode.extend(inodes) @@ -266,7 +266,7 @@ class Acks(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] node = addnodes.acks() node.document = self.state.document self.state.nested_parse(self.content, self.content_offset, node) @@ -290,7 +290,7 @@ class HList(SphinxDirective): } def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] ncolumns = self.options.get('columns', 2) node = nodes.paragraph() node.document = self.state.document @@ -325,7 +325,7 @@ class Only(SphinxDirective): option_spec = {} # type: Dict def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] node = addnodes.only() node.document = self.state.document set_source_info(self, node) @@ -379,7 +379,7 @@ class Include(BaseInclude, SphinxDirective): """ def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] if self.arguments[0].startswith('<') and \ self.arguments[0].endswith('>'): # docutils "standard" includes, do not do path processing diff --git a/sphinx/directives/patches.py b/sphinx/directives/patches.py index b1104fc97..ee423fab8 100644 --- a/sphinx/directives/patches.py +++ b/sphinx/directives/patches.py @@ -22,7 +22,7 @@ if False: # For type annotation from typing import Dict, List, Tuple # NOQA from sphinx.application import Sphinx # NOQA - from sphinx.util.typing import unicode # NOQA + from sphinx.util.typing import N_co, unicode # NOQA class Figure(images.Figure): @@ -31,9 +31,9 @@ class Figure(images.Figure): """ def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] name = self.options.pop('name', None) - result = super(Figure, self).run() + result = super(Figure, self).run() # type: List[nodes.Node] if len(result) == 2 or isinstance(result[0], nodes.system_message): return result @@ -54,8 +54,8 @@ class Figure(images.Figure): class Meta(html.Meta, SphinxDirective): def run(self): - # type: () -> List[nodes.Node] - result = super(Meta, self).run() + # type: () -> List[N_co] + result = super(Meta, self).run() # type: List[nodes.Node] for node in result: if (isinstance(node, nodes.pending) and isinstance(node.details['nodes'][0], html.MetaBody.meta)): @@ -124,7 +124,7 @@ class MathDirective(SphinxDirective): } def run(self): - # type: () -> List[nodes.Node] + # type: () -> List[N_co] latex = '\n'.join(self.content) if self.arguments and self.arguments[0]: latex = self.arguments[0] + '\n\n' + latex diff --git a/sphinx/util/typing.py b/sphinx/util/typing.py index 2ac03c916..e0cf0104b 100644 --- a/sphinx/util/typing.py +++ b/sphinx/util/typing.py @@ -9,7 +9,7 @@ :license: BSD, see LICENSE for details. """ -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, TypeVar from docutils import nodes from docutils.parsers.rst.states import Inliner @@ -24,6 +24,9 @@ else: unicode = str +N_co = TypeVar('N_co', bound=nodes.Node, covariant=True) + + # common role functions RoleFunction = Callable[[text_type, text_type, text_type, int, Inliner, Dict, List[text_type]], Tuple[List[nodes.Node], List[nodes.system_message]]]