Refactor C and C++ AST types (#12312)

- Define ``__eq__`` and ``__hash__`` methods for AST types
- Improve the base ``__eq__`` method
- Cache the value of ``is_anon()``
- Rename ``ASTIdentifier.identifier``
- Various other serialisation improvements
This commit is contained in:
Adam Turner 2024-04-23 04:56:32 +01:00 committed by GitHub
parent 67493fcbc7
commit 2f2078fba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1316 additions and 67 deletions

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import sys
import warnings
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Union, cast
from docutils import nodes from docutils import nodes
@ -38,39 +40,40 @@ class ASTBase(ASTBaseBase):
################################################################################ ################################################################################
class ASTIdentifier(ASTBaseBase): class ASTIdentifier(ASTBaseBase):
def __init__(self, identifier: str) -> None: def __init__(self, name: str) -> None:
assert identifier is not None if not isinstance(name, str) or len(name) == 0:
assert len(identifier) != 0 raise AssertionError
self.identifier = identifier self.name = sys.intern(name)
self.is_anonymous = name[0] == '@'
# ASTBaseBase already implements this method, # ASTBaseBase already implements this method,
# but specialising it here improves performance # but specialising it here improves performance
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if type(other) is not ASTIdentifier: if not isinstance(other, ASTIdentifier):
return NotImplemented return NotImplemented
return self.identifier == other.identifier return self.name == other.name
def is_anon(self) -> bool: def is_anon(self) -> bool:
return self.identifier[0] == '@' return self.is_anonymous
# and this is where we finally make a difference between __str__ and the display string # and this is where we finally make a difference between __str__ and the display string
def __str__(self) -> str: def __str__(self) -> str:
return self.identifier return self.name
def get_display_string(self) -> str: def get_display_string(self) -> str:
return "[anonymous]" if self.is_anon() else self.identifier return "[anonymous]" if self.is_anonymous else self.name
def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment,
prefix: str, symbol: Symbol) -> None: prefix: str, symbol: Symbol) -> None:
# note: slightly different signature of describe_signature due to the prefix # note: slightly different signature of describe_signature due to the prefix
verify_description_mode(mode) verify_description_mode(mode)
if self.is_anon(): if self.is_anonymous:
node = addnodes.desc_sig_name(text="[anonymous]") node = addnodes.desc_sig_name(text="[anonymous]")
else: else:
node = addnodes.desc_sig_name(self.identifier, self.identifier) node = addnodes.desc_sig_name(self.name, self.name)
if mode == 'markType': if mode == 'markType':
targetText = prefix + self.identifier targetText = prefix + self.name
pnode = addnodes.pending_xref('', refdomain='c', pnode = addnodes.pending_xref('', refdomain='c',
reftype='identifier', reftype='identifier',
reftarget=targetText, modname=None, reftarget=targetText, modname=None,
@ -87,6 +90,14 @@ class ASTIdentifier(ASTBaseBase):
else: else:
raise Exception('Unknown description mode: %s' % mode) raise Exception('Unknown description mode: %s' % mode)
@property
def identifier(self) -> str:
warnings.warn(
'`ASTIdentifier.identifier` is deprecated, use `ASTIdentifier.name` instead',
DeprecationWarning, stacklevel=2,
)
return self.name
class ASTNestedName(ASTBase): class ASTNestedName(ASTBase):
def __init__(self, names: list[ASTIdentifier], rooted: bool) -> None: def __init__(self, names: list[ASTIdentifier], rooted: bool) -> None:
@ -94,6 +105,14 @@ class ASTNestedName(ASTBase):
self.names = names self.names = names
self.rooted = rooted self.rooted = rooted
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTNestedName):
return NotImplemented
return self.names == other.names and self.rooted == other.rooted
def __hash__(self) -> int:
return hash((self.names, self.rooted))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self return self
@ -186,6 +205,14 @@ class ASTBooleanLiteral(ASTLiteral):
def __init__(self, value: bool) -> None: def __init__(self, value: bool) -> None:
self.value = value self.value = value
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTBooleanLiteral):
return NotImplemented
return self.value == other.value
def __hash__(self) -> int:
return hash(self.value)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
if self.value: if self.value:
return 'true' return 'true'
@ -202,6 +229,14 @@ class ASTNumberLiteral(ASTLiteral):
def __init__(self, data: str) -> None: def __init__(self, data: str) -> None:
self.data = data self.data = data
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTNumberLiteral):
return NotImplemented
return self.data == other.data
def __hash__(self) -> int:
return hash(self.data)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return self.data return self.data
@ -221,6 +256,17 @@ class ASTCharLiteral(ASTLiteral):
else: else:
raise UnsupportedMultiCharacterCharLiteral(decoded) raise UnsupportedMultiCharacterCharLiteral(decoded)
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTCharLiteral):
return NotImplemented
return (
self.prefix == other.prefix
and self.value == other.value
)
def __hash__(self) -> int:
return hash((self.prefix, self.value))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
if self.prefix is None: if self.prefix is None:
return "'" + self.data + "'" return "'" + self.data + "'"
@ -237,6 +283,14 @@ class ASTStringLiteral(ASTLiteral):
def __init__(self, data: str) -> None: def __init__(self, data: str) -> None:
self.data = data self.data = data
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTStringLiteral):
return NotImplemented
return self.data == other.data
def __hash__(self) -> int:
return hash(self.data)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return self.data return self.data
@ -251,6 +305,14 @@ class ASTIdExpression(ASTExpression):
# note: this class is basically to cast a nested name as an expression # note: this class is basically to cast a nested name as an expression
self.name = name self.name = name
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTIdExpression):
return NotImplemented
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return transform(self.name) return transform(self.name)
@ -266,6 +328,14 @@ class ASTParenExpr(ASTExpression):
def __init__(self, expr: ASTExpression) -> None: def __init__(self, expr: ASTExpression) -> None:
self.expr = expr self.expr = expr
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTParenExpr):
return NotImplemented
return self.expr == other.expr
def __hash__(self) -> int:
return hash(self.expr)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return '(' + transform(self.expr) + ')' return '(' + transform(self.expr) + ')'
@ -290,6 +360,14 @@ class ASTPostfixCallExpr(ASTPostfixOp):
def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None: def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None:
self.lst = lst self.lst = lst
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTPostfixCallExpr):
return NotImplemented
return self.lst == other.lst
def __hash__(self) -> int:
return hash(self.lst)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return transform(self.lst) return transform(self.lst)
@ -302,6 +380,14 @@ class ASTPostfixArray(ASTPostfixOp):
def __init__(self, expr: ASTExpression) -> None: def __init__(self, expr: ASTExpression) -> None:
self.expr = expr self.expr = expr
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTPostfixArray):
return NotImplemented
return self.expr == other.expr
def __hash__(self) -> int:
return hash(self.expr)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return '[' + transform(self.expr) + ']' return '[' + transform(self.expr) + ']'
@ -334,6 +420,14 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp):
def __init__(self, name: ASTNestedName) -> None: def __init__(self, name: ASTNestedName) -> None:
self.name = name self.name = name
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTPostfixMemberOfPointer):
return NotImplemented
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return '->' + transform(self.name) return '->' + transform(self.name)
@ -348,6 +442,14 @@ class ASTPostfixExpr(ASTExpression):
self.prefix = prefix self.prefix = prefix
self.postFixes = postFixes self.postFixes = postFixes
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTPostfixExpr):
return NotImplemented
return self.prefix == other.prefix and self.postFixes == other.postFixes
def __hash__(self) -> int:
return hash((self.prefix, self.postFixes))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)]) return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)])
@ -366,6 +468,14 @@ class ASTUnaryOpExpr(ASTExpression):
self.op = op self.op = op
self.expr = expr self.expr = expr
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTUnaryOpExpr):
return NotImplemented
return self.op == other.op and self.expr == other.expr
def __hash__(self) -> int:
return hash((self.op, self.expr))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
if self.op[0] in 'cn': if self.op[0] in 'cn':
return self.op + " " + transform(self.expr) return self.op + " " + transform(self.expr)
@ -386,6 +496,14 @@ class ASTSizeofType(ASTExpression):
def __init__(self, typ: ASTType) -> None: def __init__(self, typ: ASTType) -> None:
self.typ = typ self.typ = typ
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTSizeofType):
return NotImplemented
return self.typ == other.typ
def __hash__(self) -> int:
return hash(self.typ)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return "sizeof(" + transform(self.typ) + ")" return "sizeof(" + transform(self.typ) + ")"
@ -401,6 +519,14 @@ class ASTSizeofExpr(ASTExpression):
def __init__(self, expr: ASTExpression) -> None: def __init__(self, expr: ASTExpression) -> None:
self.expr = expr self.expr = expr
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTSizeofExpr):
return NotImplemented
return self.expr == other.expr
def __hash__(self) -> int:
return hash(self.expr)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return "sizeof " + transform(self.expr) return "sizeof " + transform(self.expr)
@ -415,6 +541,14 @@ class ASTAlignofExpr(ASTExpression):
def __init__(self, typ: ASTType) -> None: def __init__(self, typ: ASTType) -> None:
self.typ = typ self.typ = typ
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTAlignofExpr):
return NotImplemented
return self.typ == other.typ
def __hash__(self) -> int:
return hash(self.typ)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return "alignof(" + transform(self.typ) + ")" return "alignof(" + transform(self.typ) + ")"
@ -434,6 +568,17 @@ class ASTCastExpr(ASTExpression):
self.typ = typ self.typ = typ
self.expr = expr self.expr = expr
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTCastExpr):
return NotImplemented
return (
self.typ == other.typ
and self.expr == other.expr
)
def __hash__(self) -> int:
return hash((self.typ, self.expr))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
res = ['('] res = ['(']
res.append(transform(self.typ)) res.append(transform(self.typ))
@ -456,6 +601,17 @@ class ASTBinOpExpr(ASTBase):
self.exprs = exprs self.exprs = exprs
self.ops = ops self.ops = ops
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTBinOpExpr):
return NotImplemented
return (
self.exprs == other.exprs
and self.ops == other.ops
)
def __hash__(self) -> int:
return hash((self.exprs, self.ops))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
res = [] res = []
res.append(transform(self.exprs[0])) res.append(transform(self.exprs[0]))
@ -487,6 +643,17 @@ class ASTAssignmentExpr(ASTExpression):
self.exprs = exprs self.exprs = exprs
self.ops = ops self.ops = ops
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTAssignmentExpr):
return NotImplemented
return (
self.exprs == other.exprs
and self.ops == other.ops
)
def __hash__(self) -> int:
return hash((self.exprs, self.ops))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
res = [] res = []
res.append(transform(self.exprs[0])) res.append(transform(self.exprs[0]))
@ -515,6 +682,14 @@ class ASTFallbackExpr(ASTExpression):
def __init__(self, expr: str) -> None: def __init__(self, expr: str) -> None:
self.expr = expr self.expr = expr
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTFallbackExpr):
return NotImplemented
return self.expr == other.expr
def __hash__(self) -> int:
return hash(self.expr)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return self.expr return self.expr
@ -539,6 +714,14 @@ class ASTTrailingTypeSpecFundamental(ASTTrailingTypeSpec):
assert len(names) != 0 assert len(names) != 0
self.names = names self.names = names
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTTrailingTypeSpecFundamental):
return NotImplemented
return self.names == other.names
def __hash__(self) -> int:
return hash(self.names)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return ' '.join(self.names) return ' '.join(self.names)
@ -558,6 +741,17 @@ class ASTTrailingTypeSpecName(ASTTrailingTypeSpec):
self.prefix = prefix self.prefix = prefix
self.nestedName = nestedName self.nestedName = nestedName
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTTrailingTypeSpecName):
return NotImplemented
return (
self.prefix == other.prefix
and self.nestedName == other.nestedName
)
def __hash__(self) -> int:
return hash((self.prefix, self.nestedName))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.nestedName return self.nestedName
@ -583,6 +777,14 @@ class ASTFunctionParameter(ASTBase):
self.arg = arg self.arg = arg
self.ellipsis = ellipsis self.ellipsis = ellipsis
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTFunctionParameter):
return NotImplemented
return self.arg == other.arg and self.ellipsis == other.ellipsis
def __hash__(self) -> int:
return hash((self.arg, self.ellipsis))
def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: def get_id(self, version: int, objectType: str, symbol: Symbol) -> str:
# the anchor will be our parent # the anchor will be our parent
return symbol.parent.declaration.get_id(version, prefixed=False) return symbol.parent.declaration.get_id(version, prefixed=False)
@ -607,6 +809,14 @@ class ASTParameters(ASTBase):
self.args = args self.args = args
self.attrs = attrs self.attrs = attrs
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTParameters):
return NotImplemented
return self.args == other.args and self.attrs == other.attrs
def __hash__(self) -> int:
return hash((self.args, self.attrs))
@property @property
def function_params(self) -> list[ASTFunctionParameter]: def function_params(self) -> list[ASTFunctionParameter]:
return self.args return self.args
@ -674,6 +884,30 @@ class ASTDeclSpecsSimple(ASTBaseBase):
self.const = const self.const = const
self.attrs = attrs self.attrs = attrs
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclSpecsSimple):
return NotImplemented
return (
self.storage == other.storage
and self.threadLocal == other.threadLocal
and self.inline == other.inline
and self.restrict == other.restrict
and self.volatile == other.volatile
and self.const == other.const
and self.attrs == other.attrs
)
def __hash__(self) -> int:
return hash((
self.storage,
self.threadLocal,
self.inline,
self.restrict,
self.volatile,
self.const,
self.attrs,
))
def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple: def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple:
if not other: if not other:
return self return self
@ -741,6 +975,24 @@ class ASTDeclSpecs(ASTBase):
self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs) self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs)
self.trailingTypeSpec = trailing self.trailingTypeSpec = trailing
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclSpecs):
return NotImplemented
return (
self.outer == other.outer
and self.leftSpecs == other.leftSpecs
and self.rightSpecs == other.rightSpecs
and self.trailingTypeSpec == other.trailingTypeSpec
)
def __hash__(self) -> int:
return hash((
self.outer,
self.leftSpecs,
self.rightSpecs,
self.trailingTypeSpec,
))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
res: list[str] = [] res: list[str] = []
l = transform(self.leftSpecs) l = transform(self.leftSpecs)
@ -796,6 +1048,28 @@ class ASTArray(ASTBase):
if size is not None: if size is not None:
assert not vla assert not vla
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTArray):
return NotImplemented
return (
self.static == other.static
and self.const == other.const
and self.volatile == other.volatile
and self.restrict == other.restrict
and self.vla == other.vla
and self.size == other.size
)
def __hash__(self) -> int:
return hash((
self.static,
self.const,
self.volatile,
self.restrict,
self.vla,
self.size,
))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
el = [] el = []
if self.static: if self.static:
@ -861,6 +1135,18 @@ class ASTDeclaratorNameParam(ASTDeclarator):
self.arrayOps = arrayOps self.arrayOps = arrayOps
self.param = param self.param = param
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclaratorNameParam):
return NotImplemented
return (
self.declId == other.declId
and self.arrayOps == other.arrayOps
and self.param == other.param
)
def __hash__(self) -> int:
return hash((self.declId, self.arrayOps, self.param))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.declId return self.declId
@ -899,6 +1185,14 @@ class ASTDeclaratorNameBitField(ASTDeclarator):
self.declId = declId self.declId = declId
self.size = size self.size = size
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclaratorNameBitField):
return NotImplemented
return self.declId == other.declId and self.size == other.size
def __hash__(self) -> int:
return hash((self.declId, self.size))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.declId return self.declId
@ -937,6 +1231,20 @@ class ASTDeclaratorPtr(ASTDeclarator):
self.const = const self.const = const
self.attrs = attrs self.attrs = attrs
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclaratorPtr):
return NotImplemented
return (
self.next == other.next
and self.restrict == other.restrict
and self.volatile == other.volatile
and self.const == other.const
and self.attrs == other.attrs
)
def __hash__(self) -> int:
return hash((self.next, self.restrict, self.volatile, self.const, self.attrs))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.next.name return self.next.name
@ -1006,6 +1314,14 @@ class ASTDeclaratorParen(ASTDeclarator):
self.next = next self.next = next
# TODO: we assume the name and params are in inner # TODO: we assume the name and params are in inner
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclaratorParen):
return NotImplemented
return self.inner == other.inner and self.next == other.next
def __hash__(self) -> int:
return hash((self.inner, self.next))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.inner.name return self.inner.name
@ -1040,6 +1356,14 @@ class ASTParenExprList(ASTBaseParenExprList):
def __init__(self, exprs: list[ASTExpression]) -> None: def __init__(self, exprs: list[ASTExpression]) -> None:
self.exprs = exprs self.exprs = exprs
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTParenExprList):
return NotImplemented
return self.exprs == other.exprs
def __hash__(self) -> int:
return hash(self.exprs)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
exprs = [transform(e) for e in self.exprs] exprs = [transform(e) for e in self.exprs]
return '(%s)' % ', '.join(exprs) return '(%s)' % ', '.join(exprs)
@ -1064,6 +1388,14 @@ class ASTBracedInitList(ASTBase):
self.exprs = exprs self.exprs = exprs
self.trailingComma = trailingComma self.trailingComma = trailingComma
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTBracedInitList):
return NotImplemented
return self.exprs == other.exprs and self.trailingComma == other.trailingComma
def __hash__(self) -> int:
return hash((self.exprs, self.trailingComma))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
exprs = ', '.join(transform(e) for e in self.exprs) exprs = ', '.join(transform(e) for e in self.exprs)
trailingComma = ',' if self.trailingComma else '' trailingComma = ',' if self.trailingComma else ''
@ -1092,6 +1424,14 @@ class ASTInitializer(ASTBase):
self.value = value self.value = value
self.hasAssign = hasAssign self.hasAssign = hasAssign
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTInitializer):
return NotImplemented
return self.value == other.value and self.hasAssign == other.hasAssign
def __hash__(self) -> int:
return hash((self.value, self.hasAssign))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
val = transform(self.value) val = transform(self.value)
if self.hasAssign: if self.hasAssign:
@ -1116,6 +1456,14 @@ class ASTType(ASTBase):
self.declSpecs = declSpecs self.declSpecs = declSpecs
self.decl = decl self.decl = decl
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTType):
return NotImplemented
return self.declSpecs == other.declSpecs and self.decl == other.decl
def __hash__(self) -> int:
return hash((self.declSpecs, self.decl))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.decl.name return self.decl.name
@ -1161,6 +1509,14 @@ class ASTTypeWithInit(ASTBase):
self.type = type self.type = type
self.init = init self.init = init
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTTypeWithInit):
return NotImplemented
return self.type == other.type and self.init == other.init
def __hash__(self) -> int:
return hash((self.type, self.init))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.type.name return self.type.name
@ -1190,6 +1546,18 @@ class ASTMacroParameter(ASTBase):
self.ellipsis = ellipsis self.ellipsis = ellipsis
self.variadic = variadic self.variadic = variadic
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTMacroParameter):
return NotImplemented
return (
self.arg == other.arg
and self.ellipsis == other.ellipsis
and self.variadic == other.variadic
)
def __hash__(self) -> int:
return hash((self.arg, self.ellipsis, self.variadic))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
if self.ellipsis: if self.ellipsis:
return '...' return '...'
@ -1215,6 +1583,14 @@ class ASTMacro(ASTBase):
self.ident = ident self.ident = ident
self.args = args self.args = args
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTMacro):
return NotImplemented
return self.ident == other.ident and self.args == other.args
def __hash__(self) -> int:
return hash((self.ident, self.args))
@property @property
def name(self) -> ASTNestedName: def name(self) -> ASTNestedName:
return self.ident return self.ident
@ -1254,6 +1630,14 @@ class ASTStruct(ASTBase):
def __init__(self, name: ASTNestedName) -> None: def __init__(self, name: ASTNestedName) -> None:
self.name = name self.name = name
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTStruct):
return NotImplemented
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: def get_id(self, version: int, objectType: str, symbol: Symbol) -> str:
return symbol.get_full_nested_name().get_id(version) return symbol.get_full_nested_name().get_id(version)
@ -1270,6 +1654,14 @@ class ASTUnion(ASTBase):
def __init__(self, name: ASTNestedName) -> None: def __init__(self, name: ASTNestedName) -> None:
self.name = name self.name = name
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTUnion):
return NotImplemented
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: def get_id(self, version: int, objectType: str, symbol: Symbol) -> str:
return symbol.get_full_nested_name().get_id(version) return symbol.get_full_nested_name().get_id(version)
@ -1286,6 +1678,14 @@ class ASTEnum(ASTBase):
def __init__(self, name: ASTNestedName) -> None: def __init__(self, name: ASTNestedName) -> None:
self.name = name self.name = name
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTEnum):
return NotImplemented
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)
def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: def get_id(self, version: int, objectType: str, symbol: Symbol) -> str:
return symbol.get_full_nested_name().get_id(version) return symbol.get_full_nested_name().get_id(version)
@ -1305,6 +1705,18 @@ class ASTEnumerator(ASTBase):
self.init = init self.init = init
self.attrs = attrs self.attrs = attrs
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTEnumerator):
return NotImplemented
return (
self.name == other.name
and self.init == other.init
and self.attrs == other.attrs
)
def __hash__(self) -> int:
return hash((self.name, self.init, self.attrs))
def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: def get_id(self, version: int, objectType: str, symbol: Symbol) -> str:
return symbol.get_full_nested_name().get_id(version) return symbol.get_full_nested_name().get_id(version)
@ -1346,6 +1758,18 @@ class ASTDeclaration(ASTBaseBase):
# further changes will be made to this object # further changes will be made to this object
self._newest_id_cache: str | None = None self._newest_id_cache: str | None = None
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTDeclaration):
return NotImplemented
return (
self.objectType == other.objectType
and self.directiveType == other.directiveType
and self.declaration == other.declaration
and self.semicolon == other.semicolon
and self.symbol == other.symbol
and self.enumeratorScopedSymbol == other.enumeratorScopedSymbol
)
def clone(self) -> ASTDeclaration: def clone(self) -> ASTDeclaration:
return ASTDeclaration(self.objectType, self.directiveType, return ASTDeclaration(self.objectType, self.directiveType,
self.declaration.clone(), self.semicolon) self.declaration.clone(), self.semicolon)

View File

@ -114,6 +114,9 @@ class Symbol:
# Do symbol addition after self._children has been initialised. # Do symbol addition after self._children has been initialised.
self._add_function_params() self._add_function_params()
def __repr__(self) -> str:
return f'<Symbol {self.to_string(indent=0)!r}>'
def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None: def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None:
self._assert_invariants() self._assert_invariants()
assert self.declaration is None assert self.declaration is None

File diff suppressed because it is too large Load Diff

View File

@ -155,6 +155,9 @@ class Symbol:
# Do symbol addition after self._children has been initialised. # Do symbol addition after self._children has been initialised.
self._add_template_and_function_params() self._add_template_and_function_params()
def __repr__(self) -> str:
return f'<Symbol {self.to_string(indent=0)!r}>'
def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None: def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None:
self._assert_invariants() self._assert_invariants()
assert self.declaration is None assert self.declaration is None

View File

@ -90,17 +90,11 @@ class NoOldIdError(Exception):
class ASTBaseBase: class ASTBaseBase:
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if type(self) is not type(other): if type(self) is not type(other):
return False return NotImplemented
try: try:
for key, value in self.__dict__.items(): return self.__dict__ == other.__dict__
if value != getattr(other, key):
return False
except AttributeError: except AttributeError:
return False return False
return True
# Defining __hash__ = None is not strictly needed when __eq__ is defined.
__hash__ = None # type: ignore[assignment]
def clone(self) -> Any: def clone(self) -> Any:
return deepcopy(self) return deepcopy(self)
@ -115,7 +109,7 @@ class ASTBaseBase:
return self._stringify(lambda ast: ast.get_display_string()) return self._stringify(lambda ast: ast.get_display_string())
def __repr__(self) -> str: def __repr__(self) -> str:
return '<%s>' % self.__class__.__name__ return f'<{self.__class__.__name__}: {self._stringify(repr)}>'
################################################################################ ################################################################################
@ -131,8 +125,16 @@ class ASTCPPAttribute(ASTAttribute):
def __init__(self, arg: str) -> None: def __init__(self, arg: str) -> None:
self.arg = arg self.arg = arg
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTCPPAttribute):
return NotImplemented
return self.arg == other.arg
def __hash__(self) -> int:
return hash(self.arg)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return "[[" + self.arg + "]]" return f"[[{self.arg}]]"
def describe_signature(self, signode: TextElement) -> None: def describe_signature(self, signode: TextElement) -> None:
signode.append(addnodes.desc_sig_punctuation('[[', '[[')) signode.append(addnodes.desc_sig_punctuation('[[', '[['))
@ -146,35 +148,37 @@ class ASTGnuAttribute(ASTBaseBase):
self.args = args self.args = args
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if type(other) is not ASTGnuAttribute: if not isinstance(other, ASTGnuAttribute):
return NotImplemented return NotImplemented
return self.name == other.name and self.args == other.args return self.name == other.name and self.args == other.args
def __hash__(self) -> int:
return hash((self.name, self.args))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
res = [self.name]
if self.args: if self.args:
res.append(transform(self.args)) return self.name + transform(self.args)
return ''.join(res) return self.name
class ASTGnuAttributeList(ASTAttribute): class ASTGnuAttributeList(ASTAttribute):
def __init__(self, attrs: list[ASTGnuAttribute]) -> None: def __init__(self, attrs: list[ASTGnuAttribute]) -> None:
self.attrs = attrs self.attrs = attrs
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTGnuAttributeList):
return NotImplemented
return self.attrs == other.attrs
def __hash__(self) -> int:
return hash(self.attrs)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
res = ['__attribute__(('] attrs = ', '.join(map(transform, self.attrs))
first = True return f'__attribute__(({attrs}))'
for attr in self.attrs:
if not first:
res.append(', ')
first = False
res.append(transform(attr))
res.append('))')
return ''.join(res)
def describe_signature(self, signode: TextElement) -> None: def describe_signature(self, signode: TextElement) -> None:
txt = str(self) signode.append(nodes.Text(str(self)))
signode.append(nodes.Text(txt))
class ASTIdAttribute(ASTAttribute): class ASTIdAttribute(ASTAttribute):
@ -183,6 +187,14 @@ class ASTIdAttribute(ASTAttribute):
def __init__(self, id: str) -> None: def __init__(self, id: str) -> None:
self.id = id self.id = id
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTIdAttribute):
return NotImplemented
return self.id == other.id
def __hash__(self) -> int:
return hash(self.id)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return self.id return self.id
@ -197,12 +209,19 @@ class ASTParenAttribute(ASTAttribute):
self.id = id self.id = id
self.arg = arg self.arg = arg
def __eq__(self, other: object) -> bool:
if not isinstance(other, ASTParenAttribute):
return NotImplemented
return self.id == other.id and self.arg == other.arg
def __hash__(self) -> int:
return hash((self.id, self.arg))
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return self.id + '(' + self.arg + ')' return f'{self.id}({self.arg})'
def describe_signature(self, signode: TextElement) -> None: def describe_signature(self, signode: TextElement) -> None:
txt = str(self) signode.append(nodes.Text(str(self)))
signode.append(nodes.Text(txt))
class ASTAttributeList(ASTBaseBase): class ASTAttributeList(ASTBaseBase):
@ -210,10 +229,13 @@ class ASTAttributeList(ASTBaseBase):
self.attrs = attrs self.attrs = attrs
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if type(other) is not ASTAttributeList: if not isinstance(other, ASTAttributeList):
return NotImplemented return NotImplemented
return self.attrs == other.attrs return self.attrs == other.attrs
def __hash__(self) -> int:
return hash(self.attrs)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.attrs) return len(self.attrs)
@ -221,7 +243,7 @@ class ASTAttributeList(ASTBaseBase):
return ASTAttributeList(self.attrs + other.attrs) return ASTAttributeList(self.attrs + other.attrs)
def _stringify(self, transform: StringifyTransform) -> str: def _stringify(self, transform: StringifyTransform) -> str:
return ' '.join(transform(attr) for attr in self.attrs) return ' '.join(map(transform, self.attrs))
def describe_signature(self, signode: TextElement) -> None: def describe_signature(self, signode: TextElement) -> None:
if len(self.attrs) == 0: if len(self.attrs) == 0: