pycode: Detect @overload decorators

This commit is contained in:
Takeshi KOMIYA
2020-05-24 19:18:21 +09:00
parent a59f83b6bd
commit 640bb2e586
3 changed files with 115 additions and 0 deletions

View File

@@ -12,6 +12,7 @@ import re
import tokenize
import warnings
from importlib import import_module
from inspect import Signature
from io import StringIO
from os import path
from typing import Any, Dict, IO, List, Tuple, Optional
@@ -145,6 +146,7 @@ class ModuleAnalyzer:
self.annotations = None # type: Dict[Tuple[str, str], str]
self.attr_docs = None # type: Dict[Tuple[str, str], List[str]]
self.finals = None # type: List[str]
self.overloads = None # type: Dict[str, List[Signature]]
self.tagorder = None # type: Dict[str, int]
self.tags = None # type: Dict[str, Tuple[str, int, int]]
@@ -163,6 +165,7 @@ class ModuleAnalyzer:
self.annotations = parser.annotations
self.finals = parser.finals
self.overloads = parser.overloads
self.tags = parser.definitions
self.tagorder = parser.deforders
except Exception as exc:

View File

@@ -12,12 +12,14 @@ import itertools
import re
import sys
import tokenize
from inspect import Signature
from token import NAME, NEWLINE, INDENT, DEDENT, NUMBER, OP, STRING
from tokenize import COMMENT, NL
from typing import Any, Dict, List, Optional, Tuple
from sphinx.pycode.ast import ast # for py37 or older
from sphinx.pycode.ast import parse, unparse
from sphinx.util.inspect import signature_from_ast
comment_re = re.compile('^\\s*#: ?(.*)\r?\n?$')
@@ -232,8 +234,10 @@ class VariableCommentPicker(ast.NodeVisitor):
self.previous = None # type: ast.AST
self.deforders = {} # type: Dict[str, int]
self.finals = [] # type: List[str]
self.overloads = {} # type: Dict[str, List[Signature]]
self.typing = None # type: str
self.typing_final = None # type: str
self.typing_overload = None # type: str
super().__init__()
def get_qualname_for(self, name: str) -> Optional[List[str]]:
@@ -257,6 +261,12 @@ class VariableCommentPicker(ast.NodeVisitor):
if qualname:
self.finals.append(".".join(qualname))
def add_overload_entry(self, func: ast.FunctionDef) -> None:
qualname = self.get_qualname_for(func.name)
if qualname:
overloads = self.overloads.setdefault(".".join(qualname), [])
overloads.append(signature_from_ast(func))
def add_variable_comment(self, name: str, comment: str) -> None:
qualname = self.get_qualname_for(name)
if qualname:
@@ -285,6 +295,22 @@ class VariableCommentPicker(ast.NodeVisitor):
return False
def is_overload(self, decorators: List[ast.expr]) -> bool:
overload = []
if self.typing:
overload.append('%s.overload' % self.typing)
if self.typing_overload:
overload.append(self.typing_overload)
for decorator in decorators:
try:
if unparse(decorator) in overload:
return True
except NotImplementedError:
pass
return False
def get_self(self) -> ast.arg:
"""Returns the name of first argument if in function."""
if self.current_function and self.current_function.args.args:
@@ -310,6 +336,8 @@ class VariableCommentPicker(ast.NodeVisitor):
self.typing = name.asname or name.name
elif name.name == 'typing.final':
self.typing_final = name.asname or name.name
elif name.name == 'typing.overload':
self.typing_overload = name.asname or name.name
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handles Import node and record it to definition orders."""
@@ -318,6 +346,8 @@ class VariableCommentPicker(ast.NodeVisitor):
if node.module == 'typing' and name.name == 'final':
self.typing_final = name.asname or name.name
elif node.module == 'typing' and name.name == 'overload':
self.typing_overload = name.asname or name.name
def visit_Assign(self, node: ast.Assign) -> None:
"""Handles Assign node and pick up a variable comment."""
@@ -417,6 +447,8 @@ class VariableCommentPicker(ast.NodeVisitor):
self.add_entry(node.name) # should be called before setting self.current_function
if self.is_final(node.decorator_list):
self.add_final_entry(node.name)
if self.is_overload(node.decorator_list):
self.add_overload_entry(node)
self.context.append(node.name)
self.current_function = node
for child in node.body:
@@ -518,6 +550,7 @@ class Parser:
self.deforders = {} # type: Dict[str, int]
self.definitions = {} # type: Dict[str, Tuple[str, int, int]]
self.finals = [] # type: List[str]
self.overloads = {} # type: Dict[str, List[Signature]]
def parse(self) -> None:
"""Parse the source code."""
@@ -533,6 +566,7 @@ class Parser:
self.comments = picker.comments
self.deforders = picker.deforders
self.finals = picker.finals
self.overloads = picker.overloads
def parse_definition(self) -> None:
"""Parse the location of definitions from the code."""

View File

@@ -13,6 +13,7 @@ import sys
import pytest
from sphinx.pycode.parser import Parser
from sphinx.util.inspect import signature_from_str
def test_comment_picker_basic():
@@ -452,3 +453,80 @@ def test_typing_final_not_imported():
parser = Parser(source)
parser.parse()
assert parser.finals == []
def test_typing_overload():
source = ('import typing\n'
'\n'
'@typing.overload\n'
'def func(x: int, y: int) -> int: pass\n'
'\n'
'@typing.overload\n'
'def func(x: str, y: str) -> str: pass\n'
'\n'
'def func(x, y): pass\n')
parser = Parser(source)
parser.parse()
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
signature_from_str('(x: str, y: str) -> str')]}
def test_typing_overload_from_import():
source = ('from typing import overload\n'
'\n'
'@overload\n'
'def func(x: int, y: int) -> int: pass\n'
'\n'
'@overload\n'
'def func(x: str, y: str) -> str: pass\n'
'\n'
'def func(x, y): pass\n')
parser = Parser(source)
parser.parse()
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
signature_from_str('(x: str, y: str) -> str')]}
def test_typing_overload_import_as():
source = ('import typing as foo\n'
'\n'
'@foo.overload\n'
'def func(x: int, y: int) -> int: pass\n'
'\n'
'@foo.overload\n'
'def func(x: str, y: str) -> str: pass\n'
'\n'
'def func(x, y): pass\n')
parser = Parser(source)
parser.parse()
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
signature_from_str('(x: str, y: str) -> str')]}
def test_typing_overload_from_import_as():
source = ('from typing import overload as bar\n'
'\n'
'@bar\n'
'def func(x: int, y: int) -> int: pass\n'
'\n'
'@bar\n'
'def func(x: str, y: str) -> str: pass\n'
'\n'
'def func(x, y): pass\n')
parser = Parser(source)
parser.parse()
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
signature_from_str('(x: str, y: str) -> str')]}
def test_typing_overload_not_imported():
source = ('@typing.final\n'
'def func(x: int, y: int) -> int: pass\n'
'\n'
'@typing.final\n'
'def func(x: str, y: str) -> str: pass\n'
'\n'
'def func(x, y): pass\n')
parser = Parser(source)
parser.parse()
assert parser.overloads == {}