diff --git a/sphinx/util/nodes.py b/sphinx/util/nodes.py index 7e4dba01e..d3e735f6a 100644 --- a/sphinx/util/nodes.py +++ b/sphinx/util/nodes.py @@ -11,6 +11,7 @@ from __future__ import absolute_import import re +from typing import Any from docutils import nodes from six import text_type @@ -33,6 +34,57 @@ explicit_title_re = re.compile(r'^(.+?)\s*(?$', re.DOTALL) caption_ref_re = explicit_title_re # b/w compat alias +class NodeMatcher(object): + """A helper class for Node.traverse(). + + It checks that given node is an instance of specified node-classes and it has + specified node-attributes. + + For example, following example searches ``reference`` node having ``refdomain`` + and ``reftype`` attributes:: + + matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation') + doctree.traverse(matcher) + # => [, , ...] + + A special value ``typing.Any`` matches any kind of node-attributes. For example, + following example searches ``reference`` node having ``refdomain`` attributes:: + + from typing import Any + matcher = NodeMatcher(nodes.reference, refdomain=Any) + doctree.traverse(matcher) + # => [, , ...] + """ + + def __init__(self, *classes, **attrs): + # type: (nodes.Node, Any) -> None + self.classes = classes + self.attrs = attrs + + def match(self, node): + # type: (nodes.Node) -> bool + try: + if self.classes and not isinstance(node, self.classes): + return False + + for key, value in self.attrs.items(): + if key not in node: + return False + elif value is Any: + continue + elif node.get(key) != value: + return False + else: + return True + except Exception: + # for non-Element nodes + return False + + def __call__(self, node): + # type: (nodes.Node) -> bool + return self.match(node) + + def get_full_module_name(node): # type: (nodes.Node) -> str """ diff --git a/tests/test_util_nodes.py b/tests/test_util_nodes.py index d20b4b892..2fab10c1c 100644 --- a/tests/test_util_nodes.py +++ b/tests/test_util_nodes.py @@ -9,6 +9,7 @@ :license: BSD, see LICENSE for details. """ from textwrap import dedent +from typing import Any import pytest from docutils import frontend @@ -17,7 +18,7 @@ from docutils.parsers import rst from docutils.utils import new_document from sphinx.transforms import ApplySourceWorkaround -from sphinx.util.nodes import extract_messages, clean_astext +from sphinx.util.nodes import NodeMatcher, extract_messages, clean_astext def _transform(doctree): @@ -50,6 +51,38 @@ def assert_node_count(messages, node_type, expect_count): % (node_type, node_list, count, expect_count)) +def test_NodeMatcher(): + doctree = nodes.document(None, None) + doctree += nodes.paragraph('', 'Hello') + doctree += nodes.paragraph('', 'Sphinx', block=1) + doctree += nodes.paragraph('', 'World', block=2) + doctree += nodes.literal_block('', 'blah blah blah', block=3) + + # search by node class + matcher = NodeMatcher(nodes.paragraph) + assert len(doctree.traverse(matcher)) == 3 + + # search by multiple node classes + matcher = NodeMatcher(nodes.paragraph, nodes.literal_block) + assert len(doctree.traverse(matcher)) == 4 + + # search by node attribute + matcher = NodeMatcher(block=1) + assert len(doctree.traverse(matcher)) == 1 + + # search by node attribute (Any) + matcher = NodeMatcher(block=Any) + assert len(doctree.traverse(matcher)) == 3 + + # search by both class and attribute + matcher = NodeMatcher(nodes.paragraph, block=Any) + assert len(doctree.traverse(matcher)) == 2 + + # mismatched + matcher = NodeMatcher(nodes.title) + assert len(doctree.traverse(matcher)) == 0 + + @pytest.mark.parametrize( 'rst,node_cls,count', [