Merged the latest code of 'pgcli' used for the autocomplete feature. Fixes #5497

This commit is contained in:
Akshay Joshi 2020-10-01 12:53:45 +05:30
parent 3f817494f8
commit 300de05a20
13 changed files with 574 additions and 818 deletions

View File

@ -16,6 +16,7 @@ Housekeeping
************
| `Issue #5330 <https://redmine.postgresql.org/issues/5330>`_ - Improve code coverage and API test cases for Functions.
| `Issue #5497 <https://redmine.postgresql.org/issues/5497>`_ - Merged the latest code of 'pgcli' used for the autocomplete feature.
Bug fixes
*********

View File

@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
p.proretset is_set_returning,
d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2

View File

@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning,
d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype
AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2

View File

@ -11,8 +11,7 @@
import re
import operator
import sys
from itertools import count, repeat, chain
from itertools import count
from .completion import Completion
from collections import namedtuple, defaultdict, OrderedDict
@ -28,9 +27,9 @@ from pgadmin.utils.driver import get_driver
from config import PG_DEFAULT_DRIVER
from pgadmin.utils.preferences import Preferences
Match = namedtuple('Match', ['completion', 'priority'])
Match = namedtuple("Match", ["completion", "priority"])
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
_SchemaObject = namedtuple("SchemaObject", "name schema meta")
def SchemaObject(name, schema=None, meta=None):
@ -41,14 +40,12 @@ def SchemaObject(name, schema=None, meta=None):
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
_Candidate = namedtuple(
'Candidate', 'completion prio meta synonyms prio2 display'
)
_Candidate = namedtuple("Candidate",
"completion prio meta synonyms prio2 display")
def Candidate(
completion, prio=None, meta=None, synonyms=None, prio2=None,
display=None
completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
):
return _Candidate(
completion, prio, meta, synonyms or [completion], prio2,
@ -57,7 +54,7 @@ def Candidate(
# Used to strip trailing '::some_type' from default-value expressions
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
def normalize_ref(ref):
@ -65,15 +62,15 @@ def normalize_ref(ref):
def generate_alias(tbl):
""" Generate a table alias, consisting of all upper-case letters in
"""Generate a table alias, consisting of all upper-case letters in
the table name, or, if there are no upper-case letters, the first letter +
all letters preceded by _
param tbl - unescaped name of the table to alias
"""
return ''.join(
return "".join(
[letter for letter in tbl if letter.isupper()] or
[letter for letter, prev in zip(tbl, '_' + tbl)
if prev == '_' and letter != '_']
[letter for letter, prev in zip(tbl, "_" + tbl)
if prev == "_" and letter != "_"]
)
@ -97,13 +94,14 @@ class SQLAutoComplete(object):
self.sid = kwargs['sid'] if 'sid' in kwargs else None
self.conn = kwargs['conn'] if 'conn' in kwargs else None
self.keywords = []
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
self.databases = []
self.functions = []
self.datatypes = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.dbmetadata = \
{"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
self.text_before_cursor = None
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
@ -182,7 +180,8 @@ class SQLAutoComplete(object):
def escape_name(self, name):
if name and (
(not self.name_pattern.match(name)) or
(name.upper() in self.reserved_words)
(name.upper() in self.reserved_words) or
(name.upper() in self.functions)
):
name = '"%s"' % name
@ -212,7 +211,7 @@ class SQLAutoComplete(object):
# schemata is a list of schema names
schemata = self.escaped_names(schemata)
metadata = self.dbmetadata['tables']
metadata = self.dbmetadata["tables"]
for schema in schemata:
metadata[schema] = {}
@ -224,7 +223,7 @@ class SQLAutoComplete(object):
self.all_completions.update(schemata)
def extend_casing(self, words):
""" extend casing data
"""extend casing data
:return:
"""
@ -274,7 +273,7 @@ class SQLAutoComplete(object):
name=colname,
datatype=datatype,
has_default=has_default,
default=default
default=default,
)
metadata[schema][relname][colname] = column
self.all_completions.add(colname)
@ -285,7 +284,7 @@ class SQLAutoComplete(object):
# dbmetadata['schema_name']['functions']['function_name'] should return
# the function metadata namedtuple for the corresponding function
metadata = self.dbmetadata['functions']
metadata = self.dbmetadata["functions"]
for f in func_data:
schema, func = self.escaped_names([f.schema_name, f.func_name])
@ -309,10 +308,10 @@ class SQLAutoComplete(object):
self._arg_list_cache = \
dict((usage,
dict((meta, self._arg_list(meta, usage))
for sch, funcs in self.dbmetadata['functions'].items()
for sch, funcs in self.dbmetadata["functions"].items()
for func, metas in funcs.items()
for meta in metas))
for usage in ('call', 'call_display', 'signature'))
for usage in ("call", "call_display", "signature"))
def extend_foreignkeys(self, fk_data):
@ -322,7 +321,7 @@ class SQLAutoComplete(object):
# These are added as a list of ForeignKey namedtuples to the
# ColumnMetadata namedtuple for both the child and parent
meta = self.dbmetadata['tables']
meta = self.dbmetadata["tables"]
for fk in fk_data:
e = self.escaped_names
@ -350,7 +349,7 @@ class SQLAutoComplete(object):
# dbmetadata['datatypes'][schema_name][type_name] should store type
# metadata, such as composite type field names. Currently, we're not
# storing any metadata beyond typename, so just store None
meta = self.dbmetadata['datatypes']
meta = self.dbmetadata["datatypes"]
for t in type_data:
schema, type_name = self.escaped_names(t)
@ -364,11 +363,11 @@ class SQLAutoComplete(object):
self.databases = []
self.special_commands = []
self.search_path = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.dbmetadata = \
{"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
self.all_completions = set(self.keywords + self.functions)
def find_matches(self, text, collection, mode='fuzzy', meta=None):
def find_matches(self, text, collection, mode="strict", meta=None):
"""Find completion matches for the given text.
Given the user's input text and a collection of available
@ -389,17 +388,26 @@ class SQLAutoComplete(object):
collection:
mode:
meta:
meta_collection:
"""
if not collection:
return []
prio_order = [
'keyword', 'function', 'view', 'table', 'datatype', 'database',
'schema', 'column', 'table alias', 'join', 'name join', 'fk join',
'table format'
"keyword",
"function",
"view",
"table",
"datatype",
"database",
"schema",
"column",
"table alias",
"join",
"name join",
"fk join",
"table format",
]
type_priority = prio_order.index(meta) if meta in prio_order else -1
text = last_word(text, include='most_punctuations').lower()
text = last_word(text, include="most_punctuations").lower()
text_len = len(text)
if text and text[0] == '"':
@ -409,7 +417,7 @@ class SQLAutoComplete(object):
# Completion.position value is correct
text = text[1:]
if mode == 'fuzzy':
if mode == "fuzzy":
fuzzy = True
priority_func = self.prioritizer.name_count
else:
@ -422,19 +430,20 @@ class SQLAutoComplete(object):
# Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple
if fuzzy:
regex = '.*?'.join(map(re.escape, text))
pat = re.compile('(%s)' % regex)
regex = ".*?".join(map(re.escape, text))
pat = re.compile("(%s)" % regex)
def _match(item):
if item.lower()[:len(text) + 1] in (text, text + ' '):
if item.lower()[: len(text) + 1] in (text, text + " "):
# Exact match of first word in suggestion
# This is to get exact alias matches to the top
# E.g. for input `e`, 'Entries E' should be on top
# (before e.g. `EndUsers EU`)
return float('Infinity'), -1
return float("Infinity"), -1
r = pat.search(self.unescape_name(item.lower()))
if r:
return -len(r.group()), -r.start()
else:
match_end_limit = len(text)
@ -446,7 +455,7 @@ class SQLAutoComplete(object):
if match_point >= 0:
# Use negative infinity to force keywords to sort after all
# fuzzy matches
return -float('Infinity'), -match_point
return -float("Infinity"), -match_point
matches = []
for cand in collection:
@ -466,7 +475,7 @@ class SQLAutoComplete(object):
if sort_key:
if display_meta and len(display_meta) > 50:
# Truncate meta-text to 50 characters, if necessary
display_meta = display_meta[:47] + '...'
display_meta = display_meta[:47] + "..."
# Lexical order of items in the collection, used for
# tiebreaking items with the same match group length and start
@ -478,14 +487,18 @@ class SQLAutoComplete(object):
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = (
tuple(0 if c in (' _') else -ord(c)
tuple(0 if c in (" _") else -ord(c)
for c in self.unescape_name(item.lower())) + (1,) +
tuple(c for c in item)
)
priority = (
sort_key, type_priority, prio, priority_func(item),
prio2, lexical_priority
sort_key,
type_priority,
prio,
priority_func(item),
prio2,
lexical_priority,
)
matches.append(
Match(
@ -493,9 +506,9 @@ class SQLAutoComplete(object):
text=item,
start_position=-text_len,
display_meta=display_meta,
display=display
display=display,
),
priority=priority
priority=priority,
)
)
return matches
@ -516,8 +529,8 @@ class SQLAutoComplete(object):
matches.extend(matcher(self, suggestion, word_before_cursor))
# Sort matches so highest priorities are first
matches = sorted(matches, key=operator.attrgetter('priority'),
reverse=True)
matches = \
sorted(matches, key=operator.attrgetter("priority"), reverse=True)
result = dict()
for m in matches:
@ -539,23 +552,28 @@ class SQLAutoComplete(object):
tables = suggestion.table_refs
do_qualify = suggestion.qualifiable and {
'always': True, 'never': False,
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns]
"always": True,
"never": False,
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
def qualify(col, tbl):
return (tbl + '.' + col) if do_qualify else col
scoped_cols = self.populate_scoped_cols(
tables, suggestion.local_tables
)
scoped_cols = \
self.populate_scoped_cols(tables, suggestion.local_tables)
def make_cand(name, ref):
synonyms = (name, generate_alias(name))
return Candidate(qualify(name, ref), 0, 'column', synonyms)
return Candidate(qualify(name, ref), 0, "column", synonyms)
def flat_cols():
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items()
for c in cols]
return [
make_cand(c.name, t.ref)
for t, cols in scoped_cols.items()
for c in cols
]
if suggestion.require_last_table:
# require_last_table is used for 'tb11 JOIN tbl2 USING
# (...' which should
@ -569,10 +587,11 @@ class SQLAutoComplete(object):
dict((t, [col for col in cols if col.name in other_tbl_cols])
for t, cols in scoped_cols.items() if t.ref == ltbl)
lastword = last_word(word_before_cursor, include='most_punctuations')
if lastword == '*':
if suggestion.context == 'insert':
def is_scoped(col):
lastword = last_word(word_before_cursor, include="most_punctuations")
if lastword == "*":
if suggestion.context == "insert":
def filter(col):
if not col.has_default:
return True
return not any(
@ -580,40 +599,39 @@ class SQLAutoComplete(object):
for p in self.insert_col_skip_patterns
)
scoped_cols = \
dict((t, [col for col in cols if is_scoped(col)])
dict((t, [col for col in cols if filter(col)])
for t, cols in scoped_cols.items())
if self.asterisk_column_order == 'alphabetic':
if self.asterisk_column_order == "alphabetic":
for cols in scoped_cols.values():
cols.sort(key=operator.attrgetter('name'))
cols.sort(key=operator.attrgetter("name"))
if (
lastword != word_before_cursor and
len(tables) == 1 and
word_before_cursor[-len(lastword) - 1] == '.'
word_before_cursor[-len(lastword) - 1] == "."
):
# User typed x.*; replicate "x." for all columns except the
# first, which gets the original (as we only replace the "*"")
sep = ', ' + word_before_cursor[:-1]
sep = ", " + word_before_cursor[:-1]
collist = sep.join(c.completion for c in flat_cols())
else:
collist = ', '.join(qualify(c.name, t.ref)
collist = ", ".join(qualify(c.name, t.ref)
for t, cs in scoped_cols.items()
for c in cs)
return [Match(
completion=Completion(
collist,
-1,
display_meta='columns',
display='*'
),
priority=(1, 1, 1)
)]
return [
Match(
completion=Completion(
collist, -1, display_meta="columns", display="*"
),
priority=(1, 1, 1),
)
]
return self.find_matches(word_before_cursor, flat_cols(),
mode='strict', meta='column')
meta="column")
def alias(self, tbl, tbls):
""" Generate a unique table alias
"""Generate a unique table alias
tbl - name of the table to alias, quoted if it needs to be
tbls - TableReference iterable of tables already in query
"""
@ -628,25 +646,6 @@ class SQLAutoComplete(object):
aliases = (tbl + str(i) for i in count(2))
return next(a for a in aliases if normalize_ref(a) not in tbls)
def _check_for_aliases(self, left, refs, rtbl, suggestion, right):
"""
Check for generate aliases and return join value
:param left:
:param refs:
:param rtbl:
:param suggestion:
:param right:
:return: return join string.
"""
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
left.tbl, left.col, rtbl.ref, right.col, lref)
else:
join = '{0} ON {0}.{1} = {2}.{3}'.format(
left.tbl, left.col, rtbl.ref, right.col)
return join
def get_join_matches(self, suggestion, word_before_cursor):
tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
@ -658,10 +657,12 @@ class SQLAutoComplete(object):
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = (
(fk, rtbl, rcol) for rtbl, rcols in cols.items()
for rcol in rcols for fk in rcol.foreignkeys
(fk, rtbl, rcol)
for rtbl, rcols in cols.items()
for rcol in rcols
for fk in rcol.foreignkeys
)
col = namedtuple('col', 'schema tbl col')
col = namedtuple("col", "schema tbl col")
for fk, rtbl, rcol in fks:
right = col(rtbl.schema, rtbl.name, rcol.name)
child = col(fk.childschema, fk.childtable, fk.childcolumn)
@ -670,54 +671,38 @@ class SQLAutoComplete(object):
if suggestion.schema and left.schema != suggestion.schema:
continue
join = self._check_for_aliases(left, refs, rtbl, suggestion, right)
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
left.tbl, left.col, rtbl.ref, right.col, lref
)
else:
join = "{0} ON {0}.{1} = {2}.{3}".format(
left.tbl, left.col, rtbl.ref, right.col
)
alias = generate_alias(left.tbl)
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format(
alias, left.col, rtbl.ref, right.col)]
synonyms = [
join,
"{0} ON {0}.{1} = {2}.{3}".format(
alias, left.col, rtbl.ref, right.col
),
]
# Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and \
(qualified[normalize_ref(rtbl.ref)] and
left.schema == right.schema or
left.schema not in (right.schema, 'public')):
join = left.schema + '.' + join
left.schema not in (right.schema, "public")):
join = left.schema + "." + join
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
0 if (left.schema, left.tbl) in other_tbls else 1)
joins.append(Candidate(join, prio, 'join', synonyms=synonyms))
0 if (left.schema, left.tbl) in other_tbls else 1
)
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
return self.find_matches(word_before_cursor, joins,
mode='strict', meta='join')
def list_dict(self, pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
def add_cond(self, lcol, rcol, rref, prio, meta, **kwargs):
"""
Add Condition in join
:param lcol:
:param rcol:
:param rref:
:param prio:
:param meta:
:param kwargs:
:return:
"""
suggestion = kwargs['suggestion']
found_conds = kwargs['found_conds']
ltbl = kwargs['ltbl']
ref_prio = kwargs['ref_prio']
conds = kwargs['conds']
prefix = '' if suggestion.parent else ltbl.ref + '.'
cond = prefix + lcol + ' = ' + rref + '.' + rcol
if cond not in found_conds:
found_conds.add(cond)
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
return self.find_matches(word_before_cursor, joins, meta="join")
def get_join_condition_matches(self, suggestion, word_before_cursor):
col = namedtuple('col', 'schema tbl col')
col = namedtuple("col", "schema tbl col")
tbls = self.populate_scoped_cols(suggestion.table_refs).items
cols = [(t, c) for t, cs in tbls() for c in cs]
try:
@ -727,11 +712,24 @@ class SQLAutoComplete(object):
return []
conds, found_conds = [], set()
def add_cond(lcol, rcol, rref, prio, meta):
prefix = "" if suggestion.parent else ltbl.ref + "."
cond = prefix + lcol + " = " + rref + "." + rcol
if cond not in found_conds:
found_conds.add(cond)
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
# Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num)
for num, tbl in enumerate(suggestion.table_refs))
# Map (schema, table, col) to tables
coldict = self.list_dict(
coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
)
# For each fk from the left table, generate a join condition if
@ -742,89 +740,76 @@ class SQLAutoComplete(object):
child = col(fk.childschema, fk.childtable, fk.childcolumn)
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left, right = (child, par) if left == child else (par, child)
for rtbl in coldict[right]:
kwargs = {
"suggestion": suggestion,
"found_conds": found_conds,
"ltbl": ltbl,
"conds": conds,
"ref_prio": ref_prio
}
self.add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join',
**kwargs)
add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
# For name matching, use a {(colname, coltype): TableReference} dict
coltyp = namedtuple('coltyp', 'name datatype')
col_table = self.list_dict(
(coltyp(c.name, c.datatype), t) for t, c in cols)
coltyp = namedtuple("coltyp", "name datatype")
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
# Find all name-match join conditions
for c in (coltyp(c.name, c.datatype) for c in lcols):
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
kwargs = {
"suggestion": suggestion,
"found_conds": found_conds,
"ltbl": ltbl,
"conds": conds,
"ref_prio": ref_prio
}
prio = 1000 if c.datatype in (
'integer', 'bigint', 'smallint') else 0
self.add_cond(c.name, c.name, rtbl.ref, prio, 'name join',
**kwargs)
"integer", "bigint", "smallint") else 0
add_cond(c.name, c.name, rtbl.ref, prio, "name join")
return self.find_matches(word_before_cursor, conds,
mode='strict', meta='join')
return self.find_matches(word_before_cursor, conds, meta="join")
def get_function_matches(self, suggestion, word_before_cursor,
alias=False):
if suggestion.usage == 'from':
if suggestion.usage == "from":
# Only suggest functions allowed in FROM clause
def filt(f):
return not f.is_aggregate and not f.is_window
return (
not f.is_aggregate and not f.is_window and
not f.is_extension and
(f.is_public or f.schema_name == suggestion.schema)
)
else:
alias = False
def filt(f):
return True
return not f.is_extension and (
f.is_public or f.schema_name == suggestion.schema
)
arg_mode = {
'signature': 'signature',
'special': None,
}.get(suggestion.usage, 'call')
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode)
for f in self.populate_functions(suggestion.schema, filt)
arg_mode = {"signature": "signature", "special": None}.get(
suggestion.usage, "call"
)
matches = self.find_matches(word_before_cursor, funcs,
mode='strict', meta='function')
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode)
for f in all_functions
)
matches = self.find_matches(word_before_cursor, funcs, meta="function")
return matches
def get_schema_matches(self, suggestion, word_before_cursor):
schema_names = self.dbmetadata['tables'].keys()
schema_names = self.dbmetadata["tables"].keys()
# Unless we're sure the user really wants them, hide schema names
# starting with pg_, which are mostly temporary schemas
if not word_before_cursor.startswith('pg_'):
schema_names = [s
for s in schema_names
if not s.startswith('pg_')]
if not word_before_cursor.startswith("pg_"):
schema_names = [s for s in schema_names if not s.startswith("pg_")]
if suggestion.quoted:
schema_names = [self.escape_schema(s) for s in schema_names]
return self.find_matches(word_before_cursor, schema_names,
mode='strict', meta='schema')
meta="schema")
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
alias = self.generate_aliases
s = suggestion
t_sug = Table(s.schema, s.table_refs, s.local_tables)
v_sug = View(s.schema, s.table_refs)
f_sug = Function(s.schema, s.table_refs, usage='from')
f_sug = Function(s.schema, s.table_refs, usage="from")
return (
self.get_table_matches(t_sug, word_before_cursor, alias) +
self.get_view_matches(v_sug, word_before_cursor, alias) +
@ -839,42 +824,43 @@ class SQLAutoComplete(object):
"""
template = {
'call': self.call_arg_style,
'call_display': self.call_arg_display_style,
'signature': self.signature_arg_style
"call": self.call_arg_style,
"call_display": self.call_arg_display_style,
"signature": self.signature_arg_style,
}[usage]
args = func.args()
if not template or (
usage == 'call' and (
len(args) < 2 or func.has_variadic())):
return '()'
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max
if not template:
return "()"
elif usage == "call" and len(args) < 2:
return "()"
elif usage == "call" and func.has_variadic():
return "()"
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
max_arg_len = max(len(a.name) for a in args) if multiline else 0
args = (
self._format_arg(template, arg, arg_num + 1, max_arg_len)
for arg_num, arg in enumerate(args)
)
if multiline:
return '(' + ','.join('\n ' + a for a in args if a) + '\n)'
return "(" + ",".join("\n " + a for a in args if a) + "\n)"
else:
return '(' + ', '.join(a for a in args if a) + ')'
return "(" + ", ".join(a for a in args if a) + ")"
def _format_arg(self, template, arg, arg_num, max_arg_len):
if not template:
return None
if arg.has_default:
arg_default = 'NULL' if arg.default is None else arg.default
arg_default = "NULL" if arg.default is None else arg.default
# Remove trailing ::(schema.)type
arg_default = arg_default_type_strip_regex.sub('', arg_default)
arg_default = arg_default_type_strip_regex.sub("", arg_default)
else:
arg_default = ''
arg_default = ""
return template.format(
max_arg_len=max_arg_len,
arg_name=arg.name,
arg_num=arg_num,
arg_type=arg.datatype,
arg_default=arg_default
arg_default=arg_default,
)
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
@ -890,63 +876,60 @@ class SQLAutoComplete(object):
if do_alias:
alias = self.alias(cased_tbl, suggestion.table_refs)
synonyms = (cased_tbl, generate_alias(cased_tbl))
maybe_alias = (' ' + alias) if do_alias else ''
maybe_schema = (tbl.schema + '.') if tbl.schema else ''
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ''
if arg_mode == 'call':
display_suffix = self._arg_list_cache['call_display'][tbl.meta]
elif arg_mode == 'signature':
display_suffix = self._arg_list_cache['signature'][tbl.meta]
maybe_alias = (" " + alias) if do_alias else ""
maybe_schema = (tbl.schema + ".") if tbl.schema else ""
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
if arg_mode == "call":
display_suffix = self._arg_list_cache["call_display"][tbl.meta]
elif arg_mode == "signature":
display_suffix = self._arg_list_cache["signature"][tbl.meta]
else:
display_suffix = ''
display_suffix = ""
item = maybe_schema + cased_tbl + suffix + maybe_alias
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
prio2 = 0 if tbl.schema else 1
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
tables = self.populate_schema_objects(suggestion.schema, "tables")
tables.extend(
SchemaObject(tbl.name) for tbl in suggestion.local_tables)
# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.name.startswith('pg_')]
if not suggestion.schema and \
(not word_before_cursor.startswith("pg_")):
tables = [t for t in tables if not t.name.startswith("pg_")]
tables = [self._make_cand(t, alias, suggestion) for t in tables]
return self.find_matches(word_before_cursor, tables,
mode='strict', meta='table')
return self.find_matches(word_before_cursor, tables, meta="table")
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
views = self.populate_schema_objects(suggestion.schema, 'views')
views = self.populate_schema_objects(suggestion.schema, "views")
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.name.startswith('pg_')]
not word_before_cursor.startswith("pg_")):
views = [v for v in views if not v.name.startswith("pg_")]
views = [self._make_cand(v, alias, suggestion) for v in views]
return self.find_matches(word_before_cursor, views,
mode='strict', meta='view')
return self.find_matches(word_before_cursor, views, meta="view")
def get_alias_matches(self, suggestion, word_before_cursor):
aliases = suggestion.aliases
return self.find_matches(word_before_cursor, aliases,
mode='strict', meta='table alias')
meta="table alias")
def get_database_matches(self, _, word_before_cursor):
return self.find_matches(word_before_cursor, self.databases,
mode='strict', meta='database')
meta="database")
def get_keyword_matches(self, suggestion, word_before_cursor):
return self.find_matches(word_before_cursor, self.keywords,
mode='strict', meta='keyword')
meta="keyword")
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
types = self.populate_schema_objects(suggestion.schema, "datatypes")
types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types,
mode='strict', meta='datatype')
matches = self.find_matches(word_before_cursor, types, meta="datatype")
return matches
def get_word_before_cursor(self, word=False):
@ -1004,52 +987,6 @@ class SQLAutoComplete(object):
Datatype: get_datatype_matches,
}
def addcols(self, schema, rel, alias, reltype, cols, columns):
"""
Add columns in schema column list.
:param schema: Schema for reference.
:param rel:
:param alias:
:param reltype:
:param cols:
:param columns:
:return:
"""
tbl = TableReference(schema, rel, alias, reltype == 'functions')
if tbl not in columns:
columns[tbl] = []
columns[tbl].extend(cols)
def _get_schema_columns(self, schemas, tbl, meta, columns):
"""
Check and add schema table columns as per table.
:param schemas: Schema
:param tbl:
:param meta:
:param columns: column list
:return:
"""
for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta['functions'].get(schema, {}).get(relname)
for func in (functions or []):
# func is a FunctionMetadata object
cols = func.fields()
self.addcols(schema, relname, tbl.alias, 'functions', cols,
columns)
else:
for reltype in ('tables', 'views'):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
cols = cols.values()
self.addcols(schema, relname, tbl.alias, reltype, cols,
columns)
break
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
"""Find all columns in a set of scoped_tables.
@ -1062,14 +999,37 @@ class SQLAutoComplete(object):
columns = OrderedDict()
meta = self.dbmetadata
def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == "functions")
if tbl not in columns:
columns[tbl] = []
columns[tbl].extend(cols)
for tbl in scoped_tbls:
# Local tables should shadow database tables
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
cols = ctes[normalize_ref(tbl.name)]
self.addcols(None, tbl.name, 'CTE', tbl.alias, cols, columns)
addcols(None, tbl.name, "CTE", tbl.alias, cols)
continue
schemas = [tbl.schema] if tbl.schema else self.search_path
self._get_schema_columns(schemas, tbl, meta, columns)
for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta["functions"].get(schema, {}).get(relname)
for func in functions or []:
# func is a FunctionMetadata object
cols = func.fields()
addcols(schema, relname, tbl.alias, "functions", cols)
else:
for reltype in ("tables", "views"):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
cols = cols.values()
addcols(schema, relname, tbl.alias, reltype, cols)
break
return columns
@ -1125,10 +1085,10 @@ class SQLAutoComplete(object):
SchemaObject(
name=func,
schema=(self._maybe_schema(schema=sch, parent=schema)),
meta=meta
meta=meta,
)
for sch in self._get_schemas('functions', schema)
for (func, metas) in self.dbmetadata['functions'][sch].items()
for sch in self._get_schemas("functions", schema)
for (func, metas) in self.dbmetadata["functions"][sch].items()
for meta in metas
if filter_func(meta)
]
@ -1234,6 +1194,7 @@ class SQLAutoComplete(object):
row['is_aggregate'],
row['is_window'],
row['is_set_returning'],
row['is_extension'],
row['arg_defaults'].strip('{}').split(',')
if row['arg_defaults'] is not None
else row['arg_defaults']

View File

@ -1,7 +1,6 @@
"""
Using Completion class from
https://github.com/jonathanslenders/python-prompt-toolkit/
blob/master/prompt_toolkit/completion.py
https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/completion/base.py
"""
__all__ = (
@ -38,7 +37,7 @@ class Completion(object):
assert self.start_position <= 0
def __repr__(self):
return '%s(text=%r, start_position=%r)' % (
return "%s(text=%r, start_position=%r)" % (
self.__class__.__name__, self.text, self.start_position)
def __eq__(self, other):

View File

@ -1,295 +0,0 @@
import re
from collections import namedtuple
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space.
'all_punctuations': re.compile('([^\s]+)$'),
}
def last_word(text, include='alphanum_underscore'):
"""
Find the last word in a sentence.
>>> last_word('abc')
'abc'
>>> last_word(' abc')
'abc'
>>> last_word('')
''
>>> last_word(' ')
''
>>> last_word('abc ')
''
>>> last_word('abc def')
'def'
>>> last_word('abc def ')
''
>>> last_word('abc def;')
''
>>> last_word('bac $def')
'def'
>>> last_word('bac $def', include='most_punctuations')
'$def'
>>> last_word('bac \def', include='most_punctuations')
'\\\\def'
>>> last_word('bac \def;', include='most_punctuations')
'\\\\def;'
>>> last_word('bac::def', include='most_punctuations')
'def'
>>> last_word('"foo*bar', include='most_punctuations')
'"foo*bar'
"""
if not text: # Empty string
return ''
if text[-1].isspace():
return ''
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
TableReference = namedtuple(
'TableReference', ['schema', 'name', 'alias', 'is_function']
)
# This code is borrowed from sqlparse example script.
# <url>
def is_subselect(parsed):
if not parsed.is_group():
return False
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in sql_type:
return True
return False
def _identifier_is_function(identifier):
return any(isinstance(t, Function) for t in identifier.tokens)
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This
# causes the second FROM to trigger this elif condition resulting
# in a StopIteration. So we need to ignore the keyword if the
# keyword FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its
# variants INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
item_val.endswith('JOIN')):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
tbl_prefix_seen = True
break
def extract_table_identifiers(token_stream, allow_functions=True):
"""yields tuples of TableReference namedtuples"""
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = (allow_functions and
_identifier_is_function(identifier))
except AttributeError:
continue
if real_name:
yield TableReference(
schema_name, real_name, identifier.get_alias(),
is_function
)
elif isinstance(item, Identifier):
real_name = item.get_real_name()
schema_name = item.get_parent_name()
is_function = allow_functions and _identifier_is_function(item)
if real_name:
yield TableReference(
schema_name, real_name, item.get_alias(), is_function
)
else:
name = item.get_name()
yield TableReference(
None, name, item.get_alias() or name, is_function
)
elif isinstance(item, Function):
yield TableReference(
None, item.get_real_name(), item.get_alias(), allow_functions
)
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
"""Extract the table names from an SQL statment.
Returns a list of TableReference namedtuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return ()
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
# function calls due to the parenthesized column list, e.g. interprets
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(
stream, allow_functions=not insert_stmt
)
return tuple(identifiers)
def find_prev_keyword(sql):
""" Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child
# token inside a TokenList, in which case token_index thows an
# error Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
# p.token_index(t) # Throws ValueError: not in list
idx = flattened.index(t)
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx + 1])
return t, text
return None, ''
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
tokens = list(parsed.flatten())
i = 0
while i < len(tokens):
tok = tokens[i]
if tok.match(Token.Error, "'"):
# An unmatched single quote
return True
elif (tok.ttype in Token.Name.Builtin and
dollar_quote_regex.match(tok.value)):
# Find the matching closing dollar quote sign
for (j, tok2) in enumerate(tokens[i + 1:], i + 1):
if tok2.match(Token.Name.Builtin, tok.value):
# Found the matching closing quote - continue our scan
# for open quotes thereafter
i = j
break
else:
# No matching dollar sign quote
return True
i += 1
return False
def parse_partial_identifier(word):
"""Attempt to parse a (partially typed) word as an identifier
word may include a schema qualification, like `schema_name.partial_name`
or `schema_name.` There may also be unclosed quotation marks, like
`"schema`, or `schema."partial_name`
:param word: string representing a (partially complete) identifier
:return: sqlparse.sql.Identifier, or None
"""
p = sqlparse.parse(word)[0]
n_tok = len(p.tokens)
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
return p.tokens[0]
elif hasattr(p, 'token_next_match') and p.token_next_match(0, Error, '"'):
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
# Close the double quote, then reparse
return parse_partial_identifier(word + '"')
else:
return None
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print(extract_tables(sql))

View File

@ -4,7 +4,7 @@ import sqlparse
def query_starts_with(query, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
@ -18,5 +18,5 @@ def queries_start_with(queries, prefixes):
def is_destructive(queries):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
return queries_start_with(queries, keywords)

View File

@ -10,12 +10,11 @@ from .meta import TableMetadata, ColumnMetadata
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple('TableExpression', 'name columns start stop')
TableExpression = namedtuple("TableExpression", "name columns start stop")
def isolate_query_ctes(full_text, text_before_cursor):
"""Simplify a query by converting CTEs into table metadata objects
"""
"""Simplify a query by converting CTEs into table metadata objects"""
if not full_text:
return full_text, text_before_cursor, tuple()
@ -30,8 +29,8 @@ def isolate_query_ctes(full_text, text_before_cursor):
for cte in ctes:
if cte.start < current_position < cte.stop:
# Currently editing a cte - treat its body as the current full_text
text_before_cursor = full_text[cte.start:current_position]
full_text = full_text[cte.start:cte.stop]
text_before_cursor = full_text[cte.start: current_position]
full_text = full_text[cte.start: cte.stop]
return full_text, text_before_cursor, meta
# Append this cte to the list of available table metadata
@ -40,19 +39,19 @@ def isolate_query_ctes(full_text, text_before_cursor):
# Editing past the last cte (ie the main body of the query)
full_text = full_text[ctes[-1].stop:]
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
text_before_cursor = text_before_cursor[ctes[-1].stop: current_position]
return full_text, text_before_cursor, tuple(meta)
def extract_ctes(sql):
""" Extract constant table expresseions from a query
"""Extract constant table expresseions from a query
Returns tuple (ctes, remainder_sql)
Returns tuple (ctes, remainder_sql)
ctes is a list of TableExpression namedtuples
remainder_sql is the text from the original query after the CTEs have
been stripped.
ctes is a list of TableExpression namedtuples
remainder_sql is the text from the original query after the CTEs have
been stripped.
"""
p = parse(sql)[0]
@ -66,7 +65,7 @@ def extract_ctes(sql):
# Get the next (meaningful) token, which should be the first CTE
idx, tok = p.token_next(idx)
if not tok:
return ([], '')
return ([], "")
start_pos = token_start_pos(p.tokens, idx)
ctes = []
@ -87,7 +86,7 @@ def extract_ctes(sql):
idx = p.token_index(tok) + 1
# Collapse everything after the ctes into a remainder query
remainder = ''.join(str(tok) for tok in p.tokens[idx:])
remainder = "".join(str(tok) for tok in p.tokens[idx:])
return ctes, remainder
@ -112,15 +111,15 @@ def get_cte_from_token(tok, pos0):
def extract_column_names(parsed):
# Find the first DML token to check if it's a SELECT or
# INSERT/UPDATE/DELETE
# Find the first DML token to check if it's a
# SELECT or INSERT/UPDATE/DELETE
idx, tok = parsed.token_next_by(t=DML)
tok_val = tok and tok.value.lower()
if tok_val in ('insert', 'update', 'delete'):
if tok_val in ("insert", "update", "delete"):
# Jump ahead to the RETURNING clause where the list of column names is
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
elif not tok_val == 'select':
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
elif not tok_val == "select":
# Must be invalid CTE
return ()

View File

@ -1,23 +1,29 @@
from collections import namedtuple
_ColumnMetadata = namedtuple(
'ColumnMetadata',
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default",
"has_default"]
)
def ColumnMetadata(
name, datatype, foreignkeys=None, default=None, has_default=False
):
return _ColumnMetadata(
name, datatype, foreignkeys or [], default, has_default
)
def ColumnMetadata(name, datatype, foreignkeys=None, default=None,
has_default=False):
return _ColumnMetadata(name, datatype, foreignkeys or [], default,
has_default)
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
'parentcolumn', 'childschema',
'childtable', 'childcolumn'])
TableMetadata = namedtuple('TableMetadata', 'name columns')
ForeignKey = namedtuple(
"ForeignKey",
[
"parentschema",
"parenttable",
"parentcolumn",
"childschema",
"childtable",
"childcolumn",
],
)
TableMetadata = namedtuple("TableMetadata", "name columns")
def parse_defaults(defaults_string):
@ -25,34 +31,42 @@ def parse_defaults(defaults_string):
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
if not defaults_string:
return
current = ''
current = ""
in_quote = None
for char in defaults_string:
if current == '' and char == ' ':
if current == "" and char == " ":
# Skip space after comma separating default expressions
continue
if char == '"' or char == '\'':
if char == '"' or char == "'":
if in_quote and char == in_quote:
# End quote
in_quote = None
elif not in_quote:
# Begin quote
in_quote = char
elif char == ',' and not in_quote:
elif char == "," and not in_quote:
# End of expression
yield current
current = ''
current = ""
continue
current += char
yield current
class FunctionMetadata(object):
def __init__(
self, schema_name, func_name, arg_names, arg_types, arg_modes,
return_type, is_aggregate, is_window, is_set_returning,
arg_defaults
self,
schema_name,
func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
):
"""Class for describing a postgresql function"""
@ -80,19 +94,29 @@ class FunctionMetadata(object):
self.is_aggregate = is_aggregate
self.is_window = is_window
self.is_set_returning = is_set_returning
self.is_extension = bool(is_extension)
self.is_public = self.schema_name and self.schema_name == "public"
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
return isinstance(other, self.__class__) and \
self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
def _signature(self):
return (
self.schema_name, self.func_name, self.arg_names, self.arg_types,
self.arg_modes, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning, self.arg_defaults
self.schema_name,
self.func_name,
self.arg_names,
self.arg_types,
self.arg_modes,
self.return_type,
self.is_aggregate,
self.is_window,
self.is_set_returning,
self.is_extension,
self.arg_defaults,
)
def __hash__(self):
@ -100,26 +124,25 @@ class FunctionMetadata(object):
def __repr__(self):
return (
(
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
) % (self.__class__.__name__,) + self._signature()
)
"%s(schema_name=%r, func_name=%r, arg_names=%r, "
"arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
"is_window=%r, is_set_returning=%r, is_extension=%r, "
"arg_defaults=%r)"
) % ((self.__class__.__name__,) + self._signature())
def has_variadic(self):
return self.arg_modes and any(
arg_mode == 'v' for arg_mode in self.arg_modes)
return self.arg_modes and \
any(arg_mode == "v" for arg_mode in self.arg_modes)
def args(self):
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
if not self.arg_names:
return []
modes = self.arg_modes or ['i'] * len(self.arg_names)
modes = self.arg_modes or ["i"] * len(self.arg_names)
args = [
(name, typ)
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
]
def arg(name, typ, num):
@ -127,7 +150,8 @@ class FunctionMetadata(object):
num_defaults = len(self.arg_defaults)
has_default = num + num_defaults >= num_args
default = (
self.arg_defaults[num - num_args + num_defaults] if has_default
self.arg_defaults[num - num_args + num_defaults]
if has_default
else None
)
return ColumnMetadata(name, typ, [], default, has_default)
@ -137,7 +161,7 @@ class FunctionMetadata(object):
def fields(self):
"""Returns a list of output-field ColumnMetadata namedtuples"""
if self.return_type.lower() == 'void':
if self.return_type.lower() == "void":
return []
elif not self.arg_modes:
# For functions without output parameters, the function name
@ -145,7 +169,9 @@ class FunctionMetadata(object):
# E.g. 'SELECT unnest FROM unnest(...);'
return [ColumnMetadata(self.func_name, self.return_type, [])]
return [ColumnMetadata(name, typ, [])
for name, typ, mode in zip(
self.arg_names, self.arg_types, self.arg_modes)
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
return [
ColumnMetadata(name, typ, [])
for name, typ, mode in zip(self.arg_names, self.arg_types,
self.arg_modes)
if mode in ("o", "b", "t")
] # OUT, INOUT, TABLE

View File

@ -3,12 +3,15 @@ from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
TableReference = namedtuple(
"TableReference", ["schema", "name", "alias", "is_function"]
)
TableReference.ref = property(
lambda self: self.alias or (
self.name if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"')
self.name
if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'
)
)
@ -18,9 +21,13 @@ def is_subselect(parsed):
if not parsed.is_group:
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
'UPDATE', 'CREATE',
'DELETE'):
if item.ttype is DML and item.value.upper() in (
"SELECT",
"INSERT",
"UPDATE",
"CREATE",
"DELETE",
):
return True
return False
@ -37,32 +44,42 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for x in extract_from_part(item, stop_at_punctuation):
yield x
elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration
return
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a
# StopIteration. So we need to ignore the keyword if the keyword
# `return`. So we need to ignore the keyword if the keyword
# FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
elif (
item.ttype is Keyword and
(not item.value.upper() == "FROM") and
(not item.value.upper().endswith("JOIN"))
):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
item_val.endswith('JOIN')):
if (
item_val
in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
) or item_val.endswith("JOIN")
):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
if identifier.ttype is Keyword and \
identifier.value.upper() == "FROM":
tbl_prefix_seen = True
break
@ -94,29 +111,35 @@ def extract_table_identifiers(token_stream, allow_functions=True):
name = name.lower()
return schema_name, name, alias
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = (allow_functions and
_identifier_is_function(identifier))
except AttributeError:
continue
if real_name:
yield TableReference(schema_name, real_name,
identifier.get_alias(), is_function)
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
try:
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = allow_functions and \
_identifier_is_function(identifier)
except AttributeError:
continue
if real_name:
yield TableReference(
schema_name, real_name, identifier.get_alias(),
is_function
)
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
yield TableReference(schema_name, real_name, alias, is_function)
elif isinstance(item, Function):
schema_name, real_name, alias = parse_identifier(item)
yield TableReference(None, real_name, alias, allow_functions)
yield TableReference(schema_name, real_name, alias,
is_function)
elif isinstance(item, Function):
schema_name, real_name, alias = parse_identifier(item)
yield TableReference(None, real_name, alias, allow_functions)
except StopIteration:
return
# extract_tables is inspired from examples in the sqlparse lib.
@ -134,7 +157,7 @@ def extract_tables(sql):
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
insert_stmt = parsed[0].token_first().value.lower() == "insert"
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as

View File

@ -5,17 +5,17 @@ from sqlparse.tokens import Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
"alphanum_underscore": re.compile(r"(\w+)$"),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
"many_punctuations": re.compile(r"([^():,\s]+)$"),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
# This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'),
"all_punctuations": re.compile(r"([^\s]+)$"),
}
def last_word(text, include='alphanum_underscore'):
def last_word(text, include="alphanum_underscore"):
r"""
Find the last word in a sentence.
@ -49,41 +49,42 @@ def last_word(text, include='alphanum_underscore'):
'"foo*bar'
"""
if not text: # Empty string
return ''
if not text: # Empty string
return ""
if text[-1].isspace():
return ''
return ""
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
return ""
def find_prev_keyword(sql, n_skip=0):
""" Find the last sql keyword in an SQL statement
"""Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
return None, ""
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
flattened = flattened[:len(flattened) - n_skip]
flattened = flattened[: len(flattened) - n_skip]
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
if t.value == "(" or (
t.is_keyword and (t.value.upper() not in logical_operators)
):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error
# inside a TokenList, in which case token_index throws an error
# Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
@ -93,14 +94,14 @@ def find_prev_keyword(sql, n_skip=0):
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx + 1])
text = "".join(tok.value for tok in flattened[: idx + 1])
return t, text
return None, ''
return None, ""
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
def is_open_quote(sql):

View File

@ -4,13 +4,13 @@ from sqlparse.tokens import Name
from collections import defaultdict
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
white_space_regex = re.compile("\\s+", re.MULTILINE)
def _compile_regex(keyword):
# Surround the keyword with word boundaries and replace interior whitespace
# with whitespace wildcards
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b'
pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)

View File

@ -3,28 +3,29 @@ import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import (
last_word, find_prev_keyword, parse_partial_identifier)
from .parseutils.utils import last_word, find_prev_keyword,\
parse_partial_identifier
from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes
Special = namedtuple('Special', [])
Database = namedtuple('Database', [])
Schema = namedtuple('Schema', ['quoted'])
Special = namedtuple("Special", [])
Database = namedtuple("Database", [])
Schema = namedtuple("Schema", ["quoted"])
Schema.__new__.__defaults__ = (False,)
# FromClauseItem is a table/view/function used in the FROM clause
# `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
TableFormat = namedtuple('TableFormat', [])
View = namedtuple('View', ['schema', 'table_refs'])
FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
TableFormat = namedtuple("TableFormat", [])
View = namedtuple("View", ["schema", "table_refs"])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple('Join', ['table_refs', 'schema'])
Join = namedtuple("Join", ["table_refs", "schema"])
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
# For convenience, don't require the `usage` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple(), tuple())
@ -32,31 +33,33 @@ View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
Column = namedtuple(
'Column',
['table_refs', 'require_last_table', 'local_tables',
'qualifiable', 'context']
"Column",
["table_refs", "require_last_table", "local_tables", "qualifiable",
"context"],
)
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
Keyword = namedtuple('Keyword', ['last_token'])
Keyword = namedtuple("Keyword", ["last_token"])
Keyword.__new__.__defaults__ = (None,)
NamedQuery = namedtuple('NamedQuery', [])
Datatype = namedtuple('Datatype', ['schema'])
Alias = namedtuple('Alias', ['aliases'])
NamedQuery = namedtuple("NamedQuery", [])
Datatype = namedtuple("Datatype", ["schema"])
Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple('Path', [])
Path = namedtuple("Path", [])
class SqlStatement(object):
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
text_before_cursor, include='many_punctuations')
text_before_cursor, include="many_punctuations"
)
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
full_text, text_before_cursor, self.local_tables = \
isolate_query_ctes(full_text, text_before_cursor)
full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
full_text, text_before_cursor
)
self.text_before_cursor_including_last_word = text_before_cursor
@ -67,39 +70,41 @@ class SqlStatement(object):
# completion useless because it will always return the list of
# keywords as completion.
if self.word_before_cursor:
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
parsed = sqlparse.parse(text_before_cursor)
else:
text_before_cursor = \
text_before_cursor[:-len(word_before_cursor)]
text_before_cursor[: -len(word_before_cursor)]
parsed = sqlparse.parse(text_before_cursor)
self.identifier = parse_partial_identifier(word_before_cursor)
else:
parsed = sqlparse.parse(text_before_cursor)
full_text, text_before_cursor, parsed = \
_split_multiple_statements(full_text, text_before_cursor, parsed)
full_text, text_before_cursor, parsed = _split_multiple_statements(
full_text, text_before_cursor, parsed
)
self.full_text = full_text
self.text_before_cursor = text_before_cursor
self.parsed = parsed
self.last_token = \
parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
self.last_token = parsed and \
parsed.token_prev(len(parsed.tokens))[1] or ""
def is_insert(self):
return self.parsed.token_first().value.lower() == 'insert'
return self.parsed.token_first().value.lower() == "insert"
def get_tables(self, scope='full'):
""" Gets the tables available in the statement.
def get_tables(self, scope="full"):
"""Gets the tables available in the statement.
param `scope:` possible values: 'full', 'insert', 'before'
If 'insert', only the first table is returned.
If 'before', only tables before the cursor are returned.
If not 'insert' and the stmt is an insert, the first table is skipped.
"""
tables = extract_tables(
self.full_text if scope == 'full' else self.text_before_cursor)
if scope == 'insert':
self.full_text if scope == "full" else self.text_before_cursor
)
if scope == "insert":
tables = tables[:1]
elif self.is_insert():
tables = tables[1:]
@ -118,8 +123,9 @@ class SqlStatement(object):
return schema
def reduce_to_prev_keyword(self, n_skip=0):
prev_keyword, self.text_before_cursor = \
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
prev_keyword, self.text_before_cursor = find_prev_keyword(
self.text_before_cursor, n_skip=n_skip
)
return prev_keyword
@ -131,7 +137,7 @@ def suggest_type(full_text, text_before_cursor):
A scope for a column category will be a list of tables.
"""
if full_text.startswith('\\i '):
if full_text.startswith("\\i "):
return (Path(),)
# This is a temporary hack; the exception handling
@ -144,7 +150,7 @@ def suggest_type(full_text, text_before_cursor):
return suggest_based_on_last_token(stmt.last_token, stmt)
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
def _strip_named_query(txt):
@ -155,11 +161,11 @@ def _strip_named_query(txt):
"""
if named_query_regex.match(txt):
txt = named_query_regex.sub('', txt)
txt = named_query_regex.sub("", txt)
return txt
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
def _find_function_body(text):
@ -205,12 +211,12 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
return full_text, text_before_cursor, None
token2 = None
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
token1 = statement.token_first()
if token1:
token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == 'FUNCTION':
if token2 and token2.value.upper() == "FUNCTION":
full_text, text_before_cursor, statement = _statement_from_function(
full_text, text_before_cursor, statement
)
@ -246,9 +252,9 @@ def suggest_based_on_last_token(token, stmt):
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == '(':
if prev_keyword and prev_keyword.value == "(":
# Suggest datatypes
return suggest_based_on_last_token('type', stmt)
return suggest_based_on_last_token("type", stmt)
else:
return (Keyword(),)
else:
@ -256,7 +262,7 @@ def suggest_based_on_last_token(token, stmt):
if not token:
return (Keyword(),)
elif token_v.endswith('('):
elif token_v.endswith("("):
p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
@ -268,10 +274,10 @@ def suggest_based_on_last_token(token, stmt):
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
# Suggest columns/functions AND keywords. (If we wanted to
# be really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token('where', stmt)
column_suggestions = suggest_based_on_last_token("where", stmt)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
@ -282,7 +288,7 @@ def suggest_based_on_last_token(token, stmt):
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == 'exists':
if prev_tok == "exists":
return (Keyword(),)
else:
return column_suggestions
@ -292,59 +298,47 @@ def suggest_based_on_last_token(token, stmt):
if (
prev_tok and prev_tok.value and
prev_tok.value.lower().split(' ')[-1] == 'using'
prev_tok.value.lower().split(" ")[-1] == "using"
):
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables('before')
tables = stmt.get_tables("before")
# suggest columns that are present in more than one table
return (Column(table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables),)
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
elif p.token_first().value.lower() == 'select' and \
last_word(stmt.text_before_cursor,
'all_punctuations').startswith('('):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
return (
Column(table_refs=stmt.get_tables('insert'), context='insert'),
Column(
table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables,
),
)
elif p.token_first().value.lower() == "select":
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor,
"all_punctuations").startswith("("):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
return (Column(table_refs=stmt.get_tables("insert"),
context="insert"),)
# We're probably in a function argument list
return (Column(table_refs=extract_tables(stmt.full_text),
local_tables=stmt.local_tables, qualifiable=True),)
elif token_v == 'set':
return _suggest_expression(token_v, stmt)
elif token_v == "set":
return (Column(table_refs=stmt.get_tables(),
local_tables=stmt.local_tables),)
elif token_v in ('select', 'where', 'having', 'by', 'distinct'):
# Check for a table alias or schema qualification
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
return (Column(table_refs=tables, local_tables=stmt.local_tables,
qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),)
elif token_v == 'as':
elif token_v in ("select", "where", "having", "order by", "distinct"):
return _suggest_expression(token_v, stmt)
elif token_v == "as":
# Don't suggest anything for aliases
return ()
elif (
(token_v.endswith('join') and token.is_keyword) or
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
elif (token_v.endswith("join") and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate")
):
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith('join') and token.is_keyword
is_join = token_v.endswith("join") and token.is_keyword
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
@ -354,60 +348,77 @@ def suggest_based_on_last_token(token, stmt):
# Suggest schemas
suggest.insert(0, Schema())
if token_v == 'from' or is_join:
suggest.append(FromClauseItem(schema=schema,
table_refs=tables,
local_tables=stmt.local_tables))
elif token_v == 'truncate':
if token_v == "from" or is_join:
suggest.append(
FromClauseItem(
schema=schema, table_refs=tables,
local_tables=stmt.local_tables
)
)
elif token_v == "truncate":
suggest.append(Table(schema))
else:
suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables('before')
tables = stmt.get_tables("before")
suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest)
elif token_v == 'function':
elif token_v == "function":
schema = stmt.get_identifier_schema()
# stmt.get_previous_token will fail for e.g.
# `SELECT 1 FROM functions WHERE function:`
try:
prev = stmt.get_previous_token(token).value.lower()
if prev in ('drop', 'alter', 'create', 'create or replace'):
return (Function(schema=schema, usage='signature'),)
if prev in ("drop", "alter", "create", "create or replace"):
# Suggest functions from either the currently-selected schema
# or the public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
suggest.append(Function(schema=schema, usage="signature"))
return tuple(suggest)
except ValueError:
pass
return tuple()
elif token_v in ('table', 'view'):
elif token_v in ("table", "view"):
# E.g. 'ALTER TABLE <tablname>'
rel_type = \
{'table': Table, 'view': View, 'function': Function}[token_v]
{"table": Table, "view": View, "function": Function}[token_v]
schema = stmt.get_identifier_schema()
if schema:
return (rel_type(schema=schema),)
else:
return (Schema(), rel_type(schema=schema))
elif token_v == 'column':
elif token_v == "column":
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
return (Column(table_refs=stmt.get_tables()),)
elif token_v == 'on':
tables = stmt.get_tables('before')
elif token_v == "on":
tables = stmt.get_tables("before")
parent = \
(stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
sugs = [Column(table_refs=filteredtables,
local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent)]
sugs = [
Column(table_refs=filteredtables,
local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
]
if filteredtables and _allow_join_condition(stmt.parsed):
sugs.append(JoinCondition(table_refs=tables,
parent=filteredtables[-1]))
@ -417,38 +428,39 @@ def suggest_based_on_last_token(token, stmt):
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed):
return (Alias(aliases=aliases), JoinCondition(
table_refs=tables, parent=None))
return (
Alias(aliases=aliases),
JoinCondition(table_refs=tables, parent=None),
)
else:
return (Alias(aliases=aliases),)
elif token_v in ('c', 'use', 'database', 'template'):
elif token_v in ("c", "use", "database", "template"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return (Database(),)
elif token_v == 'schema':
elif token_v == "schema":
# DROP SCHEMA schema_name, SET SCHEMA schema name
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
quoted = prev_keyword and prev_keyword.value.lower() == "set"
return (Schema(quoted),)
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
elif token_v.endswith(",") or token_v in ("=", "and", "or"):
prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
else:
return ()
elif token_v in ('type', '::'):
elif token_v in ("type", "::"):
# ALTER TABLE foo SET DATA TYPE bar
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema),
Table(schema=schema)]
suggestions = [Datatype(schema=schema), Table(schema=schema)]
if not schema:
suggestions.append(Schema())
return tuple(suggestions)
elif token_v in ['alter', 'create', 'drop']:
elif token_v in {"alter", "create", "drop"}:
return (Keyword(token_v.upper()),)
elif token.is_keyword:
# token is a keyword we haven't implemented any special handling for
@ -462,11 +474,38 @@ def suggest_based_on_last_token(token, stmt):
return (Keyword(),)
def _suggest_expression(token_v, stmt):
"""
Return suggestions for an expression, taking account of any partially-typed
identifier's parent, which may be a table alias or schema name.
"""
parent = stmt.identifier.get_parent_name() if stmt.identifier else []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (
Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
)
return (
Column(table_refs=tables, local_tables=stmt.local_tables,
qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),
)
def identifies(id, ref):
"""Returns true if string `id` matches TableReference `ref`"""
return id == ref.alias or id == ref.name or (
ref.schema and (id == ref.schema + '.' + ref.name))
return (
id == ref.alias or id == ref.name or
(ref.schema and (id == ref.schema + "." + ref.name))
)
def _allow_join_condition(statement):
@ -486,7 +525,7 @@ def _allow_join_condition(statement):
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower() in ('on', 'and', 'or')
return last_tok.value.lower() in ("on", "and", "or")
def _allow_join(statement):
@ -505,7 +544,5 @@ def _allow_join(statement):
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return (
last_tok.value.lower().endswith('join') and
last_tok.value.lower() not in ('cross join', 'natural join')
)
return last_tok.value.lower().endswith("join") and \
last_tok.value.lower() not in ("cross join", "natural join",)