sphinx/sphinx/pycode/ast.py
2022-09-27 18:31:47 +01:00

203 lines
7.0 KiB
Python

"""Helpers for AST (Abstract Syntax Tree)."""
import ast
from typing import Dict, List, Optional, Type, overload
OPERATORS: Dict[Type[ast.AST], str] = {
ast.Add: "+",
ast.And: "and",
ast.BitAnd: "&",
ast.BitOr: "|",
ast.BitXor: "^",
ast.Div: "/",
ast.FloorDiv: "//",
ast.Invert: "~",
ast.LShift: "<<",
ast.MatMult: "@",
ast.Mult: "*",
ast.Mod: "%",
ast.Not: "not",
ast.Pow: "**",
ast.Or: "or",
ast.RShift: ">>",
ast.Sub: "-",
ast.UAdd: "+",
ast.USub: "-",
}
def parse(code: str, mode: str = 'exec') -> "ast.AST":
"""Parse the *code* using the built-in ast module."""
try:
return ast.parse(code, mode=mode, type_comments=True)
except SyntaxError:
# Some syntax error found. To ignore invalid type comments, retry parsing without
# type_comments parameter (refs: https://github.com/sphinx-doc/sphinx/issues/8652).
return ast.parse(code, mode=mode)
@overload
def unparse(node: None, code: str = '') -> None:
...
@overload
def unparse(node: ast.AST, code: str = '') -> str:
...
def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]:
"""Unparse an AST to string."""
if node is None:
return None
elif isinstance(node, str):
return node
return _UnparseVisitor(code).visit(node)
# a greatly cut-down version of `ast._Unparser`
class _UnparseVisitor(ast.NodeVisitor):
def __init__(self, code: str = '') -> None:
self.code = code
def _visit_op(self, node: ast.AST) -> str:
return OPERATORS[node.__class__]
for _op in OPERATORS:
locals()['visit_{}'.format(_op.__name__)] = _visit_op
def visit_arg(self, node: ast.arg) -> str:
if node.annotation:
return "%s: %s" % (node.arg, self.visit(node.annotation))
else:
return node.arg
def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str:
"""Unparse a single argument to a string."""
name = self.visit(arg)
if default:
if arg.annotation:
name += " = %s" % self.visit(default)
else:
name += "=%s" % self.visit(default)
return name
def visit_arguments(self, node: ast.arguments) -> str:
defaults: List[Optional[ast.expr]] = list(node.defaults)
positionals = len(node.args)
posonlyargs = len(node.posonlyargs)
positionals += posonlyargs
for _ in range(len(defaults), positionals):
defaults.insert(0, None)
kw_defaults: List[Optional[ast.expr]] = list(node.kw_defaults)
for _ in range(len(kw_defaults), len(node.kwonlyargs)):
kw_defaults.insert(0, None)
args: List[str] = []
for i, arg in enumerate(node.posonlyargs):
args.append(self._visit_arg_with_default(arg, defaults[i]))
if node.posonlyargs:
args.append('/')
for i, arg in enumerate(node.args):
args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))
if node.vararg:
args.append("*" + self.visit(node.vararg))
if node.kwonlyargs and not node.vararg:
args.append('*')
for i, arg in enumerate(node.kwonlyargs):
args.append(self._visit_arg_with_default(arg, kw_defaults[i]))
if node.kwarg:
args.append("**" + self.visit(node.kwarg))
return ", ".join(args)
def visit_Attribute(self, node: ast.Attribute) -> str:
return "%s.%s" % (self.visit(node.value), node.attr)
def visit_BinOp(self, node: ast.BinOp) -> str:
# Special case ``**`` to not have surrounding spaces.
if isinstance(node.op, ast.Pow):
return "".join(map(self.visit, (node.left, node.op, node.right)))
return " ".join(self.visit(e) for e in [node.left, node.op, node.right])
def visit_BoolOp(self, node: ast.BoolOp) -> str:
op = " %s " % self.visit(node.op)
return op.join(self.visit(e) for e in node.values)
def visit_Call(self, node: ast.Call) -> str:
args = ([self.visit(e) for e in node.args] +
["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords])
return "%s(%s)" % (self.visit(node.func), ", ".join(args))
def visit_Constant(self, node: ast.Constant) -> str:
if node.value is Ellipsis:
return "..."
elif isinstance(node.value, (int, float, complex)):
if self.code:
return ast.get_source_segment(self.code, node) or repr(node.value)
else:
return repr(node.value)
else:
return repr(node.value)
def visit_Dict(self, node: ast.Dict) -> str:
keys = (self.visit(k) for k in node.keys if k is not None)
values = (self.visit(v) for v in node.values)
items = (k + ": " + v for k, v in zip(keys, values))
return "{" + ", ".join(items) + "}"
def visit_Index(self, node: ast.Index) -> str:
return self.visit(node.value)
def visit_Lambda(self, node: ast.Lambda) -> str:
return "lambda %s: ..." % self.visit(node.args)
def visit_List(self, node: ast.List) -> str:
return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"
def visit_Name(self, node: ast.Name) -> str:
return node.id
def visit_Set(self, node: ast.Set) -> str:
return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"
def visit_Subscript(self, node: ast.Subscript) -> str:
def is_simple_tuple(value: ast.AST) -> bool:
return (
isinstance(value, ast.Tuple) and
bool(value.elts) and
not any(isinstance(elt, ast.Starred) for elt in value.elts)
)
if is_simple_tuple(node.slice):
elts = ", ".join(self.visit(e) for e in node.slice.elts) # type: ignore
return "%s[%s]" % (self.visit(node.value), elts)
elif isinstance(node.slice, ast.Index) and is_simple_tuple(node.slice.value):
elts = ", ".join(self.visit(e) for e in node.slice.value.elts) # type: ignore
return "%s[%s]" % (self.visit(node.value), elts)
else:
return "%s[%s]" % (self.visit(node.value), self.visit(node.slice))
def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
# UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``,
# ``-x``, ``~x``, and ``not x``. Only Not needs a space.
if isinstance(node.op, ast.Not):
return "%s %s" % (self.visit(node.op), self.visit(node.operand))
return "%s%s" % (self.visit(node.op), self.visit(node.operand))
def visit_Tuple(self, node: ast.Tuple) -> str:
if len(node.elts) == 0:
return "()"
elif len(node.elts) == 1:
return "(%s,)" % self.visit(node.elts[0])
else:
return "(" + ", ".join(self.visit(e) for e in node.elts) + ")"
def generic_visit(self, node):
raise NotImplementedError('Unable to parse %s object' % type(node).__name__)