diff --git a/sphinx/domains/python.py b/sphinx/domains/python.py index 1b551c70b..5403f499a 100644 --- a/sphinx/domains/python.py +++ b/sphinx/domains/python.py @@ -10,6 +10,7 @@ import re import warnings +from inspect import Parameter from typing import Any, Dict, Iterable, Iterator, List, Tuple from typing import cast @@ -30,6 +31,7 @@ from sphinx.roles import XRefRole from sphinx.util import logging from sphinx.util.docfields import Field, GroupedField, TypedField from sphinx.util.docutils import SphinxDirective +from sphinx.util.inspect import signature_from_str from sphinx.util.nodes import make_refnode from sphinx.util.typing import TextlikeNode @@ -62,6 +64,47 @@ pairindextypes = { } +def _parse_arglist(arglist: str) -> addnodes.desc_parameterlist: + """Parse a list of arguments using AST parser""" + params = addnodes.desc_parameterlist(arglist) + sig = signature_from_str('(%s)' % arglist) + last_kind = None + for param in sig.parameters.values(): + if param.kind != param.POSITIONAL_ONLY and last_kind == param.POSITIONAL_ONLY: + # PEP-570: Separator for Positional Only Parameter: / + params += nodes.Text('/') + if param.kind == param.KEYWORD_ONLY and last_kind in (param.POSITIONAL_OR_KEYWORD, + param.POSITIONAL_ONLY, + None): + # PEP-3102: Separator for Keyword Only Parameter: * + params += nodes.Text('*') + + node = addnodes.desc_parameter() + if param.kind == param.VAR_POSITIONAL: + node += nodes.Text('*' + param.name) + elif param.kind == param.VAR_KEYWORD: + node += nodes.Text('**' + param.name) + else: + node += nodes.Text(param.name) + + if param.annotation is not param.empty: + node += nodes.Text(': ' + param.annotation) + if param.default is not param.empty: + if param.annotation is not param.empty: + node += nodes.Text(' = ' + str(param.default)) + else: + node += nodes.Text('=' + str(param.default)) + + params += node + last_kind = param.kind + + if last_kind == Parameter.POSITIONAL_ONLY: + # PEP-570: Separator for Positional Only Parameter: / + params += nodes.Text('/') + + return params + + def _pseudo_parse_arglist(signode: desc_signature, arglist: str) -> None: """"Parse" a list of arguments separated by commas. @@ -284,7 +327,15 @@ class PyObject(ObjectDescription): signode += addnodes.desc_name(name, name) if arglist: - _pseudo_parse_arglist(signode, arglist) + try: + signode += _parse_arglist(arglist) + except SyntaxError: + # fallback to parse arglist original parser. + # it supports to represent optional arguments (ex. "func(foo [, bar])") + _pseudo_parse_arglist(signode, arglist) + except NotImplementedError as exc: + logger.warning(exc) + _pseudo_parse_arglist(signode, arglist) else: if self.needs_arglist(): # for callables, add an empty parameter list diff --git a/tests/test_domain_py.py b/tests/test_domain_py.py index f78c1e9d8..1a3af913a 100644 --- a/tests/test_domain_py.py +++ b/tests/test_domain_py.py @@ -8,6 +8,7 @@ :license: BSD, see LICENSE for details. """ +import sys from unittest.mock import Mock import pytest @@ -241,7 +242,73 @@ def test_pyfunction_signature(app): desc_content)])) assert_node(doctree[1], addnodes.desc, desctype="function", domain="py", objtype="function", noindex=False) - assert_node(doctree[1][0][1], [desc_parameterlist, desc_parameter, "name: str"]) + assert_node(doctree[1][0][1], + [desc_parameterlist, desc_parameter, ("name", + ": str")]) + + +def test_pyfunction_signature_full(app): + text = (".. py:function:: hello(a: str, b = 1, *args: str, " + "c: bool = True, **kwargs: str) -> str") + doctree = restructuredtext.parse(app, text) + assert_node(doctree, (addnodes.index, + [desc, ([desc_signature, ([desc_name, "hello"], + desc_parameterlist, + [desc_returns, "str"])], + desc_content)])) + assert_node(doctree[1], addnodes.desc, desctype="function", + domain="py", objtype="function", noindex=False) + assert_node(doctree[1][0][1], + [desc_parameterlist, ([desc_parameter, ("a", + ": str")], + [desc_parameter, ("b", + "=1")], + [desc_parameter, ("*args", + ": str")], + [desc_parameter, ("c", + ": bool", + " = True")], + [desc_parameter, ("**kwargs", + ": str")])]) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason='python 3.8+ is required.') +def test_pyfunction_signature_full_py38(app): + # case: separator at head + text = ".. py:function:: hello(*, a)" + doctree = restructuredtext.parse(app, text) + assert_node(doctree[1][0][1], + [desc_parameterlist, ("*", + [desc_parameter, ("a", + "=None")])]) + + # case: separator in the middle + text = ".. py:function:: hello(a, /, b, *, c)" + doctree = restructuredtext.parse(app, text) + assert_node(doctree[1][0][1], + [desc_parameterlist, ([desc_parameter, "a"], + "/", + [desc_parameter, "b"], + "*", + [desc_parameter, ("c", + "=None")])]) + + # case: separator in the middle (2) + text = ".. py:function:: hello(a, /, *, b)" + doctree = restructuredtext.parse(app, text) + assert_node(doctree[1][0][1], + [desc_parameterlist, ([desc_parameter, "a"], + "/", + "*", + [desc_parameter, ("b", + "=None")])]) + + # case: separator at tail + text = ".. py:function:: hello(a, /)" + doctree = restructuredtext.parse(app, text) + assert_node(doctree[1][0][1], + [desc_parameterlist, ([desc_parameter, "a"], + "/")]) def test_optional_pyfunction_signature(app):