Add NodeMatcher; a helper class for Node.traverse()

This commit is contained in:
Takeshi KOMIYA 2018-08-25 16:17:44 +09:00
parent 6e8113da36
commit 9f7afa161e
2 changed files with 86 additions and 1 deletions

View File

@ -11,6 +11,7 @@
from __future__ import absolute_import
import re
from typing import Any
from docutils import nodes
from six import text_type
@ -33,6 +34,57 @@ explicit_title_re = re.compile(r'^(.+?)\s*(?<!\x00)<(.*?)>$', re.DOTALL)
caption_ref_re = explicit_title_re # b/w compat alias
class NodeMatcher(object):
"""A helper class for Node.traverse().
It checks that given node is an instance of specified node-classes and it has
specified node-attributes.
For example, following example searches ``reference`` node having ``refdomain``
and ``reftype`` attributes::
matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation')
doctree.traverse(matcher)
# => [<reference ...>, <reference ...>, ...]
A special value ``typing.Any`` matches any kind of node-attributes. For example,
following example searches ``reference`` node having ``refdomain`` attributes::
from typing import Any
matcher = NodeMatcher(nodes.reference, refdomain=Any)
doctree.traverse(matcher)
# => [<reference ...>, <reference ...>, ...]
"""
def __init__(self, *classes, **attrs):
# type: (nodes.Node, Any) -> None
self.classes = classes
self.attrs = attrs
def match(self, node):
# type: (nodes.Node) -> bool
try:
if self.classes and not isinstance(node, self.classes):
return False
for key, value in self.attrs.items():
if key not in node:
return False
elif value is Any:
continue
elif node.get(key) != value:
return False
else:
return True
except Exception:
# for non-Element nodes
return False
def __call__(self, node):
# type: (nodes.Node) -> bool
return self.match(node)
def get_full_module_name(node):
# type: (nodes.Node) -> str
"""

View File

@ -9,6 +9,7 @@
:license: BSD, see LICENSE for details.
"""
from textwrap import dedent
from typing import Any
import pytest
from docutils import frontend
@ -17,7 +18,7 @@ from docutils.parsers import rst
from docutils.utils import new_document
from sphinx.transforms import ApplySourceWorkaround
from sphinx.util.nodes import extract_messages, clean_astext
from sphinx.util.nodes import NodeMatcher, extract_messages, clean_astext
def _transform(doctree):
@ -50,6 +51,38 @@ def assert_node_count(messages, node_type, expect_count):
% (node_type, node_list, count, expect_count))
def test_NodeMatcher():
doctree = nodes.document(None, None)
doctree += nodes.paragraph('', 'Hello')
doctree += nodes.paragraph('', 'Sphinx', block=1)
doctree += nodes.paragraph('', 'World', block=2)
doctree += nodes.literal_block('', 'blah blah blah', block=3)
# search by node class
matcher = NodeMatcher(nodes.paragraph)
assert len(doctree.traverse(matcher)) == 3
# search by multiple node classes
matcher = NodeMatcher(nodes.paragraph, nodes.literal_block)
assert len(doctree.traverse(matcher)) == 4
# search by node attribute
matcher = NodeMatcher(block=1)
assert len(doctree.traverse(matcher)) == 1
# search by node attribute (Any)
matcher = NodeMatcher(block=Any)
assert len(doctree.traverse(matcher)) == 3
# search by both class and attribute
matcher = NodeMatcher(nodes.paragraph, block=Any)
assert len(doctree.traverse(matcher)) == 2
# mismatched
matcher = NodeMatcher(nodes.title)
assert len(doctree.traverse(matcher)) == 0
@pytest.mark.parametrize(
'rst,node_cls,count',
[