
549 lines
20 KiB

import sys
import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import last_word, find_prev_keyword,\
from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes
Special = namedtuple("Special", [])
Database = namedtuple("Database", [])
Schema = namedtuple("Schema", ["quoted"])
Schema.__new__.__defaults__ = (False,)
# FromClauseItem is a table/view/function used in the FROM clause
# `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
TableFormat = namedtuple("TableFormat", [])
View = namedtuple("View", ["schema", "table_refs"])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple("Join", ["table_refs", "schema"])
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
# For convenience, don't require the `usage` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple(), tuple())
View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
Column = namedtuple(
["table_refs", "require_last_table", "local_tables", "qualifiable",
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
Keyword = namedtuple("Keyword", ["last_token"])
Keyword.__new__.__defaults__ = (None,)
NamedQuery = namedtuple("NamedQuery", [])
Datatype = namedtuple("Datatype", ["schema"])
Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", [])
class SqlStatement(object):
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
text_before_cursor, include="many_punctuations"
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
full_text, text_before_cursor
self.text_before_cursor_including_last_word = text_before_cursor
# If we've partially typed a word then word_before_cursor won't be an
# empty string. In that case we want to remove the partially typed
# string before sending it to the sqlparser. Otherwise the last token
# will always be the partially typed string which renders the smart
# completion useless because it will always return the list of
# keywords as completion.
if self.word_before_cursor:
if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
parsed = sqlparse.parse(text_before_cursor)
text_before_cursor = \
text_before_cursor[: -len(word_before_cursor)]
parsed = sqlparse.parse(text_before_cursor)
self.identifier = parse_partial_identifier(word_before_cursor)
parsed = sqlparse.parse(text_before_cursor)
full_text, text_before_cursor, parsed = _split_multiple_statements(
full_text, text_before_cursor, parsed
self.full_text = full_text
self.text_before_cursor = text_before_cursor
self.parsed = parsed
self.last_token = parsed and \
parsed.token_prev(len(parsed.tokens))[1] or ""
def is_insert(self):
return self.parsed.token_first().value.lower() == "insert"
def get_tables(self, scope="full"):
"""Gets the tables available in the statement.
param `scope:` possible values: 'full', 'insert', 'before'
If 'insert', only the first table is returned.
If 'before', only tables before the cursor are returned.
If not 'insert' and the stmt is an insert, the first table is skipped.
tables = extract_tables(
self.full_text if scope == "full" else self.text_before_cursor
if scope == "insert":
tables = tables[:1]
elif self.is_insert():
tables = tables[1:]
return tables
def get_previous_token(self, token):
return self.parsed.token_prev(self.parsed.token_index(token))[1]
def get_identifier_schema(self):
schema = \
(self.identifier and self.identifier.get_parent_name()) or None
# If schema name is unquoted, lower-case it
if schema and self.identifier.value[0] != '"':
schema = schema.lower()
return schema
def reduce_to_prev_keyword(self, n_skip=0):
prev_keyword, self.text_before_cursor = find_prev_keyword(
self.text_before_cursor, n_skip=n_skip
return prev_keyword
def suggest_type(full_text, text_before_cursor):
"""Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
A scope for a column category will be a list of tables.
if full_text.startswith("\\i "):
return (Path(),)
# This is a temporary hack; the exception handling
# here should be removed once sqlparse has been fixed
stmt = SqlStatement(full_text, text_before_cursor)
except (TypeError, AttributeError):
return []
return suggest_based_on_last_token(stmt.last_token, stmt)
named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
def _strip_named_query(txt):
This will strip "save named query" command in the beginning of the line:
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
if named_query_regex.match(txt):
txt = named_query_regex.sub("", txt)
return txt
function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
def _find_function_body(text):
split =
return (split.start(2), split.end(2)) if split else (None, None)
def _statement_from_function(full_text, text_before_cursor, statement):
current_pos = len(text_before_cursor)
body_start, body_end = _find_function_body(full_text)
if body_start is None:
return full_text, text_before_cursor, statement
if not body_start <= current_pos < body_end:
return full_text, text_before_cursor, statement
full_text = full_text[body_start:body_end]
text_before_cursor = text_before_cursor[body_start:]
parsed = sqlparse.parse(text_before_cursor)
return _split_multiple_statements(full_text, text_before_cursor, parsed)
def _split_multiple_statements(full_text, text_before_cursor, parsed):
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
# cumulatively summing statement lengths to find the one that bounds
# the current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
for statement in parsed:
stmt_len = len(str(statement))
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
if stmt_end >= current_pos:
text_before_cursor = full_text[stmt_start:current_pos]
full_text = full_text[stmt_start:]
elif parsed:
# A single statement
statement = parsed[0]
# The empty string
return full_text, text_before_cursor, None
token2 = None
if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
token1 = statement.token_first()
if token1:
token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == "FUNCTION":
full_text, text_before_cursor, statement = _statement_from_function(
full_text, text_before_cursor, statement
return full_text, text_before_cursor, statement
def suggest_based_on_last_token(token, stmt):
if isinstance(token, str):
token_v = token.lower()
elif isinstance(token, Comparison):
# If 'token' is a Comparison type such as
# 'select * FROM abc a JOIN def d ON = d.'. Then calling
# token.value on the comparison type will only return the lhs of the
# comparison. In this case So we need to do token.tokens to get
# both sides of the comparison and pick the last token out of that
# list.
token_v = token.tokens[-1].value.lower()
elif isinstance(token, Where):
# sqlparse groups all tokens from the where clause into a single token
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword = stmt.reduce_to_prev_keyword()
return suggest_based_on_last_token(prev_keyword, stmt)
elif isinstance(token, Identifier):
# If the previous token is an identifier, we can suggest datatypes if
# we're in a parenthesized column/field list, e.g.:
# CREATE TABLE foo (Identifier <CURSOR>
# CREATE FUNCTION foo (Identifier <CURSOR>
# If we're not in a parenthesized list, the most likely scenario is the
# user is about to specify an alias, e.g.:
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == "(":
# Suggest datatypes
return suggest_based_on_last_token("type", stmt)
return (Keyword(),)
token_v = token.value.lower()
if not token:
return (Keyword(),)
elif token_v.endswith("("):
p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
# Four possibilities:
# 1 - Parenthesized clause like "WHERE foo AND ("
# Suggest columns/functions
# 2 - Function call like "WHERE foo("
# Suggest columns/functions
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to
# be really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token("where", stmt)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
prev_tok = where.token_prev(len(where.tokens) - 1)[1]
if isinstance(prev_tok, Comparison):
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == "exists":
return (Keyword(),)
return column_suggestions
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
if (
prev_tok and prev_tok.value and
prev_tok.value.lower().split(" ")[-1] == "using"
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables("before")
# suggest columns that are present in more than one table
return (
elif p.token_first().value.lower() == "select":
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor,
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
return (Column(table_refs=stmt.get_tables("insert"),
# We're probably in a function argument list
return _suggest_expression(token_v, stmt)
elif token_v == "set":
return (Column(table_refs=stmt.get_tables(),
elif token_v in ("select", "where", "having", "order by", "distinct"):
return _suggest_expression(token_v, stmt)
elif token_v == "as":
# Don't suggest anything for aliases
return ()
elif (token_v.endswith("join") and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate")
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith("join") and token.is_keyword
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
if token_v == "from" or is_join:
schema=schema, table_refs=tables,
elif token_v == "truncate":
suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables("before")
suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest)
elif token_v == "function":
schema = stmt.get_identifier_schema()
# stmt.get_previous_token will fail for e.g.
# `SELECT 1 FROM functions WHERE function:`
prev = stmt.get_previous_token(token).value.lower()
if prev in ("drop", "alter", "create", "create or replace"):
# Suggest functions from either the currently-selected schema
# or the public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
suggest.append(Function(schema=schema, usage="signature"))
return tuple(suggest)
except ValueError:
return tuple()
elif token_v in ("table", "view"):
# E.g. 'ALTER TABLE <tablname>'
rel_type = \
{"table": Table, "view": View, "function": Function}[token_v]
schema = stmt.get_identifier_schema()
if schema:
return (rel_type(schema=schema),)
return (Schema(), rel_type(schema=schema))
elif token_v == "column":
return (Column(table_refs=stmt.get_tables()),)
elif token_v == "on":
tables = stmt.get_tables("before")
parent = \
(stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
sugs = [
if filteredtables and _allow_join_condition(stmt.parsed):
return tuple(sugs)
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed):
return (
JoinCondition(table_refs=tables, parent=None),
return (Alias(aliases=aliases),)
elif token_v in ("c", "use", "database", "template"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
return (Database(),)
elif token_v == "schema":
# DROP SCHEMA schema_name, SET SCHEMA schema name
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
quoted = prev_keyword and prev_keyword.value.lower() == "set"
return (Schema(quoted),)
elif token_v.endswith(",") or token_v in ("=", "and", "or"):
prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
return ()
elif token_v in ("type", "::"):
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema), Table(schema=schema)]
if not schema:
return tuple(suggestions)
elif token_v in {"alter", "create", "drop"}:
return (Keyword(token_v.upper()),)
elif token.is_keyword:
# token is a keyword we haven't implemented any special handling for
# go backwards in the query until we find one we do recognize
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
return (Keyword(token_v.upper()),)
return (Keyword(),)
def _suggest_expression(token_v, stmt):
Return suggestions for an expression, taking account of any partially-typed
identifier's parent, which may be a table alias or schema name.
parent = stmt.identifier.get_parent_name() if stmt.identifier else []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (
Column(table_refs=tables, local_tables=stmt.local_tables),
return (
Column(table_refs=tables, local_tables=stmt.local_tables,
def identifies(id, ref):
"""Returns true if string `id` matches TableReference `ref`"""
return (
id == ref.alias or id == or
(ref.schema and (id == ref.schema + "." +
def _allow_join_condition(statement):
Tests if a join condition should be suggested
We need this to avoid bad suggestions when entering e.g.
select * from tbl1 a join tbl2 b on = <cursor>
So check that the preceding token is a ON, AND, or OR keyword, instead of
e.g. an equals sign.
:param statement: an sqlparse.sql.Statement
:return: boolean
if not statement or not statement.tokens:
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower() in ("on", "and", "or")
def _allow_join(statement):
Tests if a join should be suggested
We need this to avoid bad suggestions when entering e.g.
select * from tbl1 a join tbl2 b <cursor>
So check that the preceding token is a JOIN keyword
:param statement: an sqlparse.sql.Statement
:return: boolean
if not statement or not statement.tokens:
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower().endswith("join") and \
last_tok.value.lower() not in ("cross join", "natural join",)