[C++] Support conditional operator "?"

This commit is contained in:
Jeremy Maitin-Shepard 2022-03-17 18:13:47 -07:00 committed by Jakob Lykke Andersen
parent 128f0ccc77
commit 2d2e0ac01a
2 changed files with 63 additions and 7 deletions

View File

@ -529,7 +529,8 @@ _id_operator_v2 = {
'->': 'pt', '->': 'pt',
'()': 'cl', '()': 'cl',
'[]': 'ix', '[]': 'ix',
'.*': 'ds' # this one is not overloadable, but we need it for expressions '.*': 'ds', # this one is not overloadable, but we need it for expressions
'?': 'cn',
} }
_id_operator_unary_v2 = { _id_operator_unary_v2 = {
'++': 'pp_', '++': 'pp_',
@ -1518,6 +1519,44 @@ class ASTBinOpExpr(ASTExpression):
self.exprs[i].describe_signature(signode, mode, env, symbol) self.exprs[i].describe_signature(signode, mode, env, symbol)
class ASTConditionalExpr(ASTExpression):
def __init__(self, if_expr: ASTExpression, then_expr: ASTExpression,
else_expr: ASTExpression):
self.if_expr = if_expr
self.then_expr = then_expr
self.else_expr = else_expr
def _stringify(self, transform: StringifyTransform) -> str:
res = []
res.append(transform(self.if_expr))
res.append(' ? ')
res.append(transform(self.then_expr))
res.append(' : ')
res.append(transform(self.else_expr))
return ''.join(res)
def get_id(self, version: int) -> str:
assert version >= 2
res = []
res.append(_id_operator_v2['?'])
res.append(self.if_expr.get_id(version))
res.append(self.then_expr.get_id(version))
res.append(self.else_expr.get_id(version))
return ''.join(res)
def describe_signature(self, signode: TextElement, mode: str,
env: "BuildEnvironment", symbol: "Symbol") -> None:
self.if_expr.describe_signature(signode, mode, env, symbol)
signode += addnodes.desc_sig_space()
signode += addnodes.desc_sig_operator('?', '?')
signode += addnodes.desc_sig_space()
self.then_expr.describe_signature(signode, mode, env, symbol)
signode += addnodes.desc_sig_space()
signode += addnodes.desc_sig_operator(':', ':')
signode += addnodes.desc_sig_space()
self.else_expr.describe_signature(signode, mode, env, symbol)
class ASTBracedInitList(ASTBase): class ASTBracedInitList(ASTBase):
def __init__(self, exprs: List[Union[ASTExpression, "ASTBracedInitList"]], def __init__(self, exprs: List[Union[ASTExpression, "ASTBracedInitList"]],
trailingComma: bool) -> None: trailingComma: bool) -> None:
@ -5613,9 +5652,17 @@ class DefinitionParser(BaseParser):
return ASTBinOpExpr(exprs, ops) return ASTBinOpExpr(exprs, ops)
return _parse_bin_op_expr(self, 0, inTemplate=inTemplate) return _parse_bin_op_expr(self, 0, inTemplate=inTemplate)
def _parse_conditional_expression_tail(self, orExprHead: Any) -> None: def _parse_conditional_expression_tail(self, orExprHead: ASTExpression,
inTemplate: bool) -> Optional[ASTConditionalExpr]:
# -> "?" expression ":" assignment-expression # -> "?" expression ":" assignment-expression
if not self.skip_string("?"):
return None return None
then_expr = self._parse_expression()
self.skip_ws()
if not self.skip_string(":"):
self.fail('Expected ":" after "?"')
else_expr = self._parse_assignment_expression(inTemplate)
return ASTConditionalExpr(orExprHead, then_expr, else_expr)
def _parse_assignment_expression(self, inTemplate: bool) -> ASTExpression: def _parse_assignment_expression(self, inTemplate: bool) -> ASTExpression:
# -> conditional-expression # -> conditional-expression
@ -5631,10 +5678,15 @@ class DefinitionParser(BaseParser):
ops = [] ops = []
orExpr = self._parse_logical_or_expression(inTemplate=inTemplate) orExpr = self._parse_logical_or_expression(inTemplate=inTemplate)
exprs.append(orExpr) exprs.append(orExpr)
# TODO: handle ternary with _parse_conditional_expression_tail
while True: while True:
oneMore = False oneMore = False
self.skip_ws() self.skip_ws()
prev_expr = exprs[-1]
if isinstance(prev_expr, ASTExpression):
cond_expr = self._parse_conditional_expression_tail(prev_expr, inTemplate)
if cond_expr is not None:
exprs[-1] = cond_expr
continue
for op in _expression_assignment_ops: for op in _expression_assignment_ops:
if op[0] in 'anox': if op[0] in 'anox':
if not self.skip_word(op): if not self.skip_word(op):
@ -5649,14 +5701,17 @@ class DefinitionParser(BaseParser):
if not oneMore: if not oneMore:
break break
if len(ops) == 0: if len(ops) == 0:
return orExpr return cast(ASTExpression, exprs[-1])
else: else:
return ASTAssignmentExpr(exprs, ops) return ASTAssignmentExpr(exprs, ops)
def _parse_constant_expression(self, inTemplate: bool) -> ASTExpression: def _parse_constant_expression(self, inTemplate: bool) -> ASTExpression:
# -> conditional-expression # -> conditional-expression
orExpr = self._parse_logical_or_expression(inTemplate=inTemplate) orExpr = self._parse_logical_or_expression(inTemplate=inTemplate)
# TODO: use _parse_conditional_expression_tail self.skip_ws()
cond_expr = self._parse_conditional_expression_tail(orExpr, inTemplate)
if cond_expr is not None:
return cond_expr
return orExpr return orExpr
def _parse_expression(self) -> ASTExpression: def _parse_expression(self) -> ASTExpression:

View File

@ -326,7 +326,8 @@ def test_domain_cpp_ast_expressions():
exprCheck('5 .* 42', 'dsL5EL42E') exprCheck('5 .* 42', 'dsL5EL42E')
exprCheck('5 ->* 42', 'pmL5EL42E') exprCheck('5 ->* 42', 'pmL5EL42E')
# conditional # conditional
# TODO exprCheck('5 ? 7 : 3', 'cnL5EL7EL3E')
exprCheck('5 = 6 ? 7 = 8 : 3', 'aSL5EcnL6EaSL7EL8EL3E')
# assignment # assignment
exprCheck('a = 5', 'aS1aL5E') exprCheck('a = 5', 'aS1aL5E')
exprCheck('a *= 5', 'mL1aL5E') exprCheck('a *= 5', 'mL1aL5E')