mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-02-09 23:15:58 -06:00
1. Use concise character class syntax 2. Add a "break" statement or remove this "else" clause. 3. Replace this generic exception class with a more specific one. 4. Use a regular expression literal instead of the 'RegExp' constructor. 5. Use the opposite operator ("not in") instead.
143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
from sqlparse import parse
|
|
from sqlparse.tokens import Keyword, CTE, DML
|
|
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
|
|
from collections import namedtuple
|
|
from .meta import TableMetadata, ColumnMetadata
|
|
|
|
|
|
# TableExpression is a namedtuple representing a CTE, used internally
|
|
# name: cte alias assigned in the query
|
|
# columns: list of column names
|
|
# start: index into the original string of the left parens starting the CTE
|
|
# stop: index into the original string of the right parens ending the CTE
|
|
TableExpression = namedtuple("TableExpression", "name columns start stop")
|
|
|
|
|
|
def isolate_query_ctes(full_text, text_before_cursor):
|
|
"""Simplify a query by converting CTEs into table metadata objects"""
|
|
|
|
if not full_text or not full_text.strip():
|
|
return full_text, text_before_cursor, tuple()
|
|
|
|
ctes, remainder = extract_ctes(full_text)
|
|
if not ctes:
|
|
return full_text, text_before_cursor, ()
|
|
|
|
current_position = len(text_before_cursor)
|
|
meta = []
|
|
|
|
for cte in ctes:
|
|
if cte.start < current_position < cte.stop:
|
|
# Currently editing a cte - treat its body as the current full_text
|
|
text_before_cursor = full_text[cte.start: current_position]
|
|
full_text = full_text[cte.start: cte.stop]
|
|
return full_text, text_before_cursor, meta
|
|
|
|
# Append this cte to the list of available table metadata
|
|
cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
|
|
meta.append(TableMetadata(cte.name, cols))
|
|
|
|
# Editing past the last cte (ie the main body of the query)
|
|
full_text = full_text[ctes[-1].stop:]
|
|
text_before_cursor = text_before_cursor[ctes[-1].stop: current_position]
|
|
|
|
return full_text, text_before_cursor, tuple(meta)
|
|
|
|
|
|
def extract_ctes(sql):
|
|
"""Extract constant table expresseions from a query
|
|
|
|
Returns tuple (ctes, remainder_sql)
|
|
|
|
ctes is a list of TableExpression namedtuples
|
|
remainder_sql is the text from the original query after the CTEs have
|
|
been stripped.
|
|
"""
|
|
|
|
p = parse(sql)[0]
|
|
|
|
# Make sure the first meaningful token is "WITH" which is necessary to
|
|
# define CTEs
|
|
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
|
|
if not (tok and tok.ttype == CTE):
|
|
return [], sql
|
|
|
|
# Get the next (meaningful) token, which should be the first CTE
|
|
idx, tok = p.token_next(idx)
|
|
if not tok:
|
|
return ([], "")
|
|
start_pos = token_start_pos(p.tokens, idx)
|
|
ctes = []
|
|
|
|
if isinstance(tok, IdentifierList):
|
|
# Multiple ctes
|
|
for t in tok.get_identifiers():
|
|
cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
|
|
cte = get_cte_from_token(t, start_pos + cte_start_offset)
|
|
if not cte:
|
|
continue
|
|
ctes.append(cte)
|
|
elif isinstance(tok, Identifier):
|
|
# A single CTE
|
|
cte = get_cte_from_token(tok, start_pos)
|
|
if cte:
|
|
ctes.append(cte)
|
|
|
|
idx = p.token_index(tok) + 1
|
|
|
|
# Collapse everything after the ctes into a remainder query
|
|
remainder = "".join(str(tok) for tok in p.tokens[idx:])
|
|
|
|
return ctes, remainder
|
|
|
|
|
|
def get_cte_from_token(tok, pos0):
|
|
cte_name = tok.get_real_name()
|
|
if not cte_name:
|
|
return None
|
|
|
|
# Find the start position of the opening parens enclosing the cte body
|
|
idx, parens = tok.token_next_by(Parenthesis)
|
|
if not parens:
|
|
return None
|
|
|
|
start_pos = pos0 + token_start_pos(tok.tokens, idx)
|
|
cte_len = len(str(parens)) # includes parens
|
|
stop_pos = start_pos + cte_len
|
|
|
|
column_names = extract_column_names(parens)
|
|
|
|
return TableExpression(cte_name, column_names, start_pos, stop_pos)
|
|
|
|
|
|
def extract_column_names(parsed):
|
|
# Find the first DML token to check if it's a
|
|
# SELECT or INSERT/UPDATE/DELETE
|
|
idx, tok = parsed.token_next_by(t=DML)
|
|
tok_val = tok and tok.value.lower()
|
|
|
|
if tok_val in ("insert", "update", "delete"):
|
|
# Jump ahead to the RETURNING clause where the list of column names is
|
|
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
|
|
elif tok_val != "select":
|
|
# Must be invalid CTE
|
|
return ()
|
|
|
|
# The next token should be either a column name, or a list of column names
|
|
idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
|
|
return tuple(t.get_name() for t in _identifiers(tok))
|
|
|
|
|
|
def token_start_pos(tokens, idx):
|
|
return sum(len(str(t)) for t in tokens[:idx])
|
|
|
|
|
|
def _identifiers(tok):
|
|
if isinstance(tok, IdentifierList):
|
|
for t in tok.get_identifiers():
|
|
# NB: IdentifierList.get_identifiers() can return non-identifiers!
|
|
if isinstance(t, Identifier):
|
|
yield t
|
|
elif isinstance(tok, Identifier):
|
|
yield tok
|