[C++] Support requires-clause in more places

Previously a C++20 requires-clause was only supported on `function`
declarations.  However, the C++ standard allows a require-clause on
class/union templates, alias templates, and variable templates, and
also allows a requires clause after each template parameter list, not
just the final one.

This moves the requiresClause to be a property of `ASTTemplateParams`
rather than `ASTDeclaration` to better match the C++ grammar and
allows requires clauses in many places that are supported by C++20 but
were not previously allowed by Sphinx, namely:

- On class templates, alias templates, and variable templates

- After each template parameter list, not just the last one.

- After the template parameter list in template template parameters.

When encoding the id, the requires clause of the last template
parameter list is treated specially in order to preserve compatibility
with existing v4 ids.
This commit is contained in:
Jeremy Maitin-Shepard 2022-03-21 17:28:20 -07:00 committed by Jakob Lykke Andersen
parent 3c469c4258
commit ac1b0d490c
2 changed files with 84 additions and 33 deletions

View File

@ -3687,17 +3687,29 @@ class ASTTemplateParamNonType(ASTTemplateParam):
class ASTTemplateParams(ASTBase):
def __init__(self, params: List[ASTTemplateParam]) -> None:
def __init__(self, params: List[ASTTemplateParam],
requiresClause: Optional["ASTRequiresClause"]) -> None:
assert params is not None
self.params = params
self.requiresClause = requiresClause
def get_id(self, version: int) -> str:
def get_id(self, version: int, exclude_requires: bool = False) -> str:
# Note: For `version==4`, `exclude_requires` is set to `True` when
# encoding the id of the last template parameter list of a declaration,
# as that requires-clause, if any, is instead encoded by
# `ASTDeclaration.get_id` after encoding the template prefix, for
# consistency with the existing v4 format used when only a single
# requires-clause was supported.
assert version >= 2
res = []
res.append("I")
for param in self.params:
res.append(param.get_id(version))
res.append("E")
if not exclude_requires and self.requiresClause:
res.append('IQ')
res.append(self.requiresClause.expr.get_id(version))
res.append('E')
return ''.join(res)
def _stringify(self, transform: StringifyTransform) -> str:
@ -3705,6 +3717,9 @@ class ASTTemplateParams(ASTBase):
res.append("template<")
res.append(", ".join(transform(a) for a in self.params))
res.append("> ")
if self.requiresClause is not None:
res.append(transform(self.requiresClause))
res.append(" ")
return ''.join(res)
def describe_signature(self, signode: TextElement, mode: str,
@ -3719,6 +3734,9 @@ class ASTTemplateParams(ASTBase):
first = False
param.describe_signature(signode, mode, env, symbol)
signode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
signode += addnodes.desc_sig_space()
self.requiresClause.describe_signature(signode, mode, env, symbol)
def describe_signature_as_introducer(
self, parentNode: desc_signature, mode: str, env: "BuildEnvironment",
@ -3743,6 +3761,11 @@ class ASTTemplateParams(ASTBase):
if lineSpec and not first:
lineNode = makeLine(parentNode)
lineNode += addnodes.desc_sig_punctuation('>', '>')
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
parentNode += reqNode
self.requiresClause.describe_signature(reqNode, 'markType', env, symbol)
# Template introducers
@ -3865,8 +3888,12 @@ class ASTTemplateDeclarationPrefix(ASTBase):
assert version >= 2
# this is not part of a normal name mangling system
res = []
for t in self.templates:
res.append(t.get_id(version))
last_index = len(self.templates) - 1
for i, t in enumerate(self.templates):
if isinstance(t, ASTTemplateParams):
res.append(t.get_id(version, exclude_requires=(i == last_index)))
else:
res.append(t.get_id(version))
return ''.join(res)
def _stringify(self, transform: StringifyTransform) -> str:
@ -3889,7 +3916,7 @@ class ASTRequiresClause(ASTBase):
def _stringify(self, transform: StringifyTransform) -> str:
return 'requires ' + transform(self.expr)
def describe_signature(self, signode: addnodes.desc_signature_line, mode: str,
def describe_signature(self, signode: nodes.TextElement, mode: str,
env: "BuildEnvironment", symbol: "Symbol") -> None:
signode += addnodes.desc_sig_keyword('requires', 'requires')
signode += addnodes.desc_sig_space()
@ -3900,16 +3927,16 @@ class ASTRequiresClause(ASTBase):
################################################################################
class ASTDeclaration(ASTBase):
def __init__(self, objectType: str, directiveType: str, visibility: str,
templatePrefix: ASTTemplateDeclarationPrefix,
requiresClause: ASTRequiresClause, declaration: Any,
trailingRequiresClause: ASTRequiresClause,
def __init__(self, objectType: str, directiveType: Optional[str] = None,
visibility: Optional[str] = None,
templatePrefix: Optional[ASTTemplateDeclarationPrefix] = None,
declaration: Any = None,
trailingRequiresClause: Optional[ASTRequiresClause] = None,
semicolon: bool = False) -> None:
self.objectType = objectType
self.directiveType = directiveType
self.visibility = visibility
self.templatePrefix = templatePrefix
self.requiresClause = requiresClause
self.declaration = declaration
self.trailingRequiresClause = trailingRequiresClause
self.semicolon = semicolon
@ -3920,11 +3947,10 @@ class ASTDeclaration(ASTBase):
def clone(self) -> "ASTDeclaration":
templatePrefixClone = self.templatePrefix.clone() if self.templatePrefix else None
requiresClasueClone = self.requiresClause.clone() if self.requiresClause else None
trailingRequiresClasueClone = self.trailingRequiresClause.clone() \
if self.trailingRequiresClause else None
return ASTDeclaration(self.objectType, self.directiveType, self.visibility,
templatePrefixClone, requiresClasueClone,
templatePrefixClone,
self.declaration.clone(), trailingRequiresClasueClone,
self.semicolon)
@ -3932,6 +3958,18 @@ class ASTDeclaration(ASTBase):
def name(self) -> ASTNestedName:
return self.declaration.name
@property
def requiresClause(self) -> Optional[ASTRequiresClause]:
templatePrefix = self.templatePrefix
if templatePrefix is None:
return None
if not templatePrefix.templates:
return None
last_template = templatePrefix.templates[-1]
if not isinstance(last_template, ASTTemplateParams):
return None
return last_template.requiresClause
@property
def function_params(self) -> List[ASTFunctionParameter]:
if self.objectType != 'function':
@ -3940,7 +3978,7 @@ class ASTDeclaration(ASTBase):
def get_id(self, version: int, prefixed: bool = True) -> str:
if version == 1:
if self.templatePrefix:
if self.templatePrefix or self.trailingRequiresClause:
raise NoOldIdError()
if self.objectType == 'enumerator' and self.enumeratorScopedSymbol:
return self.enumeratorScopedSymbol.declaration.get_id(version)
@ -3954,14 +3992,17 @@ class ASTDeclaration(ASTBase):
res = []
if self.templatePrefix:
res.append(self.templatePrefix.get_id(version))
if self.requiresClause or self.trailingRequiresClause:
# Encode the last requires clause specially to avoid introducing a new
# id version number.
requiresClause = self.requiresClause
if requiresClause or self.trailingRequiresClause:
if version < 4:
raise NoOldIdError()
res.append('IQ')
if self.requiresClause and self.trailingRequiresClause:
if requiresClause and self.trailingRequiresClause:
res.append('aa')
if self.requiresClause:
res.append(self.requiresClause.expr.get_id(version))
if requiresClause:
res.append(requiresClause.expr.get_id(version))
if self.trailingRequiresClause:
res.append(self.trailingRequiresClause.expr.get_id(version))
res.append('E')
@ -3978,9 +4019,6 @@ class ASTDeclaration(ASTBase):
res.append(' ')
if self.templatePrefix:
res.append(transform(self.templatePrefix))
if self.requiresClause:
res.append(transform(self.requiresClause))
res.append(' ')
res.append(transform(self.declaration))
if self.trailingRequiresClause:
res.append(' ')
@ -4005,11 +4043,6 @@ class ASTDeclaration(ASTBase):
self.templatePrefix.describe_signature(signode, mode, env,
symbol=self.symbol,
lineSpec=options.get('tparam-line-spec'))
if self.requiresClause:
reqNode = addnodes.desc_signature_line()
reqNode.sphinx_line_type = 'requiresClause'
signode.append(reqNode)
self.requiresClause.describe_signature(reqNode, 'markType', env, self.symbol)
signode += mainDeclNode
if self.visibility and self.visibility != "public":
mainDeclNode += addnodes.desc_sig_keyword(self.visibility, self.visibility)
@ -4192,7 +4225,7 @@ class Symbol:
continue
# only add a declaration if we our self are from a declaration
if self.declaration:
decl = ASTDeclaration('templateParam', None, None, None, None, tp, None)
decl = ASTDeclaration(objectType='templateParam', declaration=tp)
else:
decl = None
nne = ASTNestedNameElement(tp.get_identifier(), None)
@ -4207,7 +4240,7 @@ class Symbol:
if nn is None:
continue
# (comparing to the template params: we have checked that we are a declaration)
decl = ASTDeclaration('functionParam', None, None, None, None, fp, None)
decl = ASTDeclaration(objectType='functionParam', declaration=fp)
assert not nn.rooted
assert len(nn.names) == 1
self._add_symbols(nn, [], decl, self.docname, self.line)
@ -6761,7 +6794,8 @@ class DefinitionParser(BaseParser):
err = eParam
self.skip_ws()
if self.skip_string('>'):
return ASTTemplateParams(templateParams)
requiresClause = self._parse_requires_clause()
return ASTTemplateParams(templateParams, requiresClause)
elif self.skip_string(','):
continue
else:
@ -6883,6 +6917,8 @@ class DefinitionParser(BaseParser):
return ASTTemplateDeclarationPrefix(None)
else:
raise e
if objectType == 'concept' and params.requiresClause is not None:
self.fail('requires-clause not allowed for concept')
else:
params = self._parse_template_introduction()
if not params:
@ -6931,7 +6967,7 @@ class DefinitionParser(BaseParser):
newTemplates: List[Union[ASTTemplateParams, ASTTemplateIntroduction]] = []
for _i in range(numExtra):
newTemplates.append(ASTTemplateParams([]))
newTemplates.append(ASTTemplateParams([], requiresClause=None))
if templatePrefix and not isMemberInstantiation:
newTemplates.extend(templatePrefix.templates)
templatePrefix = ASTTemplateDeclarationPrefix(newTemplates)
@ -6947,7 +6983,6 @@ class DefinitionParser(BaseParser):
raise Exception('Internal error, unknown directiveType "%s".' % directiveType)
visibility = None
templatePrefix = None
requiresClause = None
trailingRequiresClause = None
declaration: Any = None
@ -6957,8 +6992,6 @@ class DefinitionParser(BaseParser):
if objectType in ('type', 'concept', 'member', 'function', 'class'):
templatePrefix = self._parse_template_declaration_prefix(objectType)
if objectType == 'function' and templatePrefix is not None:
requiresClause = self._parse_requires_clause()
if objectType == 'type':
prevErrors = []
@ -7003,7 +7036,7 @@ class DefinitionParser(BaseParser):
self.skip_ws()
semicolon = self.skip_string(';')
return ASTDeclaration(objectType, directiveType, visibility,
templatePrefix, requiresClause, declaration,
templatePrefix, declaration,
trailingRequiresClause, semicolon)
def parse_namespace_object(self) -> ASTNamespace:

View File

@ -893,6 +893,24 @@ def test_domain_cpp_ast_requires_clauses():
{4: 'I0EIQoo1Aoo1B1CE1fvv'})
check('function', 'template<typename T> requires A && B || C and D void f()',
{4: 'I0EIQooaa1A1Baa1C1DE1fvv'})
check('function',
'template<typename T> requires R<T> ' +
'template<typename U> requires S<T> ' +
'void A<T>::f() requires B',
{4: 'I0EIQ1RI1TEEI0EIQaa1SI1TE1BEN1AI1TE1fEvv'})
check('function',
'template<template<typename T> requires R<T> typename X> ' +
'void f()',
{2: 'II0EIQ1RI1TEE0E1fv', 4: 'II0EIQ1RI1TEE0E1fvv'})
check('type',
'template<typename T> requires IsValid<T> {key}T = true_type',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='using')
check('class',
'template<typename T> requires IsValid<T> {key}T : Base',
{4: 'I0EIQ7IsValidI1TEE1T'}, key='class')
check('member',
'template<typename T> requires IsValid<T> int Val = 7',
{4: 'I0EIQ7IsValidI1TEE3Val'})
def test_domain_cpp_ast_template_args():