diff --git a/sphinx/builders/devhelp.py b/sphinx/builders/devhelp.py index fc2c0b1c9..f81154984 100644 --- a/sphinx/builders/devhelp.py +++ b/sphinx/builders/devhelp.py @@ -15,6 +15,7 @@ from __future__ import absolute_import import gzip import re from os import path +from typing import Any from docutils import nodes @@ -23,6 +24,7 @@ from sphinx.builders.html import StandaloneHTMLBuilder from sphinx.environment.adapters.indexentries import IndexEntries from sphinx.locale import __ from sphinx.util import logging +from sphinx.util.nodes import NodeMatcher from sphinx.util.osutil import make_filename try: @@ -32,7 +34,7 @@ except ImportError: if False: # For type annotation - from typing import Any, Dict, List # NOQA + from typing import Dict, List # NOQA from sphinx.application import Sphinx # NOQA @@ -100,12 +102,8 @@ class DevhelpBuilder(StandaloneHTMLBuilder): parent.attrib['link'] = node['refuri'] parent.attrib['name'] = node.astext() - def istoctree(node): - # type: (nodes.Node) -> bool - return isinstance(node, addnodes.compact_paragraph) and \ - 'toctree' in node - - for node in tocdoc.traverse(istoctree): + matcher = NodeMatcher(addnodes.compact_paragraph, toctree=Any) + for node in tocdoc.traverse(matcher): write_toc(node, chapters) # Index diff --git a/sphinx/builders/latex/transforms.py b/sphinx/builders/latex/transforms.py index 160c8c324..afc580f9d 100644 --- a/sphinx/builders/latex/transforms.py +++ b/sphinx/builders/latex/transforms.py @@ -16,6 +16,7 @@ from sphinx.builders.latex.nodes import ( captioned_literal_block, footnotemark, footnotetext, math_reference, thebibliography ) from sphinx.transforms import SphinxTransform +from sphinx.util.nodes import NodeMatcher if False: # For type annotation @@ -30,7 +31,7 @@ class FootnoteDocnameUpdater(SphinxTransform): TARGET_NODES = (nodes.footnote, nodes.footnote_reference) def apply(self): - for node in self.document.traverse(lambda n: isinstance(n, self.TARGET_NODES)): + for node in self.document.traverse(NodeMatcher(*self.TARGET_NODES)): node['docname'] = self.env.docname @@ -536,14 +537,14 @@ class CitationReferenceTransform(SphinxTransform): if self.app.builder.name != 'latex': return + matcher = NodeMatcher(addnodes.pending_xref, refdomain='std', reftype='citation') citations = self.env.get_domain('std').data['citations'] - for node in self.document.traverse(addnodes.pending_xref): - if node['refdomain'] == 'std' and node['reftype'] == 'citation': - docname, labelid, _ = citations.get(node['reftarget'], ('', '', 0)) - if docname: - citation_ref = nodes.citation_reference('', *node.children, - docname=docname, refname=labelid) - node.replace_self(citation_ref) + for node in self.document.traverse(matcher): + docname, labelid, _ = citations.get(node['reftarget'], ('', '', 0)) + if docname: + citation_ref = nodes.citation_reference('', *node.children, + docname=docname, refname=labelid) + node.replace_self(citation_ref) class MathReferenceTransform(SphinxTransform): @@ -577,10 +578,10 @@ class LiteralBlockTransform(SphinxTransform): if self.app.builder.name != 'latex': return - for node in self.document.traverse(nodes.container): - if node.get('literal_block') is True: - newnode = captioned_literal_block('', *node.children, **node.attributes) - node.replace_self(newnode) + matcher = NodeMatcher(nodes.container, literal_block=True) + for node in self.document.traverse(matcher): + newnode = captioned_literal_block('', *node.children, **node.attributes) + node.replace_self(newnode) class DocumentTargetTransform(SphinxTransform): diff --git a/sphinx/transforms/i18n.py b/sphinx/transforms/i18n.py index f49e27df3..727a12be6 100644 --- a/sphinx/transforms/i18n.py +++ b/sphinx/transforms/i18n.py @@ -10,6 +10,7 @@ """ from os import path +from typing import Any from docutils import nodes from docutils.io import StringInput @@ -22,14 +23,14 @@ from sphinx.transforms import SphinxTransform from sphinx.util import split_index_msg, logging from sphinx.util.i18n import find_catalog from sphinx.util.nodes import ( - LITERAL_TYPE_NODES, IMAGE_TYPE_NODES, + LITERAL_TYPE_NODES, IMAGE_TYPE_NODES, NodeMatcher, extract_messages, is_pending_meta, traverse_translatable_index, ) from sphinx.util.pycompat import indent if False: # For type annotation - from typing import Any, Dict, List, Tuple # NOQA + from typing import Dict, List, Tuple # NOQA from sphinx.application import Sphinx # NOQA from sphinx.config import Config # NOQA @@ -183,11 +184,8 @@ class Locale(SphinxTransform): self.document.note_implicit_target(section_node) # replace target's refname to new target name - def is_named_target(node): - # type: (nodes.Node) -> bool - return isinstance(node, nodes.target) and \ - node.get('refname') == old_name - for old_target in self.document.traverse(is_named_target): + matcher = NodeMatcher(nodes.target, refname=old_name) + for old_target in self.document.traverse(matcher): old_target['refname'] = new_name processed = True @@ -276,16 +274,14 @@ class Locale(SphinxTransform): continue # skip # auto-numbered foot note reference should use original 'ids'. - def is_autofootnote_ref(node): - # type: (nodes.Node) -> bool - return isinstance(node, nodes.footnote_reference) and node.get('auto') - def list_replace_or_append(lst, old, new): # type: (List, Any, Any) -> None if old in lst: lst[lst.index(old)] = new else: lst.append(new) + + is_autofootnote_ref = NodeMatcher(nodes.footnote_reference, auto=Any) old_foot_refs = node.traverse(is_autofootnote_ref) new_foot_refs = patch.traverse(is_autofootnote_ref) if len(old_foot_refs) != len(new_foot_refs): @@ -328,10 +324,7 @@ class Locale(SphinxTransform): # * reference target ".. _Python: ..." is not translatable. # * use translated refname for section refname. # * inline reference "`Python <...>`_" has no 'refname'. - def is_refnamed_ref(node): - # type: (nodes.Node) -> bool - return isinstance(node, nodes.reference) and \ - 'refname' in node + is_refnamed_ref = NodeMatcher(nodes.reference, refname=Any) old_refs = node.traverse(is_refnamed_ref) new_refs = patch.traverse(is_refnamed_ref) if len(old_refs) != len(new_refs): @@ -358,10 +351,7 @@ class Locale(SphinxTransform): self.document.note_refname(new) # refnamed footnote should use original 'ids'. - def is_refnamed_footnote_ref(node): - # type: (nodes.Node) -> bool - return isinstance(node, nodes.footnote_reference) and \ - 'refname' in node + is_refnamed_footnote_ref = NodeMatcher(nodes.footnote_reference, refname=Any) old_foot_refs = node.traverse(is_refnamed_footnote_ref) new_foot_refs = patch.traverse(is_refnamed_footnote_ref) refname_ids_map = {} @@ -380,10 +370,7 @@ class Locale(SphinxTransform): new["ids"] = refname_ids_map[refname] # citation should use original 'ids'. - def is_citation_ref(node): - # type: (nodes.Node) -> bool - return isinstance(node, nodes.citation_reference) and \ - 'refname' in node + is_citation_ref = NodeMatcher(nodes.citation_reference, refname=Any) old_cite_refs = node.traverse(is_citation_ref) new_cite_refs = patch.traverse(is_citation_ref) refname_ids_map = {} @@ -474,10 +461,7 @@ class Locale(SphinxTransform): node['entries'] = new_entries # remove translated attribute that is used for avoiding double translation. - def has_translatable(node): - # type: (nodes.Node) -> bool - return isinstance(node, nodes.Element) and 'translated' in node - for node in self.document.traverse(has_translatable): + for node in self.document.traverse(NodeMatcher(translated=Any)): node.delattr('translated') @@ -492,7 +476,8 @@ class RemoveTranslatableInline(SphinxTransform): from sphinx.builders.gettext import MessageCatalogBuilder if isinstance(self.app.builder, MessageCatalogBuilder): return - for inline in self.document.traverse(nodes.inline): - if 'translatable' in inline: - inline.parent.remove(inline) - inline.parent += inline.children + + matcher = NodeMatcher(nodes.inline, translatable=Any) + for inline in self.document.traverse(matcher): + inline.parent.remove(inline) + inline.parent += inline.children diff --git a/sphinx/util/nodes.py b/sphinx/util/nodes.py index 7e4dba01e..9d500de76 100644 --- a/sphinx/util/nodes.py +++ b/sphinx/util/nodes.py @@ -11,6 +11,7 @@ from __future__ import absolute_import import re +from typing import Any from docutils import nodes from six import text_type @@ -33,6 +34,57 @@ explicit_title_re = re.compile(r'^(.+?)\s*(?$', re.DOTALL) caption_ref_re = explicit_title_re # b/w compat alias +class NodeMatcher(object): + """A helper class for Node.traverse(). + + It checks that given node is an instance of specified node-classes and it has + specified node-attributes. + + For example, following example searches ``reference`` node having ``refdomain`` + and ``reftype`` attributes:: + + matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation') + doctree.traverse(matcher) + # => [, , ...] + + A special value ``typing.Any`` matches any kind of node-attributes. For example, + following example searches ``reference`` node having ``refdomain`` attributes:: + + from typing import Any + matcher = NodeMatcher(nodes.reference, refdomain=Any) + doctree.traverse(matcher) + # => [, , ...] + """ + + def __init__(self, *classes, **attrs): + # type: (nodes.Node, Any) -> None + self.classes = classes + self.attrs = attrs + + def match(self, node): + # type: (nodes.Node) -> bool + try: + if self.classes and not isinstance(node, self.classes): + return False + + for key, value in self.attrs.items(): + if key not in node: + return False + elif value is Any: + continue + elif node.get(key) != value: + return False + else: + return True + except Exception: + # for non-Element nodes + return False + + def __call__(self, node): + # type: (nodes.Node) -> bool + return self.match(node) + + def get_full_module_name(node): # type: (nodes.Node) -> str """ @@ -241,11 +293,7 @@ def traverse_parent(node, cls=None): def traverse_translatable_index(doctree): # type: (nodes.Node) -> Iterable[Tuple[nodes.Node, List[unicode]]] """Traverse translatable index node from a document tree.""" - def is_block_index(node): - # type: (nodes.Node) -> bool - return isinstance(node, addnodes.index) and \ - node.get('inline') is False - for node in doctree.traverse(is_block_index): + for node in doctree.traverse(NodeMatcher(addnodes.index, inline=False)): if 'raw_entries' in node: entries = node['raw_entries'] else: diff --git a/sphinx/writers/manpage.py b/sphinx/writers/manpage.py index 45a800533..80d0d820c 100644 --- a/sphinx/writers/manpage.py +++ b/sphinx/writers/manpage.py @@ -21,6 +21,7 @@ from sphinx import addnodes from sphinx.locale import admonitionlabels, _ from sphinx.util import logging from sphinx.util.i18n import format_date +from sphinx.util.nodes import NodeMatcher if False: # For type annotation @@ -63,16 +64,13 @@ class NestedInlineTransform(object): def apply(self): # type: () -> None - def is_inline(node): - # type: (nodes.Node) -> bool - return isinstance(node, (nodes.literal, nodes.emphasis, nodes.strong)) - - for node in self.document.traverse(is_inline): - if any(is_inline(subnode) for subnode in node): + matcher = NodeMatcher(nodes.literal, nodes.emphasis, nodes.strong) + for node in self.document.traverse(matcher): + if any(matcher(subnode) for subnode in node): pos = node.parent.index(node) for subnode in reversed(node[1:]): node.remove(subnode) - if is_inline(subnode): + if matcher(subnode): node.parent.insert(pos + 1, subnode) else: newnode = node.__class__('', subnode, **node.attributes) diff --git a/tests/test_util_nodes.py b/tests/test_util_nodes.py index d20b4b892..2fab10c1c 100644 --- a/tests/test_util_nodes.py +++ b/tests/test_util_nodes.py @@ -9,6 +9,7 @@ :license: BSD, see LICENSE for details. """ from textwrap import dedent +from typing import Any import pytest from docutils import frontend @@ -17,7 +18,7 @@ from docutils.parsers import rst from docutils.utils import new_document from sphinx.transforms import ApplySourceWorkaround -from sphinx.util.nodes import extract_messages, clean_astext +from sphinx.util.nodes import NodeMatcher, extract_messages, clean_astext def _transform(doctree): @@ -50,6 +51,38 @@ def assert_node_count(messages, node_type, expect_count): % (node_type, node_list, count, expect_count)) +def test_NodeMatcher(): + doctree = nodes.document(None, None) + doctree += nodes.paragraph('', 'Hello') + doctree += nodes.paragraph('', 'Sphinx', block=1) + doctree += nodes.paragraph('', 'World', block=2) + doctree += nodes.literal_block('', 'blah blah blah', block=3) + + # search by node class + matcher = NodeMatcher(nodes.paragraph) + assert len(doctree.traverse(matcher)) == 3 + + # search by multiple node classes + matcher = NodeMatcher(nodes.paragraph, nodes.literal_block) + assert len(doctree.traverse(matcher)) == 4 + + # search by node attribute + matcher = NodeMatcher(block=1) + assert len(doctree.traverse(matcher)) == 1 + + # search by node attribute (Any) + matcher = NodeMatcher(block=Any) + assert len(doctree.traverse(matcher)) == 3 + + # search by both class and attribute + matcher = NodeMatcher(nodes.paragraph, block=Any) + assert len(doctree.traverse(matcher)) == 2 + + # mismatched + matcher = NodeMatcher(nodes.title) + assert len(doctree.traverse(matcher)) == 0 + + @pytest.mark.parametrize( 'rst,node_cls,count', [