From 2f2078fba3ab369ecf43e738ac4bb7d3c0aab9de Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+AA-Turner@users.noreply.github.com> Date: Tue, 23 Apr 2024 04:56:32 +0100 Subject: [PATCH] Refactor C and C++ AST types (#12312) - Define ``__eq__`` and ``__hash__`` methods for AST types - Improve the base ``__eq__`` method - Cache the value of ``is_anon()`` - Rename ``ASTIdentifier.identifier`` - Various other serialisation improvements --- sphinx/domains/c/_ast.py | 448 +++++++++++++++++- sphinx/domains/c/_symbol.py | 3 + sphinx/domains/cpp/_ast.py | 847 +++++++++++++++++++++++++++++++++- sphinx/domains/cpp/_symbol.py | 3 + sphinx/util/cfamily.py | 82 ++-- 5 files changed, 1316 insertions(+), 67 deletions(-) diff --git a/sphinx/domains/c/_ast.py b/sphinx/domains/c/_ast.py index 3a8e2a2a4..6082a56fe 100644 --- a/sphinx/domains/c/_ast.py +++ b/sphinx/domains/c/_ast.py @@ -1,5 +1,7 @@ from __future__ import annotations +import sys +import warnings from typing import TYPE_CHECKING, Any, Union, cast from docutils import nodes @@ -38,39 +40,40 @@ class ASTBase(ASTBaseBase): ################################################################################ class ASTIdentifier(ASTBaseBase): - def __init__(self, identifier: str) -> None: - assert identifier is not None - assert len(identifier) != 0 - self.identifier = identifier + def __init__(self, name: str) -> None: + if not isinstance(name, str) or len(name) == 0: + raise AssertionError + self.name = sys.intern(name) + self.is_anonymous = name[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: - if type(other) is not ASTIdentifier: + if not isinstance(other, ASTIdentifier): return NotImplemented - return self.identifier == other.identifier + return self.name == other.name def is_anon(self) -> bool: - return self.identifier[0] == '@' + return self.is_anonymous # and this is where we finally make a difference between __str__ and the display string def __str__(self) -> str: - return self.identifier + return self.name def get_display_string(self) -> str: - return "[anonymous]" if self.is_anon() else self.identifier + return "[anonymous]" if self.is_anonymous else self.name def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, symbol: Symbol) -> None: # note: slightly different signature of describe_signature due to the prefix verify_description_mode(mode) - if self.is_anon(): + if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: - node = addnodes.desc_sig_name(self.identifier, self.identifier) + node = addnodes.desc_sig_name(self.name, self.name) if mode == 'markType': - targetText = prefix + self.identifier + targetText = prefix + self.name pnode = addnodes.pending_xref('', refdomain='c', reftype='identifier', reftarget=targetText, modname=None, @@ -87,6 +90,14 @@ class ASTIdentifier(ASTBaseBase): else: raise Exception('Unknown description mode: %s' % mode) + @property + def identifier(self) -> str: + warnings.warn( + '`ASTIdentifier.identifier` is deprecated, use `ASTIdentifier.name` instead', + DeprecationWarning, stacklevel=2, + ) + return self.name + class ASTNestedName(ASTBase): def __init__(self, names: list[ASTIdentifier], rooted: bool) -> None: @@ -94,6 +105,14 @@ class ASTNestedName(ASTBase): self.names = names self.rooted = rooted + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedName): + return NotImplemented + return self.names == other.names and self.rooted == other.rooted + + def __hash__(self) -> int: + return hash((self.names, self.rooted)) + @property def name(self) -> ASTNestedName: return self @@ -186,6 +205,14 @@ class ASTBooleanLiteral(ASTLiteral): def __init__(self, value: bool) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBooleanLiteral): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: if self.value: return 'true' @@ -202,6 +229,14 @@ class ASTNumberLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNumberLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -221,6 +256,17 @@ class ASTCharLiteral(ASTLiteral): else: raise UnsupportedMultiCharacterCharLiteral(decoded) + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCharLiteral): + return NotImplemented + return ( + self.prefix == other.prefix + and self.value == other.value + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.value)) + def _stringify(self, transform: StringifyTransform) -> str: if self.prefix is None: return "'" + self.data + "'" @@ -237,6 +283,14 @@ class ASTStringLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStringLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -251,6 +305,14 @@ class ASTIdExpression(ASTExpression): # note: this class is basically to cast a nested name as an expression self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdExpression): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.name) @@ -266,6 +328,14 @@ class ASTParenExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '(' + transform(self.expr) + ')' @@ -290,6 +360,14 @@ class ASTPostfixCallExpr(ASTPostfixOp): def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None: self.lst = lst + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixCallExpr): + return NotImplemented + return self.lst == other.lst + + def __hash__(self) -> int: + return hash(self.lst) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.lst) @@ -302,6 +380,14 @@ class ASTPostfixArray(ASTPostfixOp): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixArray): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '[' + transform(self.expr) + ']' @@ -334,6 +420,14 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMemberOfPointer): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '->' + transform(self.name) @@ -348,6 +442,14 @@ class ASTPostfixExpr(ASTExpression): self.prefix = prefix self.postFixes = postFixes + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixExpr): + return NotImplemented + return self.prefix == other.prefix and self.postFixes == other.postFixes + + def __hash__(self) -> int: + return hash((self.prefix, self.postFixes)) + def _stringify(self, transform: StringifyTransform) -> str: return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)]) @@ -366,6 +468,14 @@ class ASTUnaryOpExpr(ASTExpression): self.op = op self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnaryOpExpr): + return NotImplemented + return self.op == other.op and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.op, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: if self.op[0] in 'cn': return self.op + " " + transform(self.expr) @@ -386,6 +496,14 @@ class ASTSizeofType(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofType): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof(" + transform(self.typ) + ")" @@ -401,6 +519,14 @@ class ASTSizeofExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof " + transform(self.expr) @@ -415,6 +541,14 @@ class ASTAlignofExpr(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAlignofExpr): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "alignof(" + transform(self.typ) + ")" @@ -434,6 +568,17 @@ class ASTCastExpr(ASTExpression): self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCastExpr): + return NotImplemented + return ( + self.typ == other.typ + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] res.append(transform(self.typ)) @@ -456,6 +601,17 @@ class ASTBinOpExpr(ASTBase): self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBinOpExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -487,6 +643,17 @@ class ASTAssignmentExpr(ASTExpression): self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAssignmentExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -515,6 +682,14 @@ class ASTFallbackExpr(ASTExpression): def __init__(self, expr: str) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFallbackExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return self.expr @@ -539,6 +714,14 @@ class ASTTrailingTypeSpecFundamental(ASTTrailingTypeSpec): assert len(names) != 0 self.names = names + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecFundamental): + return NotImplemented + return self.names == other.names + + def __hash__(self) -> int: + return hash(self.names) + def _stringify(self, transform: StringifyTransform) -> str: return ' '.join(self.names) @@ -558,6 +741,17 @@ class ASTTrailingTypeSpecName(ASTTrailingTypeSpec): self.prefix = prefix self.nestedName = nestedName + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecName): + return NotImplemented + return ( + self.prefix == other.prefix + and self.nestedName == other.nestedName + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.nestedName)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -583,6 +777,14 @@ class ASTFunctionParameter(ASTBase): self.arg = arg self.ellipsis = ellipsis + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFunctionParameter): + return NotImplemented + return self.arg == other.arg and self.ellipsis == other.ellipsis + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: # the anchor will be our parent return symbol.parent.declaration.get_id(version, prefixed=False) @@ -607,6 +809,14 @@ class ASTParameters(ASTBase): self.args = args self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParameters): + return NotImplemented + return self.args == other.args and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.args, self.attrs)) + @property def function_params(self) -> list[ASTFunctionParameter]: return self.args @@ -674,6 +884,30 @@ class ASTDeclSpecsSimple(ASTBaseBase): self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecsSimple): + return NotImplemented + return ( + self.storage == other.storage + and self.threadLocal == other.threadLocal + and self.inline == other.inline + and self.restrict == other.restrict + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash(( + self.storage, + self.threadLocal, + self.inline, + self.restrict, + self.volatile, + self.const, + self.attrs, + )) + def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple: if not other: return self @@ -741,6 +975,24 @@ class ASTDeclSpecs(ASTBase): self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs) self.trailingTypeSpec = trailing + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecs): + return NotImplemented + return ( + self.outer == other.outer + and self.leftSpecs == other.leftSpecs + and self.rightSpecs == other.rightSpecs + and self.trailingTypeSpec == other.trailingTypeSpec + ) + + def __hash__(self) -> int: + return hash(( + self.outer, + self.leftSpecs, + self.rightSpecs, + self.trailingTypeSpec, + )) + def _stringify(self, transform: StringifyTransform) -> str: res: list[str] = [] l = transform(self.leftSpecs) @@ -796,6 +1048,28 @@ class ASTArray(ASTBase): if size is not None: assert not vla + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTArray): + return NotImplemented + return ( + self.static == other.static + and self.const == other.const + and self.volatile == other.volatile + and self.restrict == other.restrict + and self.vla == other.vla + and self.size == other.size + ) + + def __hash__(self) -> int: + return hash(( + self.static, + self.const, + self.volatile, + self.restrict, + self.vla, + self.size, + )) + def _stringify(self, transform: StringifyTransform) -> str: el = [] if self.static: @@ -861,6 +1135,18 @@ class ASTDeclaratorNameParam(ASTDeclarator): self.arrayOps = arrayOps self.param = param + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameParam): + return NotImplemented + return ( + self.declId == other.declId + and self.arrayOps == other.arrayOps + and self.param == other.param + ) + + def __hash__(self) -> int: + return hash((self.declId, self.arrayOps, self.param)) + @property def name(self) -> ASTNestedName: return self.declId @@ -899,6 +1185,14 @@ class ASTDeclaratorNameBitField(ASTDeclarator): self.declId = declId self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameBitField): + return NotImplemented + return self.declId == other.declId and self.size == other.size + + def __hash__(self) -> int: + return hash((self.declId, self.size)) + @property def name(self) -> ASTNestedName: return self.declId @@ -937,6 +1231,20 @@ class ASTDeclaratorPtr(ASTDeclarator): self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorPtr): + return NotImplemented + return ( + self.next == other.next + and self.restrict == other.restrict + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.next, self.restrict, self.volatile, self.const, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -1006,6 +1314,14 @@ class ASTDeclaratorParen(ASTDeclarator): self.next = next # TODO: we assume the name and params are in inner + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParen): + return NotImplemented + return self.inner == other.inner and self.next == other.next + + def __hash__(self) -> int: + return hash((self.inner, self.next)) + @property def name(self) -> ASTNestedName: return self.inner.name @@ -1040,6 +1356,14 @@ class ASTParenExprList(ASTBaseParenExprList): def __init__(self, exprs: list[ASTExpression]) -> None: self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExprList): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def _stringify(self, transform: StringifyTransform) -> str: exprs = [transform(e) for e in self.exprs] return '(%s)' % ', '.join(exprs) @@ -1064,6 +1388,14 @@ class ASTBracedInitList(ASTBase): self.exprs = exprs self.trailingComma = trailingComma + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBracedInitList): + return NotImplemented + return self.exprs == other.exprs and self.trailingComma == other.trailingComma + + def __hash__(self) -> int: + return hash((self.exprs, self.trailingComma)) + def _stringify(self, transform: StringifyTransform) -> str: exprs = ', '.join(transform(e) for e in self.exprs) trailingComma = ',' if self.trailingComma else '' @@ -1092,6 +1424,14 @@ class ASTInitializer(ASTBase): self.value = value self.hasAssign = hasAssign + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTInitializer): + return NotImplemented + return self.value == other.value and self.hasAssign == other.hasAssign + + def __hash__(self) -> int: + return hash((self.value, self.hasAssign)) + def _stringify(self, transform: StringifyTransform) -> str: val = transform(self.value) if self.hasAssign: @@ -1116,6 +1456,14 @@ class ASTType(ASTBase): self.declSpecs = declSpecs self.decl = decl + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTType): + return NotImplemented + return self.declSpecs == other.declSpecs and self.decl == other.decl + + def __hash__(self) -> int: + return hash((self.declSpecs, self.decl)) + @property def name(self) -> ASTNestedName: return self.decl.name @@ -1161,6 +1509,14 @@ class ASTTypeWithInit(ASTBase): self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -1190,6 +1546,18 @@ class ASTMacroParameter(ASTBase): self.ellipsis = ellipsis self.variadic = variadic + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTMacroParameter): + return NotImplemented + return ( + self.arg == other.arg + and self.ellipsis == other.ellipsis + and self.variadic == other.variadic + ) + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis, self.variadic)) + def _stringify(self, transform: StringifyTransform) -> str: if self.ellipsis: return '...' @@ -1215,6 +1583,14 @@ class ASTMacro(ASTBase): self.ident = ident self.args = args + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTMacro): + return NotImplemented + return self.ident == other.ident and self.args == other.args + + def __hash__(self) -> int: + return hash((self.ident, self.args)) + @property def name(self) -> ASTNestedName: return self.ident @@ -1254,6 +1630,14 @@ class ASTStruct(ASTBase): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStruct): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1270,6 +1654,14 @@ class ASTUnion(ASTBase): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnion): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1286,6 +1678,14 @@ class ASTEnum(ASTBase): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnum): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1305,6 +1705,18 @@ class ASTEnumerator(ASTBase): self.init = init self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnumerator): + return NotImplemented + return ( + self.name == other.name + and self.init == other.init + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.init, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -1346,6 +1758,18 @@ class ASTDeclaration(ASTBaseBase): # further changes will be made to this object self._newest_id_cache: str | None = None + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaration): + return NotImplemented + return ( + self.objectType == other.objectType + and self.directiveType == other.directiveType + and self.declaration == other.declaration + and self.semicolon == other.semicolon + and self.symbol == other.symbol + and self.enumeratorScopedSymbol == other.enumeratorScopedSymbol + ) + def clone(self) -> ASTDeclaration: return ASTDeclaration(self.objectType, self.directiveType, self.declaration.clone(), self.semicolon) diff --git a/sphinx/domains/c/_symbol.py b/sphinx/domains/c/_symbol.py index 5205204c4..fd1c0d05d 100644 --- a/sphinx/domains/c/_symbol.py +++ b/sphinx/domains/c/_symbol.py @@ -114,6 +114,9 @@ class Symbol: # Do symbol addition after self._children has been initialised. self._add_function_params() + def __repr__(self) -> str: + return f'' + def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None: self._assert_invariants() assert self.declaration is None diff --git a/sphinx/domains/cpp/_ast.py b/sphinx/domains/cpp/_ast.py index ad57695d1..141d5112c 100644 --- a/sphinx/domains/cpp/_ast.py +++ b/sphinx/domains/cpp/_ast.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import sys +import warnings +from typing import TYPE_CHECKING, Any, ClassVar, Literal from docutils import nodes @@ -44,60 +46,64 @@ class ASTBase(ASTBaseBase): ################################################################################ class ASTIdentifier(ASTBase): - def __init__(self, identifier: str) -> None: - assert identifier is not None - assert len(identifier) != 0 - self.identifier = identifier + def __init__(self, name: str) -> None: + if not isinstance(name, str) or len(name) == 0: + raise AssertionError + self.name = sys.intern(name) + self.is_anonymous = name[0] == '@' # ASTBaseBase already implements this method, # but specialising it here improves performance def __eq__(self, other: object) -> bool: - if type(other) is not ASTIdentifier: + if not isinstance(other, ASTIdentifier): return NotImplemented - return self.identifier == other.identifier + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) def _stringify(self, transform: StringifyTransform) -> str: - return transform(self.identifier) + return transform(self.name) def is_anon(self) -> bool: - return self.identifier[0] == '@' + return self.is_anonymous def get_id(self, version: int) -> str: - if self.is_anon() and version < 3: + if self.is_anonymous and version < 3: raise NoOldIdError if version == 1: - if self.identifier == 'size_t': + if self.name == 'size_t': return 's' else: - return self.identifier - if self.identifier == "std": + return self.name + if self.name == "std": return 'St' - elif self.identifier[0] == "~": + elif self.name[0] == "~": # a destructor, just use an arbitrary version of dtors return 'D0' else: - if self.is_anon(): - return 'Ut%d_%s' % (len(self.identifier) - 1, self.identifier[1:]) + if self.is_anonymous: + return 'Ut%d_%s' % (len(self.name) - 1, self.name[1:]) else: - return str(len(self.identifier)) + self.identifier + return str(len(self.name)) + self.name # and this is where we finally make a difference between __str__ and the display string def __str__(self) -> str: - return self.identifier + return self.name def get_display_string(self) -> str: - return "[anonymous]" if self.is_anon() else self.identifier + return "[anonymous]" if self.is_anonymous else self.name def describe_signature(self, signode: TextElement, mode: str, env: BuildEnvironment, prefix: str, templateArgs: str, symbol: Symbol) -> None: verify_description_mode(mode) - if self.is_anon(): + if self.is_anonymous: node = addnodes.desc_sig_name(text="[anonymous]") else: - node = addnodes.desc_sig_name(self.identifier, self.identifier) + node = addnodes.desc_sig_name(self.name, self.name) if mode == 'markType': - targetText = prefix + self.identifier + templateArgs + targetText = prefix + self.name + templateArgs pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', reftarget=targetText, modname=None, @@ -118,8 +124,8 @@ class ASTIdentifier(ASTBase): # the target is 'operator""id' instead of just 'id' assert len(prefix) == 0 assert len(templateArgs) == 0 - assert not self.is_anon() - targetText = 'operator""' + self.identifier + assert not self.is_anonymous + targetText = 'operator""' + self.name pnode = addnodes.pending_xref('', refdomain='cpp', reftype='identifier', reftarget=targetText, modname=None, @@ -130,6 +136,14 @@ class ASTIdentifier(ASTBase): else: raise Exception('Unknown description mode: %s' % mode) + @property + def identifier(self) -> str: + warnings.warn( + '`ASTIdentifier.identifier` is deprecated, use `ASTIdentifier.name` instead', + DeprecationWarning, stacklevel=2, + ) + return self.name + class ASTNestedNameElement(ASTBase): def __init__(self, identOrOp: ASTIdentifier | ASTOperator, @@ -137,6 +151,14 @@ class ASTNestedNameElement(ASTBase): self.identOrOp = identOrOp self.templateArgs = templateArgs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedNameElement): + return NotImplemented + return self.identOrOp == other.identOrOp and self.templateArgs == other.templateArgs + + def __hash__(self) -> int: + return hash((self.identOrOp, self.templateArgs)) + def is_operator(self) -> bool: return False @@ -169,6 +191,18 @@ class ASTNestedName(ASTBase): assert len(self.names) == len(self.templates) self.rooted = rooted + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNestedName): + return NotImplemented + return ( + self.names == other.names + and self.templates == other.templates + and self.rooted == other.rooted + ) + + def __hash__(self) -> int: + return hash((self.names, self.templates, self.rooted)) + @property def name(self) -> ASTNestedName: return self @@ -316,6 +350,12 @@ class ASTLiteral(ASTExpression): class ASTPointerLiteral(ASTLiteral): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPointerLiteral) + + def __hash__(self) -> int: + return hash('nullptr') + def _stringify(self, transform: StringifyTransform) -> str: return 'nullptr' @@ -331,6 +371,14 @@ class ASTBooleanLiteral(ASTLiteral): def __init__(self, value: bool) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBooleanLiteral): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: if self.value: return 'true' @@ -352,6 +400,14 @@ class ASTNumberLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNumberLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -368,6 +424,14 @@ class ASTStringLiteral(ASTLiteral): def __init__(self, data: str) -> None: self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTStringLiteral): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + def _stringify(self, transform: StringifyTransform) -> str: return self.data @@ -392,6 +456,17 @@ class ASTCharLiteral(ASTLiteral): else: raise UnsupportedMultiCharacterCharLiteral(decoded) + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCharLiteral): + return NotImplemented + return ( + self.prefix == other.prefix + and self.value == other.value + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.value)) + def _stringify(self, transform: StringifyTransform) -> str: if self.prefix is None: return "'" + self.data + "'" @@ -415,6 +490,14 @@ class ASTUserDefinedLiteral(ASTLiteral): self.literal = literal self.ident = ident + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUserDefinedLiteral): + return NotImplemented + return self.literal == other.literal and self.ident == other.ident + + def __hash__(self) -> int: + return hash((self.literal, self.ident)) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.literal) + transform(self.ident) @@ -431,6 +514,12 @@ class ASTUserDefinedLiteral(ASTLiteral): ################################################################################ class ASTThisLiteral(ASTExpression): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTThisLiteral) + + def __hash__(self) -> int: + return hash("this") + def _stringify(self, transform: StringifyTransform) -> str: return "this" @@ -450,6 +539,18 @@ class ASTFoldExpr(ASTExpression): self.op = op self.rightExpr = rightExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFoldExpr): + return NotImplemented + return ( + self.leftExpr == other.leftExpr + and self.op == other.op + and self.rightExpr == other.rightExpr + ) + + def __hash__(self) -> int: + return hash((self.leftExpr, self.op, self.rightExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] if self.leftExpr: @@ -508,6 +609,14 @@ class ASTParenExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '(' + transform(self.expr) + ')' @@ -526,6 +635,14 @@ class ASTIdExpression(ASTExpression): # note: this class is basically to cast a nested name as an expression self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdExpression): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.name) @@ -553,6 +670,14 @@ class ASTPostfixArray(ASTPostfixOp): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixArray): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return '[' + transform(self.expr) + ']' @@ -570,6 +695,14 @@ class ASTPostfixMember(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMember): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '.' + transform(self.name) @@ -586,6 +719,14 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): def __init__(self, name: ASTNestedName) -> None: self.name = name + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixMemberOfPointer): + return NotImplemented + return self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + def _stringify(self, transform: StringifyTransform) -> str: return '->' + transform(self.name) @@ -599,6 +740,12 @@ class ASTPostfixMemberOfPointer(ASTPostfixOp): class ASTPostfixInc(ASTPostfixOp): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPostfixInc) + + def __hash__(self) -> int: + return hash('++') + def _stringify(self, transform: StringifyTransform) -> str: return '++' @@ -611,6 +758,12 @@ class ASTPostfixInc(ASTPostfixOp): class ASTPostfixDec(ASTPostfixOp): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTPostfixDec) + + def __hash__(self) -> int: + return hash('--') + def _stringify(self, transform: StringifyTransform) -> str: return '--' @@ -626,6 +779,14 @@ class ASTPostfixCallExpr(ASTPostfixOp): def __init__(self, lst: ASTParenExprList | ASTBracedInitList) -> None: self.lst = lst + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixCallExpr): + return NotImplemented + return self.lst == other.lst + + def __hash__(self) -> int: + return hash(self.lst) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.lst) @@ -647,6 +808,14 @@ class ASTPostfixExpr(ASTExpression): self.prefix = prefix self.postFixes = postFixes + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPostfixExpr): + return NotImplemented + return self.prefix == other.prefix and self.postFixes == other.postFixes + + def __hash__(self) -> int: + return hash((self.prefix, self.postFixes)) + def _stringify(self, transform: StringifyTransform) -> str: return ''.join([transform(self.prefix), *(transform(p) for p in self.postFixes)]) @@ -670,6 +839,14 @@ class ASTExplicitCast(ASTExpression): self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTExplicitCast): + return NotImplemented + return self.cast == other.cast and self.typ == other.typ and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.cast, self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [self.cast] res.append('<') @@ -700,6 +877,14 @@ class ASTTypeId(ASTExpression): self.typeOrExpr = typeOrExpr self.isType = isType + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeId): + return NotImplemented + return self.typeOrExpr == other.typeOrExpr and self.isType == other.isType + + def __hash__(self) -> int: + return hash((self.typeOrExpr, self.isType)) + def _stringify(self, transform: StringifyTransform) -> str: return 'typeid(' + transform(self.typeOrExpr) + ')' @@ -723,6 +908,14 @@ class ASTUnaryOpExpr(ASTExpression): self.op = op self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnaryOpExpr): + return NotImplemented + return self.op == other.op and self.expr == other.expr + + def __hash__(self) -> int: + return hash((self.op, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: if self.op[0] in 'cn': return self.op + " " + transform(self.expr) @@ -746,6 +939,14 @@ class ASTSizeofParamPack(ASTExpression): def __init__(self, identifier: ASTIdentifier) -> None: self.identifier = identifier + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofParamPack): + return NotImplemented + return self.identifier == other.identifier + + def __hash__(self) -> int: + return hash(self.identifier) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof...(" + transform(self.identifier) + ")" @@ -766,6 +967,14 @@ class ASTSizeofType(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofType): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof(" + transform(self.typ) + ")" @@ -784,6 +993,14 @@ class ASTSizeofExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTSizeofExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return "sizeof " + transform(self.expr) @@ -801,6 +1018,14 @@ class ASTAlignofExpr(ASTExpression): def __init__(self, typ: ASTType) -> None: self.typ = typ + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAlignofExpr): + return NotImplemented + return self.typ == other.typ + + def __hash__(self) -> int: + return hash(self.typ) + def _stringify(self, transform: StringifyTransform) -> str: return "alignof(" + transform(self.typ) + ")" @@ -819,6 +1044,14 @@ class ASTNoexceptExpr(ASTExpression): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNoexceptExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'noexcept(' + transform(self.expr) + ')' @@ -841,6 +1074,19 @@ class ASTNewExpr(ASTExpression): self.typ = typ self.initList = initList + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNewExpr): + return NotImplemented + return ( + self.rooted == other.rooted + and self.isNewTypeId == other.isNewTypeId + and self.typ == other.typ + and self.initList == other.initList + ) + + def __hash__(self) -> int: + return hash((self.rooted, self.isNewTypeId, self.typ, self.initList)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.rooted: @@ -888,6 +1134,18 @@ class ASTDeleteExpr(ASTExpression): self.array = array self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeleteExpr): + return NotImplemented + return ( + self.rooted == other.rooted + and self.array == other.array + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.rooted, self.array, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.rooted: @@ -925,6 +1183,17 @@ class ASTCastExpr(ASTExpression): self.typ = typ self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCastExpr): + return NotImplemented + return ( + self.typ == other.typ + and self.expr == other.expr + ) + + def __hash__(self) -> int: + return hash((self.typ, self.expr)) + def _stringify(self, transform: StringifyTransform) -> str: res = ['('] res.append(transform(self.typ)) @@ -950,6 +1219,17 @@ class ASTBinOpExpr(ASTExpression): self.exprs = exprs self.ops = ops + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBinOpExpr): + return NotImplemented + return ( + self.exprs == other.exprs + and self.ops == other.ops + ) + + def __hash__(self) -> int: + return hash((self.exprs, self.ops)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.exprs[0])) @@ -990,6 +1270,18 @@ class ASTConditionalExpr(ASTExpression): self.thenExpr = thenExpr self.elseExpr = elseExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTConditionalExpr): + return NotImplemented + return ( + self.ifExpr == other.ifExpr + and self.thenExpr == other.thenExpr + and self.elseExpr == other.elseExpr + ) + + def __hash__(self) -> int: + return hash((self.ifExpr, self.thenExpr, self.elseExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.ifExpr)) @@ -1027,6 +1319,14 @@ class ASTBracedInitList(ASTBase): self.exprs = exprs self.trailingComma = trailingComma + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBracedInitList): + return NotImplemented + return self.exprs == other.exprs and self.trailingComma == other.trailingComma + + def __hash__(self) -> int: + return hash((self.exprs, self.trailingComma)) + def get_id(self, version: int) -> str: return "il%sE" % ''.join(e.get_id(version) for e in self.exprs) @@ -1059,6 +1359,18 @@ class ASTAssignmentExpr(ASTExpression): self.op = op self.rightExpr = rightExpr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTAssignmentExpr): + return NotImplemented + return ( + self.leftExpr == other.leftExpr + and self.op == other.op + and self.rightExpr == other.rightExpr + ) + + def __hash__(self) -> int: + return hash((self.leftExpr, self.op, self.rightExpr)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] res.append(transform(self.leftExpr)) @@ -1093,6 +1405,14 @@ class ASTCommaExpr(ASTExpression): assert len(exprs) > 0 self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCommaExpr): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def _stringify(self, transform: StringifyTransform) -> str: return ', '.join(transform(e) for e in self.exprs) @@ -1118,6 +1438,14 @@ class ASTFallbackExpr(ASTExpression): def __init__(self, expr: str) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFallbackExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return self.expr @@ -1137,11 +1465,16 @@ class ASTFallbackExpr(ASTExpression): ################################################################################ class ASTOperator(ASTBase): + is_anonymous: ClassVar[Literal[False]] = False + def __eq__(self, other: object) -> bool: raise NotImplementedError(repr(self)) + def __hash__(self) -> int: + raise NotImplementedError(repr(self)) + def is_anon(self) -> bool: - return False + return self.is_anonymous def is_operator(self) -> bool: return True @@ -1193,6 +1526,9 @@ class ASTOperatorBuildIn(ASTOperator): return NotImplemented return self.op == other.op + def __hash__(self) -> int: + return hash(self.op) + def get_id(self, version: int) -> str: if version == 1: ids = _id_operator_v1 @@ -1228,6 +1564,9 @@ class ASTOperatorLiteral(ASTOperator): return NotImplemented return self.identifier == other.identifier + def __hash__(self) -> int: + return hash(self.identifier) + def get_id(self, version: int) -> str: if version == 1: raise NoOldIdError @@ -1252,6 +1591,9 @@ class ASTOperatorType(ASTOperator): return NotImplemented return self.type == other.type + def __hash__(self) -> int: + return hash(self.type) + def get_id(self, version: int) -> str: if version == 1: return 'castto-%s-operator' % self.type.get_id(version) @@ -1275,6 +1617,14 @@ class ASTTemplateArgConstant(ASTBase): def __init__(self, value: ASTExpression) -> None: self.value = value + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateArgConstant): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.value) @@ -1298,6 +1648,14 @@ class ASTTemplateArgs(ASTBase): self.args = args self.packExpansion = packExpansion + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateArgs): + return NotImplemented + return self.args == other.args and self.packExpansion == other.packExpansion + + def __hash__(self) -> int: + return hash((self.args, self.packExpansion)) + def get_id(self, version: int) -> str: if version == 1: res = [] @@ -1361,6 +1719,14 @@ class ASTTrailingTypeSpecFundamental(ASTTrailingTypeSpec): # the canonical name list is for ID lookup self.canonNames = canonNames + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecFundamental): + return NotImplemented + return self.names == other.names and self.canonNames == other.canonNames + + def __hash__(self) -> int: + return hash((self.names, self.canonNames)) + def _stringify(self, transform: StringifyTransform) -> str: return ' '.join(self.names) @@ -1394,6 +1760,12 @@ class ASTTrailingTypeSpecFundamental(ASTTrailingTypeSpec): class ASTTrailingTypeSpecDecltypeAuto(ASTTrailingTypeSpec): + def __eq__(self, other: object) -> bool: + return isinstance(other, ASTTrailingTypeSpecDecltypeAuto) + + def __hash__(self) -> int: + return hash('decltype(auto)') + def _stringify(self, transform: StringifyTransform) -> str: return 'decltype(auto)' @@ -1414,6 +1786,14 @@ class ASTTrailingTypeSpecDecltype(ASTTrailingTypeSpec): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecDecltype): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'decltype(' + transform(self.expr) + ')' @@ -1437,6 +1817,18 @@ class ASTTrailingTypeSpecName(ASTTrailingTypeSpec): self.nestedName = nestedName self.placeholderType = placeholderType + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTrailingTypeSpecName): + return NotImplemented + return ( + self.prefix == other.prefix + and self.nestedName == other.nestedName + and self.placeholderType == other.placeholderType + ) + + def __hash__(self) -> int: + return hash((self.prefix, self.nestedName, self.placeholderType)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -1480,6 +1872,14 @@ class ASTFunctionParameter(ASTBase): self.arg = arg self.ellipsis = ellipsis + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTFunctionParameter): + return NotImplemented + return self.arg == other.arg and self.ellipsis == other.ellipsis + + def __hash__(self) -> int: + return hash((self.arg, self.ellipsis)) + def get_id( self, version: int, objectType: str | None = None, symbol: Symbol | None = None, ) -> str: @@ -1512,6 +1912,14 @@ class ASTNoexceptSpec(ASTBase): def __init__(self, expr: ASTExpression | None) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNoexceptSpec): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: if self.expr: return 'noexcept(' + transform(self.expr) + ')' @@ -1543,6 +1951,28 @@ class ASTParametersQualifiers(ASTBase): self.attrs = attrs self.initializer = initializer + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParametersQualifiers): + return NotImplemented + return ( + self.args == other.args + and self.volatile == other.volatile + and self.const == other.const + and self.refQual == other.refQual + and self.exceptionSpec == other.exceptionSpec + and self.trailingReturn == other.trailingReturn + and self.override == other.override + and self.final == other.final + and self.attrs == other.attrs + and self.initializer == other.initializer + ) + + def __hash__(self) -> int: + return hash(( + self.args, self.volatile, self.const, self.refQual, self.exceptionSpec, + self.trailingReturn, self.override, self.final, self.attrs, self.initializer + )) + @property def function_params(self) -> list[ASTFunctionParameter]: return self.args @@ -1681,6 +2111,14 @@ class ASTExplicitSpec(ASTBase): def __init__(self, expr: ASTExpression | None) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTExplicitSpec): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: res = ['explicit'] if self.expr is not None: @@ -1717,6 +2155,40 @@ class ASTDeclSpecsSimple(ASTBase): self.friend = friend self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecsSimple): + return NotImplemented + return ( + self.storage == other.storage + and self.threadLocal == other.threadLocal + and self.inline == other.inline + and self.virtual == other.virtual + and self.explicitSpec == other.explicitSpec + and self.consteval == other.consteval + and self.constexpr == other.constexpr + and self.constinit == other.constinit + and self.volatile == other.volatile + and self.const == other.const + and self.friend == other.friend + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash(( + self.storage, + self.threadLocal, + self.inline, + self.virtual, + self.explicitSpec, + self.consteval, + self.constexpr, + self.constinit, + self.volatile, + self.const, + self.friend, + self.attrs, + )) + def mergeWith(self, other: ASTDeclSpecsSimple) -> ASTDeclSpecsSimple: if not other: return self @@ -1811,6 +2283,24 @@ class ASTDeclSpecs(ASTBase): self.allSpecs = self.leftSpecs.mergeWith(self.rightSpecs) self.trailingTypeSpec = trailing + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclSpecs): + return NotImplemented + return ( + self.outer == other.outer + and self.leftSpecs == other.leftSpecs + and self.rightSpecs == other.rightSpecs + and self.trailingTypeSpec == other.trailingTypeSpec + ) + + def __hash__(self) -> int: + return hash(( + self.outer, + self.leftSpecs, + self.rightSpecs, + self.trailingTypeSpec, + )) + def get_id(self, version: int) -> str: if version == 1: res = [] @@ -1873,6 +2363,14 @@ class ASTArray(ASTBase): def __init__(self, size: ASTExpression) -> None: self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTArray): + return NotImplemented + return self.size == other.size + + def __hash__(self) -> int: + return hash(self.size) + def _stringify(self, transform: StringifyTransform) -> str: if self.size: return '[' + transform(self.size) + ']' @@ -1953,6 +2451,18 @@ class ASTDeclaratorNameParamQual(ASTDeclarator): self.arrayOps = arrayOps self.paramQual = paramQual + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameParamQual): + return NotImplemented + return ( + self.declId == other.declId + and self.arrayOps == other.arrayOps + and self.paramQual == other.paramQual + ) + + def __hash__(self) -> int: + return hash((self.declId, self.arrayOps, self.paramQual)) + @property def name(self) -> ASTNestedName: return self.declId @@ -2037,6 +2547,14 @@ class ASTDeclaratorNameBitField(ASTDeclarator): self.declId = declId self.size = size + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorNameBitField): + return NotImplemented + return self.declId == other.declId and self.size == other.size + + def __hash__(self) -> int: + return hash((self.declId, self.size)) + @property def name(self) -> ASTNestedName: return self.declId @@ -2087,6 +2605,19 @@ class ASTDeclaratorPtr(ASTDeclarator): self.const = const self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorPtr): + return NotImplemented + return ( + self.next == other.next + and self.volatile == other.volatile + and self.const == other.const + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.next, self.volatile, self.const, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2192,6 +2723,14 @@ class ASTDeclaratorRef(ASTDeclarator): self.next = next self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorRef): + return NotImplemented + return self.next == other.next and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.next, self.attrs)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2258,6 +2797,14 @@ class ASTDeclaratorParamPack(ASTDeclarator): assert next self.next = next + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParamPack): + return NotImplemented + return self.next == other.next + + def __hash__(self) -> int: + return hash(self.next) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2326,6 +2873,19 @@ class ASTDeclaratorMemPtr(ASTDeclarator): self.volatile = volatile self.next = next + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorMemPtr): + return NotImplemented + return ( + self.className == other.className + and self.const == other.const + and self.volatile == other.volatile + and self.next == other.next + ) + + def __hash__(self) -> int: + return hash((self.className, self.const, self.volatile, self.next)) + @property def name(self) -> ASTNestedName: return self.next.name @@ -2424,6 +2984,14 @@ class ASTDeclaratorParen(ASTDeclarator): self.next = next # TODO: we assume the name, params, and qualifiers are in inner + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaratorParen): + return NotImplemented + return self.inner == other.inner and self.next == other.next + + def __hash__(self) -> int: + return hash((self.inner, self.next)) + @property def name(self) -> ASTNestedName: return self.inner.name @@ -2493,6 +3061,14 @@ class ASTPackExpansionExpr(ASTExpression): def __init__(self, expr: ASTExpression | ASTBracedInitList) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTPackExpansionExpr): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return transform(self.expr) + '...' @@ -2510,6 +3086,14 @@ class ASTParenExprList(ASTBaseParenExprList): def __init__(self, exprs: list[ASTExpression | ASTBracedInitList]) -> None: self.exprs = exprs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenExprList): + return NotImplemented + return self.exprs == other.exprs + + def __hash__(self) -> int: + return hash(self.exprs) + def get_id(self, version: int) -> str: return "pi%sE" % ''.join(e.get_id(version) for e in self.exprs) @@ -2538,6 +3122,14 @@ class ASTInitializer(ASTBase): self.value = value self.hasAssign = hasAssign + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTInitializer): + return NotImplemented + return self.value == other.value and self.hasAssign == other.hasAssign + + def __hash__(self) -> int: + return hash((self.value, self.hasAssign)) + def _stringify(self, transform: StringifyTransform) -> str: val = transform(self.value) if self.hasAssign: @@ -2562,6 +3154,14 @@ class ASTType(ASTBase): self.declSpecs = declSpecs self.decl = decl + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTType): + return NotImplemented + return self.declSpecs == other.declSpecs and self.decl == other.decl + + def __hash__(self) -> int: + return hash((self.declSpecs, self.decl)) + @property def name(self) -> ASTNestedName: return self.decl.name @@ -2671,6 +3271,14 @@ class ASTTemplateParamConstrainedTypeWithInit(ASTBase): self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamConstrainedTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -2712,6 +3320,14 @@ class ASTTypeWithInit(ASTBase): self.type = type self.init = init + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeWithInit): + return NotImplemented + return self.type == other.type and self.init == other.init + + def __hash__(self) -> int: + return hash((self.type, self.init)) + @property def name(self) -> ASTNestedName: return self.type.name @@ -2749,6 +3365,14 @@ class ASTTypeUsing(ASTBase): self.name = name self.type = type + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTypeUsing): + return NotImplemented + return self.name == other.name and self.type == other.type + + def __hash__(self) -> int: + return hash((self.name, self.type)) + def get_id(self, version: int, objectType: str | None = None, symbol: Symbol | None = None) -> str: if version == 1: @@ -2785,6 +3409,14 @@ class ASTConcept(ASTBase): self.nestedName = nestedName self.initializer = initializer + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTConcept): + return NotImplemented + return self.nestedName == other.nestedName and self.initializer == other.initializer + + def __hash__(self) -> int: + return hash((self.nestedName, self.initializer)) + @property def name(self) -> ASTNestedName: return self.nestedName @@ -2816,6 +3448,19 @@ class ASTBaseClass(ASTBase): self.virtual = virtual self.pack = pack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTBaseClass): + return NotImplemented + return ( + self.name == other.name + and self.visibility == other.visibility + and self.virtual == other.virtual + and self.pack == other.pack + ) + + def __hash__(self) -> int: + return hash((self.name, self.visibility, self.virtual, self.pack)) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.visibility is not None: @@ -2851,6 +3496,19 @@ class ASTClass(ASTBase): self.bases = bases self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTClass): + return NotImplemented + return ( + self.name == other.name + and self.final == other.final + and self.bases == other.bases + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.final, self.bases, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: return symbol.get_full_nested_name().get_id(version) @@ -2899,6 +3557,14 @@ class ASTUnion(ASTBase): self.name = name self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTUnion): + return NotImplemented + return self.name == other.name and self.attrs == other.attrs + + def __hash__(self) -> int: + return hash((self.name, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -2929,6 +3595,19 @@ class ASTEnum(ASTBase): self.underlyingType = underlyingType self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnum): + return NotImplemented + return ( + self.name == other.name + and self.scoped == other.scoped + and self.underlyingType == other.underlyingType + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.scoped, self.underlyingType, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -2971,6 +3650,18 @@ class ASTEnumerator(ASTBase): self.init = init self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTEnumerator): + return NotImplemented + return ( + self.name == other.name + and self.init == other.init + and self.attrs == other.attrs + ) + + def __hash__(self) -> int: + return hash((self.name, self.init, self.attrs)) + def get_id(self, version: int, objectType: str, symbol: Symbol) -> str: if version == 1: raise NoOldIdError @@ -3035,6 +3726,19 @@ class ASTTemplateKeyParamPackIdDefault(ASTTemplateParam): self.parameterPack = parameterPack self.default = default + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateKeyParamPackIdDefault): + return NotImplemented + return ( + self.key == other.key + and self.identifier == other.identifier + and self.parameterPack == other.parameterPack + and self.default == other.default + ) + + def __hash__(self) -> int: + return hash((self.key, self.identifier, self.parameterPack, self.default)) + def get_identifier(self) -> ASTIdentifier: return self.identifier @@ -3086,6 +3790,14 @@ class ASTTemplateParamType(ASTTemplateParam): assert data self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamType): + return NotImplemented + return self.data == other.data + + def __hash__(self) -> int: + return hash(self.data) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3125,6 +3837,17 @@ class ASTTemplateParamTemplateType(ASTTemplateParam): self.nestedParams = nestedParams self.data = data + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamTemplateType): + return NotImplemented + return ( + self.nestedParams == other.nestedParams + and self.data == other.data + ) + + def __hash__(self) -> int: + return hash((self.nestedParams, self.data)) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3166,6 +3889,14 @@ class ASTTemplateParamNonType(ASTTemplateParam): self.param = param self.parameterPack = parameterPack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParamNonType): + return NotImplemented + return ( + self.param == other.param + and self.parameterPack == other.parameterPack + ) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3221,6 +3952,14 @@ class ASTTemplateParams(ASTBase): self.params = params self.requiresClause = requiresClause + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateParams): + return NotImplemented + return self.params == other.params and self.requiresClause == other.requiresClause + + def __hash__(self) -> int: + return hash((self.params, self.requiresClause)) + def get_id(self, version: int, excludeRequires: bool = False) -> str: assert version >= 2 res = [] @@ -3295,6 +4034,17 @@ class ASTTemplateIntroductionParameter(ASTBase): self.identifier = identifier self.parameterPack = parameterPack + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateIntroductionParameter): + return NotImplemented + return ( + self.identifier == other.identifier + and self.parameterPack == other.parameterPack + ) + + def __hash__(self) -> int: + return hash((self.identifier, self.parameterPack)) + @property def name(self) -> ASTNestedName: id = self.get_identifier() @@ -3351,6 +4101,14 @@ class ASTTemplateIntroduction(ASTBase): self.concept = concept self.params = params + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateIntroduction): + return NotImplemented + return self.concept == other.concept and self.params == other.params + + def __hash__(self) -> int: + return hash((self.concept, self.params)) + def get_id(self, version: int) -> str: assert version >= 2 return ''.join([ @@ -3402,6 +4160,14 @@ class ASTTemplateDeclarationPrefix(ASTBase): # templates is None means it's an explicit instantiation of a variable self.templates = templates + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTTemplateDeclarationPrefix): + return NotImplemented + return self.templates == other.templates + + def __hash__(self) -> int: + return hash(self.templates) + def get_requires_clause_in_last(self) -> ASTRequiresClause | None: if self.templates is None: return None @@ -3436,6 +4202,14 @@ class ASTRequiresClause(ASTBase): def __init__(self, expr: ASTExpression) -> None: self.expr = expr + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTRequiresClause): + return NotImplemented + return self.expr == other.expr + + def __hash__(self) -> int: + return hash(self.expr) + def _stringify(self, transform: StringifyTransform) -> str: return 'requires ' + transform(self.expr) @@ -3472,6 +4246,21 @@ class ASTDeclaration(ASTBase): # further changes will be made to this object self._newest_id_cache: str | None = None + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTDeclaration): + return NotImplemented + return ( + self.objectType == other.objectType + and self.directiveType == other.directiveType + and self.visibility == other.visibility + and self.templatePrefix == other.templatePrefix + and self.declaration == other.declaration + and self.trailingRequiresClause == other.trailingRequiresClause + and self.semicolon == other.semicolon + and self.symbol == other.symbol + and self.enumeratorScopedSymbol == other.enumeratorScopedSymbol + ) + def clone(self) -> ASTDeclaration: templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None trailingRequiresClasueClone = self.trailingRequiresClause.clone() \ @@ -3627,6 +4416,14 @@ class ASTNamespace(ASTBase): self.nestedName = nestedName self.templatePrefix = templatePrefix + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTNamespace): + return NotImplemented + return ( + self.nestedName == other.nestedName + and self.templatePrefix == other.templatePrefix + ) + def _stringify(self, transform: StringifyTransform) -> str: res = [] if self.templatePrefix: diff --git a/sphinx/domains/cpp/_symbol.py b/sphinx/domains/cpp/_symbol.py index 4caa43070..14c8f5fe6 100644 --- a/sphinx/domains/cpp/_symbol.py +++ b/sphinx/domains/cpp/_symbol.py @@ -155,6 +155,9 @@ class Symbol: # Do symbol addition after self._children has been initialised. self._add_template_and_function_params() + def __repr__(self) -> str: + return f'' + def _fill_empty(self, declaration: ASTDeclaration, docname: str, line: int) -> None: self._assert_invariants() assert self.declaration is None diff --git a/sphinx/util/cfamily.py b/sphinx/util/cfamily.py index c8879839e..53f38685a 100644 --- a/sphinx/util/cfamily.py +++ b/sphinx/util/cfamily.py @@ -90,17 +90,11 @@ class NoOldIdError(Exception): class ASTBaseBase: def __eq__(self, other: object) -> bool: if type(self) is not type(other): - return False + return NotImplemented try: - for key, value in self.__dict__.items(): - if value != getattr(other, key): - return False + return self.__dict__ == other.__dict__ except AttributeError: return False - return True - - # Defining __hash__ = None is not strictly needed when __eq__ is defined. - __hash__ = None # type: ignore[assignment] def clone(self) -> Any: return deepcopy(self) @@ -115,7 +109,7 @@ class ASTBaseBase: return self._stringify(lambda ast: ast.get_display_string()) def __repr__(self) -> str: - return '<%s>' % self.__class__.__name__ + return f'<{self.__class__.__name__}: {self._stringify(repr)}>' ################################################################################ @@ -131,8 +125,16 @@ class ASTCPPAttribute(ASTAttribute): def __init__(self, arg: str) -> None: self.arg = arg + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTCPPAttribute): + return NotImplemented + return self.arg == other.arg + + def __hash__(self) -> int: + return hash(self.arg) + def _stringify(self, transform: StringifyTransform) -> str: - return "[[" + self.arg + "]]" + return f"[[{self.arg}]]" def describe_signature(self, signode: TextElement) -> None: signode.append(addnodes.desc_sig_punctuation('[[', '[[')) @@ -146,35 +148,37 @@ class ASTGnuAttribute(ASTBaseBase): self.args = args def __eq__(self, other: object) -> bool: - if type(other) is not ASTGnuAttribute: + if not isinstance(other, ASTGnuAttribute): return NotImplemented return self.name == other.name and self.args == other.args + def __hash__(self) -> int: + return hash((self.name, self.args)) + def _stringify(self, transform: StringifyTransform) -> str: - res = [self.name] if self.args: - res.append(transform(self.args)) - return ''.join(res) + return self.name + transform(self.args) + return self.name class ASTGnuAttributeList(ASTAttribute): def __init__(self, attrs: list[ASTGnuAttribute]) -> None: self.attrs = attrs + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTGnuAttributeList): + return NotImplemented + return self.attrs == other.attrs + + def __hash__(self) -> int: + return hash(self.attrs) + def _stringify(self, transform: StringifyTransform) -> str: - res = ['__attribute__(('] - first = True - for attr in self.attrs: - if not first: - res.append(', ') - first = False - res.append(transform(attr)) - res.append('))') - return ''.join(res) + attrs = ', '.join(map(transform, self.attrs)) + return f'__attribute__(({attrs}))' def describe_signature(self, signode: TextElement) -> None: - txt = str(self) - signode.append(nodes.Text(txt)) + signode.append(nodes.Text(str(self))) class ASTIdAttribute(ASTAttribute): @@ -183,6 +187,14 @@ class ASTIdAttribute(ASTAttribute): def __init__(self, id: str) -> None: self.id = id + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTIdAttribute): + return NotImplemented + return self.id == other.id + + def __hash__(self) -> int: + return hash(self.id) + def _stringify(self, transform: StringifyTransform) -> str: return self.id @@ -197,12 +209,19 @@ class ASTParenAttribute(ASTAttribute): self.id = id self.arg = arg + def __eq__(self, other: object) -> bool: + if not isinstance(other, ASTParenAttribute): + return NotImplemented + return self.id == other.id and self.arg == other.arg + + def __hash__(self) -> int: + return hash((self.id, self.arg)) + def _stringify(self, transform: StringifyTransform) -> str: - return self.id + '(' + self.arg + ')' + return f'{self.id}({self.arg})' def describe_signature(self, signode: TextElement) -> None: - txt = str(self) - signode.append(nodes.Text(txt)) + signode.append(nodes.Text(str(self))) class ASTAttributeList(ASTBaseBase): @@ -210,10 +229,13 @@ class ASTAttributeList(ASTBaseBase): self.attrs = attrs def __eq__(self, other: object) -> bool: - if type(other) is not ASTAttributeList: + if not isinstance(other, ASTAttributeList): return NotImplemented return self.attrs == other.attrs + def __hash__(self) -> int: + return hash(self.attrs) + def __len__(self) -> int: return len(self.attrs) @@ -221,7 +243,7 @@ class ASTAttributeList(ASTBaseBase): return ASTAttributeList(self.attrs + other.attrs) def _stringify(self, transform: StringifyTransform) -> str: - return ' '.join(transform(attr) for attr in self.attrs) + return ' '.join(map(transform, self.attrs)) def describe_signature(self, signode: TextElement) -> None: if len(self.attrs) == 0: