From 91650d89c9bb53c78e22eb94a31edc5656faa0b5 Mon Sep 17 00:00:00 2001 From: Takeshi KOMIYA Date: Wed, 15 Aug 2018 01:42:07 +0900 Subject: [PATCH] Fix #5291: autodoc crashed by ForwardRef types --- CHANGES | 2 ++ sphinx/util/inspect.py | 4 ++++ tests/test_util_inspect.py | 3 +++ tests/typing_test_data.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/CHANGES b/CHANGES index 2cdb52a6c..7472232da 100644 --- a/CHANGES +++ b/CHANGES @@ -28,6 +28,8 @@ Bugs fixed font with XeLaTeX/LuaLateX (refs: #5251) * #5280: autodoc: Fix wrong type annotations for complex typing * autodoc: Optional types are wrongly rendered +* #5291: autodoc crashed by ForwardRef types + Testing -------- diff --git a/sphinx/util/inspect.py b/sphinx/util/inspect.py index c3fe0178d..64e93f267 100644 --- a/sphinx/util/inspect.py +++ b/sphinx/util/inspect.py @@ -477,6 +477,8 @@ class Signature(object): qualname = annotation._name elif getattr(annotation, '__qualname__', None): qualname = annotation.__qualname__ + elif getattr(annotation, '__forward_arg__', None): + qualname = annotation.__forward_arg__ else: qualname = self.format_annotation(annotation.__origin__) # ex. Union elif hasattr(annotation, '__qualname__'): @@ -510,6 +512,8 @@ class Signature(object): qualname = annotation._name elif getattr(annotation, '__qualname__', None): qualname = annotation.__qualname__ + elif getattr(annotation, '__forward_arg__', None): + qualname = annotation.__forward_arg__ else: qualname = self.format_annotation(annotation.__origin__) # ex. Union elif hasattr(annotation, '__qualname__'): diff --git a/tests/test_util_inspect.py b/tests/test_util_inspect.py index 51c2d06e1..a48aedcf0 100644 --- a/tests/test_util_inspect.py +++ b/tests/test_util_inspect.py @@ -297,6 +297,9 @@ def test_Signature_annotations(): sig = inspect.Signature(Node.children).format_args() assert sig == '(self) -> List[typing_test_data.Node]' + sig = inspect.Signature(Node.__init__).format_args() + assert sig == '(self, parent: Optional[Node]) -> None' + def test_safe_getattr_with_default(): class Foo(object): diff --git a/tests/typing_test_data.py b/tests/typing_test_data.py index 35064b042..5b161eac4 100644 --- a/tests/typing_test_data.py +++ b/tests/typing_test_data.py @@ -73,5 +73,8 @@ def f13() -> Optional[str]: class Node: + def __init__(self, parent: Optional['Node']) -> None: + pass + def children(self) -> List['Node']: pass