Merged the latest code of 'pgcli' used for the autocomplete feature. Fixes #5497

This commit is contained in:
Akshay Joshi 2020-10-01 12:53:45 +05:30
parent 3f817494f8
commit 300de05a20
13 changed files with 574 additions and 818 deletions

View File

@ -16,6 +16,7 @@ Housekeeping
************ ************
| `Issue #5330 <https://redmine.postgresql.org/issues/5330>`_ - Improve code coverage and API test cases for Functions. | `Issue #5330 <https://redmine.postgresql.org/issues/5330>`_ - Improve code coverage and API test cases for Functions.
| `Issue #5497 <https://redmine.postgresql.org/issues/5497>`_ - Merged the latest code of 'pgcli' used for the autocomplete feature.
Bug fixes Bug fixes
********* *********

View File

@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate, CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window, CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
p.proretset is_set_returning, p.proretset is_set_returning,
d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype WHERE p.prorettype::regtype != 'trigger'::regtype
AND n.nspname IN ({{schema_names}}) AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2 ORDER BY 1, 2

View File

@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
p.proisagg is_aggregate, p.proisagg is_aggregate,
p.proiswindow is_window, p.proiswindow is_window,
p.proretset is_set_returning, p.proretset is_set_returning,
d.deptype = 'e' is_extension,
pg_get_expr(proargdefaults, 0) AS arg_defaults pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
WHERE p.prorettype::regtype != 'trigger'::regtype WHERE p.prorettype::regtype != 'trigger'::regtype
AND n.nspname IN ({{schema_names}}) AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2 ORDER BY 1, 2

View File

@ -11,8 +11,7 @@
import re import re
import operator import operator
import sys from itertools import count
from itertools import count, repeat, chain
from .completion import Completion from .completion import Completion
from collections import namedtuple, defaultdict, OrderedDict from collections import namedtuple, defaultdict, OrderedDict
@ -28,9 +27,9 @@ from pgadmin.utils.driver import get_driver
from config import PG_DEFAULT_DRIVER from config import PG_DEFAULT_DRIVER
from pgadmin.utils.preferences import Preferences from pgadmin.utils.preferences import Preferences
Match = namedtuple('Match', ['completion', 'priority']) Match = namedtuple("Match", ["completion", "priority"])
_SchemaObject = namedtuple('SchemaObject', 'name schema meta') _SchemaObject = namedtuple("SchemaObject", "name schema meta")
def SchemaObject(name, schema=None, meta=None): def SchemaObject(name, schema=None, meta=None):
@ -41,14 +40,12 @@ def SchemaObject(name, schema=None, meta=None):
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)') _FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)') _FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
_Candidate = namedtuple( _Candidate = namedtuple("Candidate",
'Candidate', 'completion prio meta synonyms prio2 display' "completion prio meta synonyms prio2 display")
)
def Candidate( def Candidate(
completion, prio=None, meta=None, synonyms=None, prio2=None, completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
display=None
): ):
return _Candidate( return _Candidate(
completion, prio, meta, synonyms or [completion], prio2, completion, prio, meta, synonyms or [completion], prio2,
@ -57,7 +54,7 @@ def Candidate(
# Used to strip trailing '::some_type' from default-value expressions # Used to strip trailing '::some_type' from default-value expressions
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$') arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
def normalize_ref(ref): def normalize_ref(ref):
@ -65,15 +62,15 @@ def normalize_ref(ref):
def generate_alias(tbl): def generate_alias(tbl):
""" Generate a table alias, consisting of all upper-case letters in """Generate a table alias, consisting of all upper-case letters in
the table name, or, if there are no upper-case letters, the first letter + the table name, or, if there are no upper-case letters, the first letter +
all letters preceded by _ all letters preceded by _
param tbl - unescaped name of the table to alias param tbl - unescaped name of the table to alias
""" """
return ''.join( return "".join(
[letter for letter in tbl if letter.isupper()] or [letter for letter in tbl if letter.isupper()] or
[letter for letter, prev in zip(tbl, '_' + tbl) [letter for letter, prev in zip(tbl, "_" + tbl)
if prev == '_' and letter != '_'] if prev == "_" and letter != "_"]
) )
@ -97,13 +94,14 @@ class SQLAutoComplete(object):
self.sid = kwargs['sid'] if 'sid' in kwargs else None self.sid = kwargs['sid'] if 'sid' in kwargs else None
self.conn = kwargs['conn'] if 'conn' in kwargs else None self.conn = kwargs['conn'] if 'conn' in kwargs else None
self.keywords = [] self.keywords = []
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
self.databases = [] self.databases = []
self.functions = [] self.functions = []
self.datatypes = [] self.datatypes = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, self.dbmetadata = \
'datatypes': {}} {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
self.text_before_cursor = None self.text_before_cursor = None
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid) manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
@ -182,7 +180,8 @@ class SQLAutoComplete(object):
def escape_name(self, name): def escape_name(self, name):
if name and ( if name and (
(not self.name_pattern.match(name)) or (not self.name_pattern.match(name)) or
(name.upper() in self.reserved_words) (name.upper() in self.reserved_words) or
(name.upper() in self.functions)
): ):
name = '"%s"' % name name = '"%s"' % name
@ -212,7 +211,7 @@ class SQLAutoComplete(object):
# schemata is a list of schema names # schemata is a list of schema names
schemata = self.escaped_names(schemata) schemata = self.escaped_names(schemata)
metadata = self.dbmetadata['tables'] metadata = self.dbmetadata["tables"]
for schema in schemata: for schema in schemata:
metadata[schema] = {} metadata[schema] = {}
@ -224,7 +223,7 @@ class SQLAutoComplete(object):
self.all_completions.update(schemata) self.all_completions.update(schemata)
def extend_casing(self, words): def extend_casing(self, words):
""" extend casing data """extend casing data
:return: :return:
""" """
@ -274,7 +273,7 @@ class SQLAutoComplete(object):
name=colname, name=colname,
datatype=datatype, datatype=datatype,
has_default=has_default, has_default=has_default,
default=default default=default,
) )
metadata[schema][relname][colname] = column metadata[schema][relname][colname] = column
self.all_completions.add(colname) self.all_completions.add(colname)
@ -285,7 +284,7 @@ class SQLAutoComplete(object):
# dbmetadata['schema_name']['functions']['function_name'] should return # dbmetadata['schema_name']['functions']['function_name'] should return
# the function metadata namedtuple for the corresponding function # the function metadata namedtuple for the corresponding function
metadata = self.dbmetadata['functions'] metadata = self.dbmetadata["functions"]
for f in func_data: for f in func_data:
schema, func = self.escaped_names([f.schema_name, f.func_name]) schema, func = self.escaped_names([f.schema_name, f.func_name])
@ -309,10 +308,10 @@ class SQLAutoComplete(object):
self._arg_list_cache = \ self._arg_list_cache = \
dict((usage, dict((usage,
dict((meta, self._arg_list(meta, usage)) dict((meta, self._arg_list(meta, usage))
for sch, funcs in self.dbmetadata['functions'].items() for sch, funcs in self.dbmetadata["functions"].items()
for func, metas in funcs.items() for func, metas in funcs.items()
for meta in metas)) for meta in metas))
for usage in ('call', 'call_display', 'signature')) for usage in ("call", "call_display", "signature"))
def extend_foreignkeys(self, fk_data): def extend_foreignkeys(self, fk_data):
@ -322,7 +321,7 @@ class SQLAutoComplete(object):
# These are added as a list of ForeignKey namedtuples to the # These are added as a list of ForeignKey namedtuples to the
# ColumnMetadata namedtuple for both the child and parent # ColumnMetadata namedtuple for both the child and parent
meta = self.dbmetadata['tables'] meta = self.dbmetadata["tables"]
for fk in fk_data: for fk in fk_data:
e = self.escaped_names e = self.escaped_names
@ -350,7 +349,7 @@ class SQLAutoComplete(object):
# dbmetadata['datatypes'][schema_name][type_name] should store type # dbmetadata['datatypes'][schema_name][type_name] should store type
# metadata, such as composite type field names. Currently, we're not # metadata, such as composite type field names. Currently, we're not
# storing any metadata beyond typename, so just store None # storing any metadata beyond typename, so just store None
meta = self.dbmetadata['datatypes'] meta = self.dbmetadata["datatypes"]
for t in type_data: for t in type_data:
schema, type_name = self.escaped_names(t) schema, type_name = self.escaped_names(t)
@ -364,11 +363,11 @@ class SQLAutoComplete(object):
self.databases = [] self.databases = []
self.special_commands = [] self.special_commands = []
self.search_path = [] self.search_path = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, self.dbmetadata = \
'datatypes': {}} {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
self.all_completions = set(self.keywords + self.functions) self.all_completions = set(self.keywords + self.functions)
def find_matches(self, text, collection, mode='fuzzy', meta=None): def find_matches(self, text, collection, mode="strict", meta=None):
"""Find completion matches for the given text. """Find completion matches for the given text.
Given the user's input text and a collection of available Given the user's input text and a collection of available
@ -389,17 +388,26 @@ class SQLAutoComplete(object):
collection: collection:
mode: mode:
meta: meta:
meta_collection:
""" """
if not collection: if not collection:
return [] return []
prio_order = [ prio_order = [
'keyword', 'function', 'view', 'table', 'datatype', 'database', "keyword",
'schema', 'column', 'table alias', 'join', 'name join', 'fk join', "function",
'table format' "view",
"table",
"datatype",
"database",
"schema",
"column",
"table alias",
"join",
"name join",
"fk join",
"table format",
] ]
type_priority = prio_order.index(meta) if meta in prio_order else -1 type_priority = prio_order.index(meta) if meta in prio_order else -1
text = last_word(text, include='most_punctuations').lower() text = last_word(text, include="most_punctuations").lower()
text_len = len(text) text_len = len(text)
if text and text[0] == '"': if text and text[0] == '"':
@ -409,7 +417,7 @@ class SQLAutoComplete(object):
# Completion.position value is correct # Completion.position value is correct
text = text[1:] text = text[1:]
if mode == 'fuzzy': if mode == "fuzzy":
fuzzy = True fuzzy = True
priority_func = self.prioritizer.name_count priority_func = self.prioritizer.name_count
else: else:
@ -422,19 +430,20 @@ class SQLAutoComplete(object):
# Note: higher priority values mean more important, so use negative # Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple # signs to flip the direction of the tuple
if fuzzy: if fuzzy:
regex = '.*?'.join(map(re.escape, text)) regex = ".*?".join(map(re.escape, text))
pat = re.compile('(%s)' % regex) pat = re.compile("(%s)" % regex)
def _match(item): def _match(item):
if item.lower()[:len(text) + 1] in (text, text + ' '): if item.lower()[: len(text) + 1] in (text, text + " "):
# Exact match of first word in suggestion # Exact match of first word in suggestion
# This is to get exact alias matches to the top # This is to get exact alias matches to the top
# E.g. for input `e`, 'Entries E' should be on top # E.g. for input `e`, 'Entries E' should be on top
# (before e.g. `EndUsers EU`) # (before e.g. `EndUsers EU`)
return float('Infinity'), -1 return float("Infinity"), -1
r = pat.search(self.unescape_name(item.lower())) r = pat.search(self.unescape_name(item.lower()))
if r: if r:
return -len(r.group()), -r.start() return -len(r.group()), -r.start()
else: else:
match_end_limit = len(text) match_end_limit = len(text)
@ -446,7 +455,7 @@ class SQLAutoComplete(object):
if match_point >= 0: if match_point >= 0:
# Use negative infinity to force keywords to sort after all # Use negative infinity to force keywords to sort after all
# fuzzy matches # fuzzy matches
return -float('Infinity'), -match_point return -float("Infinity"), -match_point
matches = [] matches = []
for cand in collection: for cand in collection:
@ -466,7 +475,7 @@ class SQLAutoComplete(object):
if sort_key: if sort_key:
if display_meta and len(display_meta) > 50: if display_meta and len(display_meta) > 50:
# Truncate meta-text to 50 characters, if necessary # Truncate meta-text to 50 characters, if necessary
display_meta = display_meta[:47] + '...' display_meta = display_meta[:47] + "..."
# Lexical order of items in the collection, used for # Lexical order of items in the collection, used for
# tiebreaking items with the same match group length and start # tiebreaking items with the same match group length and start
@ -478,14 +487,18 @@ class SQLAutoComplete(object):
# We also use the unescape_name to make sure quoted names have # We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names. # the same priority as unquoted names.
lexical_priority = ( lexical_priority = (
tuple(0 if c in (' _') else -ord(c) tuple(0 if c in (" _") else -ord(c)
for c in self.unescape_name(item.lower())) + (1,) + for c in self.unescape_name(item.lower())) + (1,) +
tuple(c for c in item) tuple(c for c in item)
) )
priority = ( priority = (
sort_key, type_priority, prio, priority_func(item), sort_key,
prio2, lexical_priority type_priority,
prio,
priority_func(item),
prio2,
lexical_priority,
) )
matches.append( matches.append(
Match( Match(
@ -493,9 +506,9 @@ class SQLAutoComplete(object):
text=item, text=item,
start_position=-text_len, start_position=-text_len,
display_meta=display_meta, display_meta=display_meta,
display=display display=display,
), ),
priority=priority priority=priority,
) )
) )
return matches return matches
@ -516,8 +529,8 @@ class SQLAutoComplete(object):
matches.extend(matcher(self, suggestion, word_before_cursor)) matches.extend(matcher(self, suggestion, word_before_cursor))
# Sort matches so highest priorities are first # Sort matches so highest priorities are first
matches = sorted(matches, key=operator.attrgetter('priority'), matches = \
reverse=True) sorted(matches, key=operator.attrgetter("priority"), reverse=True)
result = dict() result = dict()
for m in matches: for m in matches:
@ -539,23 +552,28 @@ class SQLAutoComplete(object):
tables = suggestion.table_refs tables = suggestion.table_refs
do_qualify = suggestion.qualifiable and { do_qualify = suggestion.qualifiable and {
'always': True, 'never': False, "always": True,
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns] "never": False,
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
def qualify(col, tbl): def qualify(col, tbl):
return (tbl + '.' + col) if do_qualify else col return (tbl + '.' + col) if do_qualify else col
scoped_cols = self.populate_scoped_cols( scoped_cols = \
tables, suggestion.local_tables self.populate_scoped_cols(tables, suggestion.local_tables)
)
def make_cand(name, ref): def make_cand(name, ref):
synonyms = (name, generate_alias(name)) synonyms = (name, generate_alias(name))
return Candidate(qualify(name, ref), 0, 'column', synonyms) return Candidate(qualify(name, ref), 0, "column", synonyms)
def flat_cols(): def flat_cols():
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() return [
for c in cols] make_cand(c.name, t.ref)
for t, cols in scoped_cols.items()
for c in cols
]
if suggestion.require_last_table: if suggestion.require_last_table:
# require_last_table is used for 'tb11 JOIN tbl2 USING # require_last_table is used for 'tb11 JOIN tbl2 USING
# (...' which should # (...' which should
@ -569,10 +587,11 @@ class SQLAutoComplete(object):
dict((t, [col for col in cols if col.name in other_tbl_cols]) dict((t, [col for col in cols if col.name in other_tbl_cols])
for t, cols in scoped_cols.items() if t.ref == ltbl) for t, cols in scoped_cols.items() if t.ref == ltbl)
lastword = last_word(word_before_cursor, include='most_punctuations') lastword = last_word(word_before_cursor, include="most_punctuations")
if lastword == '*': if lastword == "*":
if suggestion.context == 'insert': if suggestion.context == "insert":
def is_scoped(col):
def filter(col):
if not col.has_default: if not col.has_default:
return True return True
return not any( return not any(
@ -580,40 +599,39 @@ class SQLAutoComplete(object):
for p in self.insert_col_skip_patterns for p in self.insert_col_skip_patterns
) )
scoped_cols = \ scoped_cols = \
dict((t, [col for col in cols if is_scoped(col)]) dict((t, [col for col in cols if filter(col)])
for t, cols in scoped_cols.items()) for t, cols in scoped_cols.items())
if self.asterisk_column_order == 'alphabetic': if self.asterisk_column_order == "alphabetic":
for cols in scoped_cols.values(): for cols in scoped_cols.values():
cols.sort(key=operator.attrgetter('name')) cols.sort(key=operator.attrgetter("name"))
if ( if (
lastword != word_before_cursor and lastword != word_before_cursor and
len(tables) == 1 and len(tables) == 1 and
word_before_cursor[-len(lastword) - 1] == '.' word_before_cursor[-len(lastword) - 1] == "."
): ):
# User typed x.*; replicate "x." for all columns except the # User typed x.*; replicate "x." for all columns except the
# first, which gets the original (as we only replace the "*"") # first, which gets the original (as we only replace the "*"")
sep = ', ' + word_before_cursor[:-1] sep = ", " + word_before_cursor[:-1]
collist = sep.join(c.completion for c in flat_cols()) collist = sep.join(c.completion for c in flat_cols())
else: else:
collist = ', '.join(qualify(c.name, t.ref) collist = ", ".join(qualify(c.name, t.ref)
for t, cs in scoped_cols.items() for t, cs in scoped_cols.items()
for c in cs) for c in cs)
return [Match( return [
completion=Completion( Match(
collist, completion=Completion(
-1, collist, -1, display_meta="columns", display="*"
display_meta='columns', ),
display='*' priority=(1, 1, 1),
), )
priority=(1, 1, 1) ]
)]
return self.find_matches(word_before_cursor, flat_cols(), return self.find_matches(word_before_cursor, flat_cols(),
mode='strict', meta='column') meta="column")
def alias(self, tbl, tbls): def alias(self, tbl, tbls):
""" Generate a unique table alias """Generate a unique table alias
tbl - name of the table to alias, quoted if it needs to be tbl - name of the table to alias, quoted if it needs to be
tbls - TableReference iterable of tables already in query tbls - TableReference iterable of tables already in query
""" """
@ -628,25 +646,6 @@ class SQLAutoComplete(object):
aliases = (tbl + str(i) for i in count(2)) aliases = (tbl + str(i) for i in count(2))
return next(a for a in aliases if normalize_ref(a) not in tbls) return next(a for a in aliases if normalize_ref(a) not in tbls)
def _check_for_aliases(self, left, refs, rtbl, suggestion, right):
"""
Check for generate aliases and return join value
:param left:
:param refs:
:param rtbl:
:param suggestion:
:param right:
:return: return join string.
"""
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
left.tbl, left.col, rtbl.ref, right.col, lref)
else:
join = '{0} ON {0}.{1} = {2}.{3}'.format(
left.tbl, left.col, rtbl.ref, right.col)
return join
def get_join_matches(self, suggestion, word_before_cursor): def get_join_matches(self, suggestion, word_before_cursor):
tbls = suggestion.table_refs tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls) cols = self.populate_scoped_cols(tbls)
@ -658,10 +657,12 @@ class SQLAutoComplete(object):
joins = [] joins = []
# Iterate over FKs in existing tables to find potential joins # Iterate over FKs in existing tables to find potential joins
fks = ( fks = (
(fk, rtbl, rcol) for rtbl, rcols in cols.items() (fk, rtbl, rcol)
for rcol in rcols for fk in rcol.foreignkeys for rtbl, rcols in cols.items()
for rcol in rcols
for fk in rcol.foreignkeys
) )
col = namedtuple('col', 'schema tbl col') col = namedtuple("col", "schema tbl col")
for fk, rtbl, rcol in fks: for fk, rtbl, rcol in fks:
right = col(rtbl.schema, rtbl.name, rcol.name) right = col(rtbl.schema, rtbl.name, rcol.name)
child = col(fk.childschema, fk.childtable, fk.childcolumn) child = col(fk.childschema, fk.childtable, fk.childcolumn)
@ -670,54 +671,38 @@ class SQLAutoComplete(object):
if suggestion.schema and left.schema != suggestion.schema: if suggestion.schema and left.schema != suggestion.schema:
continue continue
join = self._check_for_aliases(left, refs, rtbl, suggestion, right) if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
left.tbl, left.col, rtbl.ref, right.col, lref
)
else:
join = "{0} ON {0}.{1} = {2}.{3}".format(
left.tbl, left.col, rtbl.ref, right.col
)
alias = generate_alias(left.tbl) alias = generate_alias(left.tbl)
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format( synonyms = [
alias, left.col, rtbl.ref, right.col)] join,
"{0} ON {0}.{1} = {2}.{3}".format(
alias, left.col, rtbl.ref, right.col
),
]
# Schema-qualify if (1) new table in same schema as old, and old # Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public # is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and \ if not suggestion.schema and \
(qualified[normalize_ref(rtbl.ref)] and (qualified[normalize_ref(rtbl.ref)] and
left.schema == right.schema or left.schema == right.schema or
left.schema not in (right.schema, 'public')): left.schema not in (right.schema, "public")):
join = left.schema + '.' + join join = left.schema + "." + join
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
0 if (left.schema, left.tbl) in other_tbls else 1) 0 if (left.schema, left.tbl) in other_tbls else 1
joins.append(Candidate(join, prio, 'join', synonyms=synonyms)) )
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
return self.find_matches(word_before_cursor, joins, return self.find_matches(word_before_cursor, joins, meta="join")
mode='strict', meta='join')
def list_dict(self, pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
def add_cond(self, lcol, rcol, rref, prio, meta, **kwargs):
"""
Add Condition in join
:param lcol:
:param rcol:
:param rref:
:param prio:
:param meta:
:param kwargs:
:return:
"""
suggestion = kwargs['suggestion']
found_conds = kwargs['found_conds']
ltbl = kwargs['ltbl']
ref_prio = kwargs['ref_prio']
conds = kwargs['conds']
prefix = '' if suggestion.parent else ltbl.ref + '.'
cond = prefix + lcol + ' = ' + rref + '.' + rcol
if cond not in found_conds:
found_conds.add(cond)
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
def get_join_condition_matches(self, suggestion, word_before_cursor): def get_join_condition_matches(self, suggestion, word_before_cursor):
col = namedtuple('col', 'schema tbl col') col = namedtuple("col", "schema tbl col")
tbls = self.populate_scoped_cols(suggestion.table_refs).items tbls = self.populate_scoped_cols(suggestion.table_refs).items
cols = [(t, c) for t, cs in tbls() for c in cs] cols = [(t, c) for t, cs in tbls() for c in cs]
try: try:
@ -727,11 +712,24 @@ class SQLAutoComplete(object):
return [] return []
conds, found_conds = [], set() conds, found_conds = [], set()
def add_cond(lcol, rcol, rref, prio, meta):
prefix = "" if suggestion.parent else ltbl.ref + "."
cond = prefix + lcol + " = " + rref + "." + rcol
if cond not in found_conds:
found_conds.add(cond)
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
# Tables that are closer to the cursor get higher prio # Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) ref_prio = dict((tbl.ref, num)
for num, tbl in enumerate(suggestion.table_refs)) for num, tbl in enumerate(suggestion.table_refs))
# Map (schema, table, col) to tables # Map (schema, table, col) to tables
coldict = self.list_dict( coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
) )
# For each fk from the left table, generate a join condition if # For each fk from the left table, generate a join condition if
@ -742,89 +740,76 @@ class SQLAutoComplete(object):
child = col(fk.childschema, fk.childtable, fk.childcolumn) child = col(fk.childschema, fk.childtable, fk.childcolumn)
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left, right = (child, par) if left == child else (par, child) left, right = (child, par) if left == child else (par, child)
for rtbl in coldict[right]: for rtbl in coldict[right]:
kwargs = { add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
"suggestion": suggestion,
"found_conds": found_conds,
"ltbl": ltbl,
"conds": conds,
"ref_prio": ref_prio
}
self.add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join',
**kwargs)
# For name matching, use a {(colname, coltype): TableReference} dict # For name matching, use a {(colname, coltype): TableReference} dict
coltyp = namedtuple('coltyp', 'name datatype') coltyp = namedtuple("coltyp", "name datatype")
col_table = self.list_dict( col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
(coltyp(c.name, c.datatype), t) for t, c in cols)
# Find all name-match join conditions # Find all name-match join conditions
for c in (coltyp(c.name, c.datatype) for c in lcols): for c in (coltyp(c.name, c.datatype) for c in lcols):
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
kwargs = {
"suggestion": suggestion,
"found_conds": found_conds,
"ltbl": ltbl,
"conds": conds,
"ref_prio": ref_prio
}
prio = 1000 if c.datatype in ( prio = 1000 if c.datatype in (
'integer', 'bigint', 'smallint') else 0 "integer", "bigint", "smallint") else 0
self.add_cond(c.name, c.name, rtbl.ref, prio, 'name join', add_cond(c.name, c.name, rtbl.ref, prio, "name join")
**kwargs)
return self.find_matches(word_before_cursor, conds, return self.find_matches(word_before_cursor, conds, meta="join")
mode='strict', meta='join')
def get_function_matches(self, suggestion, word_before_cursor, def get_function_matches(self, suggestion, word_before_cursor,
alias=False): alias=False):
if suggestion.usage == 'from': if suggestion.usage == "from":
# Only suggest functions allowed in FROM clause # Only suggest functions allowed in FROM clause
def filt(f): def filt(f):
return not f.is_aggregate and not f.is_window return (
not f.is_aggregate and not f.is_window and
not f.is_extension and
(f.is_public or f.schema_name == suggestion.schema)
)
else: else:
alias = False alias = False
def filt(f): def filt(f):
return True return not f.is_extension and (
f.is_public or f.schema_name == suggestion.schema
)
arg_mode = { arg_mode = {"signature": "signature", "special": None}.get(
'signature': 'signature', suggestion.usage, "call"
'special': None,
}.get(suggestion.usage, 'call')
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode)
for f in self.populate_functions(suggestion.schema, filt)
) )
matches = self.find_matches(word_before_cursor, funcs, # Function overloading means we way have multiple functions of the same
mode='strict', meta='function') # name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode)
for f in all_functions
)
matches = self.find_matches(word_before_cursor, funcs, meta="function")
return matches return matches
def get_schema_matches(self, suggestion, word_before_cursor): def get_schema_matches(self, suggestion, word_before_cursor):
schema_names = self.dbmetadata['tables'].keys() schema_names = self.dbmetadata["tables"].keys()
# Unless we're sure the user really wants them, hide schema names # Unless we're sure the user really wants them, hide schema names
# starting with pg_, which are mostly temporary schemas # starting with pg_, which are mostly temporary schemas
if not word_before_cursor.startswith('pg_'): if not word_before_cursor.startswith("pg_"):
schema_names = [s schema_names = [s for s in schema_names if not s.startswith("pg_")]
for s in schema_names
if not s.startswith('pg_')]
if suggestion.quoted: if suggestion.quoted:
schema_names = [self.escape_schema(s) for s in schema_names] schema_names = [self.escape_schema(s) for s in schema_names]
return self.find_matches(word_before_cursor, schema_names, return self.find_matches(word_before_cursor, schema_names,
mode='strict', meta='schema') meta="schema")
def get_from_clause_item_matches(self, suggestion, word_before_cursor): def get_from_clause_item_matches(self, suggestion, word_before_cursor):
alias = self.generate_aliases alias = self.generate_aliases
s = suggestion s = suggestion
t_sug = Table(s.schema, s.table_refs, s.local_tables) t_sug = Table(s.schema, s.table_refs, s.local_tables)
v_sug = View(s.schema, s.table_refs) v_sug = View(s.schema, s.table_refs)
f_sug = Function(s.schema, s.table_refs, usage='from') f_sug = Function(s.schema, s.table_refs, usage="from")
return ( return (
self.get_table_matches(t_sug, word_before_cursor, alias) + self.get_table_matches(t_sug, word_before_cursor, alias) +
self.get_view_matches(v_sug, word_before_cursor, alias) + self.get_view_matches(v_sug, word_before_cursor, alias) +
@ -839,42 +824,43 @@ class SQLAutoComplete(object):
""" """
template = { template = {
'call': self.call_arg_style, "call": self.call_arg_style,
'call_display': self.call_arg_display_style, "call_display": self.call_arg_display_style,
'signature': self.signature_arg_style "signature": self.signature_arg_style,
}[usage] }[usage]
args = func.args() args = func.args()
if not template or ( if not template:
usage == 'call' and ( return "()"
len(args) < 2 or func.has_variadic())): elif usage == "call" and len(args) < 2:
return '()' return "()"
elif usage == "call" and func.has_variadic():
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max return "()"
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
max_arg_len = max(len(a.name) for a in args) if multiline else 0 max_arg_len = max(len(a.name) for a in args) if multiline else 0
args = ( args = (
self._format_arg(template, arg, arg_num + 1, max_arg_len) self._format_arg(template, arg, arg_num + 1, max_arg_len)
for arg_num, arg in enumerate(args) for arg_num, arg in enumerate(args)
) )
if multiline: if multiline:
return '(' + ','.join('\n ' + a for a in args if a) + '\n)' return "(" + ",".join("\n " + a for a in args if a) + "\n)"
else: else:
return '(' + ', '.join(a for a in args if a) + ')' return "(" + ", ".join(a for a in args if a) + ")"
def _format_arg(self, template, arg, arg_num, max_arg_len): def _format_arg(self, template, arg, arg_num, max_arg_len):
if not template: if not template:
return None return None
if arg.has_default: if arg.has_default:
arg_default = 'NULL' if arg.default is None else arg.default arg_default = "NULL" if arg.default is None else arg.default
# Remove trailing ::(schema.)type # Remove trailing ::(schema.)type
arg_default = arg_default_type_strip_regex.sub('', arg_default) arg_default = arg_default_type_strip_regex.sub("", arg_default)
else: else:
arg_default = '' arg_default = ""
return template.format( return template.format(
max_arg_len=max_arg_len, max_arg_len=max_arg_len,
arg_name=arg.name, arg_name=arg.name,
arg_num=arg_num, arg_num=arg_num,
arg_type=arg.datatype, arg_type=arg.datatype,
arg_default=arg_default arg_default=arg_default,
) )
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
@ -890,63 +876,60 @@ class SQLAutoComplete(object):
if do_alias: if do_alias:
alias = self.alias(cased_tbl, suggestion.table_refs) alias = self.alias(cased_tbl, suggestion.table_refs)
synonyms = (cased_tbl, generate_alias(cased_tbl)) synonyms = (cased_tbl, generate_alias(cased_tbl))
maybe_alias = (' ' + alias) if do_alias else '' maybe_alias = (" " + alias) if do_alias else ""
maybe_schema = (tbl.schema + '.') if tbl.schema else '' maybe_schema = (tbl.schema + ".") if tbl.schema else ""
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else '' suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
if arg_mode == 'call': if arg_mode == "call":
display_suffix = self._arg_list_cache['call_display'][tbl.meta] display_suffix = self._arg_list_cache["call_display"][tbl.meta]
elif arg_mode == 'signature': elif arg_mode == "signature":
display_suffix = self._arg_list_cache['signature'][tbl.meta] display_suffix = self._arg_list_cache["signature"][tbl.meta]
else: else:
display_suffix = '' display_suffix = ""
item = maybe_schema + cased_tbl + suffix + maybe_alias item = maybe_schema + cased_tbl + suffix + maybe_alias
display = maybe_schema + cased_tbl + display_suffix + maybe_alias display = maybe_schema + cased_tbl + display_suffix + maybe_alias
prio2 = 0 if tbl.schema else 1 prio2 = 0 if tbl.schema else 1
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
def get_table_matches(self, suggestion, word_before_cursor, alias=False): def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables') tables = self.populate_schema_objects(suggestion.schema, "tables")
tables.extend( tables.extend(
SchemaObject(tbl.name) for tbl in suggestion.local_tables) SchemaObject(tbl.name) for tbl in suggestion.local_tables)
# Unless we're sure the user really wants them, don't suggest the # Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path # pg_catalog tables that are implicitly on the search path
if not suggestion.schema and ( if not suggestion.schema and \
not word_before_cursor.startswith('pg_')): (not word_before_cursor.startswith("pg_")):
tables = [t for t in tables if not t.name.startswith('pg_')] tables = [t for t in tables if not t.name.startswith("pg_")]
tables = [self._make_cand(t, alias, suggestion) for t in tables] tables = [self._make_cand(t, alias, suggestion) for t in tables]
return self.find_matches(word_before_cursor, tables, return self.find_matches(word_before_cursor, tables, meta="table")
mode='strict', meta='table')
def get_view_matches(self, suggestion, word_before_cursor, alias=False): def get_view_matches(self, suggestion, word_before_cursor, alias=False):
views = self.populate_schema_objects(suggestion.schema, 'views') views = self.populate_schema_objects(suggestion.schema, "views")
if not suggestion.schema and ( if not suggestion.schema and (
not word_before_cursor.startswith('pg_')): not word_before_cursor.startswith("pg_")):
views = [v for v in views if not v.name.startswith('pg_')] views = [v for v in views if not v.name.startswith("pg_")]
views = [self._make_cand(v, alias, suggestion) for v in views] views = [self._make_cand(v, alias, suggestion) for v in views]
return self.find_matches(word_before_cursor, views, return self.find_matches(word_before_cursor, views, meta="view")
mode='strict', meta='view')
def get_alias_matches(self, suggestion, word_before_cursor): def get_alias_matches(self, suggestion, word_before_cursor):
aliases = suggestion.aliases aliases = suggestion.aliases
return self.find_matches(word_before_cursor, aliases, return self.find_matches(word_before_cursor, aliases,
mode='strict', meta='table alias') meta="table alias")
def get_database_matches(self, _, word_before_cursor): def get_database_matches(self, _, word_before_cursor):
return self.find_matches(word_before_cursor, self.databases, return self.find_matches(word_before_cursor, self.databases,
mode='strict', meta='database') meta="database")
def get_keyword_matches(self, suggestion, word_before_cursor): def get_keyword_matches(self, suggestion, word_before_cursor):
return self.find_matches(word_before_cursor, self.keywords, return self.find_matches(word_before_cursor, self.keywords,
mode='strict', meta='keyword') meta="keyword")
def get_datatype_matches(self, suggestion, word_before_cursor): def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes # suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes') types = self.populate_schema_objects(suggestion.schema, "datatypes")
types = [self._make_cand(t, False, suggestion) for t in types] types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types, matches = self.find_matches(word_before_cursor, types, meta="datatype")
mode='strict', meta='datatype')
return matches return matches
def get_word_before_cursor(self, word=False): def get_word_before_cursor(self, word=False):
@ -1004,52 +987,6 @@ class SQLAutoComplete(object):
Datatype: get_datatype_matches, Datatype: get_datatype_matches,
} }
def addcols(self, schema, rel, alias, reltype, cols, columns):
"""
Add columns in schema column list.
:param schema: Schema for reference.
:param rel:
:param alias:
:param reltype:
:param cols:
:param columns:
:return:
"""
tbl = TableReference(schema, rel, alias, reltype == 'functions')
if tbl not in columns:
columns[tbl] = []
columns[tbl].extend(cols)
def _get_schema_columns(self, schemas, tbl, meta, columns):
"""
Check and add schema table columns as per table.
:param schemas: Schema
:param tbl:
:param meta:
:param columns: column list
:return:
"""
for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta['functions'].get(schema, {}).get(relname)
for func in (functions or []):
# func is a FunctionMetadata object
cols = func.fields()
self.addcols(schema, relname, tbl.alias, 'functions', cols,
columns)
else:
for reltype in ('tables', 'views'):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
cols = cols.values()
self.addcols(schema, relname, tbl.alias, reltype, cols,
columns)
break
def populate_scoped_cols(self, scoped_tbls, local_tbls=()): def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
"""Find all columns in a set of scoped_tables. """Find all columns in a set of scoped_tables.
@ -1062,14 +999,37 @@ class SQLAutoComplete(object):
columns = OrderedDict() columns = OrderedDict()
meta = self.dbmetadata meta = self.dbmetadata
def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == "functions")
if tbl not in columns:
columns[tbl] = []
columns[tbl].extend(cols)
for tbl in scoped_tbls: for tbl in scoped_tbls:
# Local tables should shadow database tables # Local tables should shadow database tables
if tbl.schema is None and normalize_ref(tbl.name) in ctes: if tbl.schema is None and normalize_ref(tbl.name) in ctes:
cols = ctes[normalize_ref(tbl.name)] cols = ctes[normalize_ref(tbl.name)]
self.addcols(None, tbl.name, 'CTE', tbl.alias, cols, columns) addcols(None, tbl.name, "CTE", tbl.alias, cols)
continue continue
schemas = [tbl.schema] if tbl.schema else self.search_path schemas = [tbl.schema] if tbl.schema else self.search_path
self._get_schema_columns(schemas, tbl, meta, columns) for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta["functions"].get(schema, {}).get(relname)
for func in functions or []:
# func is a FunctionMetadata object
cols = func.fields()
addcols(schema, relname, tbl.alias, "functions", cols)
else:
for reltype in ("tables", "views"):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
cols = cols.values()
addcols(schema, relname, tbl.alias, reltype, cols)
break
return columns return columns
@ -1125,10 +1085,10 @@ class SQLAutoComplete(object):
SchemaObject( SchemaObject(
name=func, name=func,
schema=(self._maybe_schema(schema=sch, parent=schema)), schema=(self._maybe_schema(schema=sch, parent=schema)),
meta=meta meta=meta,
) )
for sch in self._get_schemas('functions', schema) for sch in self._get_schemas("functions", schema)
for (func, metas) in self.dbmetadata['functions'][sch].items() for (func, metas) in self.dbmetadata["functions"][sch].items()
for meta in metas for meta in metas
if filter_func(meta) if filter_func(meta)
] ]
@ -1234,6 +1194,7 @@ class SQLAutoComplete(object):
row['is_aggregate'], row['is_aggregate'],
row['is_window'], row['is_window'],
row['is_set_returning'], row['is_set_returning'],
row['is_extension'],
row['arg_defaults'].strip('{}').split(',') row['arg_defaults'].strip('{}').split(',')
if row['arg_defaults'] is not None if row['arg_defaults'] is not None
else row['arg_defaults'] else row['arg_defaults']

View File

@ -1,7 +1,6 @@
""" """
Using Completion class from Using Completion class from
https://github.com/jonathanslenders/python-prompt-toolkit/ https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/completion/base.py
blob/master/prompt_toolkit/completion.py
""" """
__all__ = ( __all__ = (
@ -38,7 +37,7 @@ class Completion(object):
assert self.start_position <= 0 assert self.start_position <= 0
def __repr__(self): def __repr__(self):
return '%s(text=%r, start_position=%r)' % ( return "%s(text=%r, start_position=%r)" % (
self.__class__.__name__, self.text, self.start_position) self.__class__.__name__, self.text, self.start_position)
def __eq__(self, other): def __eq__(self, other):

View File

@ -1,295 +0,0 @@
import re
from collections import namedtuple
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space.
'all_punctuations': re.compile('([^\s]+)$'),
}
def last_word(text, include='alphanum_underscore'):
"""
Find the last word in a sentence.
>>> last_word('abc')
'abc'
>>> last_word(' abc')
'abc'
>>> last_word('')
''
>>> last_word(' ')
''
>>> last_word('abc ')
''
>>> last_word('abc def')
'def'
>>> last_word('abc def ')
''
>>> last_word('abc def;')
''
>>> last_word('bac $def')
'def'
>>> last_word('bac $def', include='most_punctuations')
'$def'
>>> last_word('bac \def', include='most_punctuations')
'\\\\def'
>>> last_word('bac \def;', include='most_punctuations')
'\\\\def;'
>>> last_word('bac::def', include='most_punctuations')
'def'
>>> last_word('"foo*bar', include='most_punctuations')
'"foo*bar'
"""
if not text: # Empty string
return ''
if text[-1].isspace():
return ''
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
TableReference = namedtuple(
'TableReference', ['schema', 'name', 'alias', 'is_function']
)
# This code is borrowed from sqlparse example script.
# <url>
def is_subselect(parsed):
if not parsed.is_group():
return False
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in sql_type:
return True
return False
def _identifier_is_function(identifier):
return any(isinstance(t, Function) for t in identifier.tokens)
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This
# causes the second FROM to trigger this elif condition resulting
# in a StopIteration. So we need to ignore the keyword if the
# keyword FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its
# variants INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
item_val.endswith('JOIN')):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
tbl_prefix_seen = True
break
def extract_table_identifiers(token_stream, allow_functions=True):
"""yields tuples of TableReference namedtuples"""
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = (allow_functions and
_identifier_is_function(identifier))
except AttributeError:
continue
if real_name:
yield TableReference(
schema_name, real_name, identifier.get_alias(),
is_function
)
elif isinstance(item, Identifier):
real_name = item.get_real_name()
schema_name = item.get_parent_name()
is_function = allow_functions and _identifier_is_function(item)
if real_name:
yield TableReference(
schema_name, real_name, item.get_alias(), is_function
)
else:
name = item.get_name()
yield TableReference(
None, name, item.get_alias() or name, is_function
)
elif isinstance(item, Function):
yield TableReference(
None, item.get_real_name(), item.get_alias(), allow_functions
)
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
"""Extract the table names from an SQL statment.
Returns a list of TableReference namedtuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return ()
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
# function calls due to the parenthesized column list, e.g. interprets
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(
stream, allow_functions=not insert_stmt
)
return tuple(identifiers)
def find_prev_keyword(sql):
""" Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child
# token inside a TokenList, in which case token_index thows an
# error Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
# p.token_index(t) # Throws ValueError: not in list
idx = flattened.index(t)
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx + 1])
return t, text
return None, ''
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
tokens = list(parsed.flatten())
i = 0
while i < len(tokens):
tok = tokens[i]
if tok.match(Token.Error, "'"):
# An unmatched single quote
return True
elif (tok.ttype in Token.Name.Builtin and
dollar_quote_regex.match(tok.value)):
# Find the matching closing dollar quote sign
for (j, tok2) in enumerate(tokens[i + 1:], i + 1):
if tok2.match(Token.Name.Builtin, tok.value):
# Found the matching closing quote - continue our scan
# for open quotes thereafter
i = j
break
else:
# No matching dollar sign quote
return True
i += 1
return False
def parse_partial_identifier(word):
"""Attempt to parse a (partially typed) word as an identifier
word may include a schema qualification, like `schema_name.partial_name`
or `schema_name.` There may also be unclosed quotation marks, like
`"schema`, or `schema."partial_name`
:param word: string representing a (partially complete) identifier
:return: sqlparse.sql.Identifier, or None
"""
p = sqlparse.parse(word)[0]
n_tok = len(p.tokens)
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
return p.tokens[0]
elif hasattr(p, 'token_next_match') and p.token_next_match(0, Error, '"'):
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
# Close the double quote, then reparse
return parse_partial_identifier(word + '"')
else:
return None
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print(extract_tables(sql))

View File

@ -4,7 +4,7 @@ import sqlparse
def query_starts_with(query, prefixes): def query_starts_with(query, prefixes):
"""Check if the query starts with any item from *prefixes*.""" """Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes] prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True) formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
@ -18,5 +18,5 @@ def queries_start_with(queries, prefixes):
def is_destructive(queries): def is_destructive(queries):
"""Returns if any of the queries in *queries* is destructive.""" """Returns if any of the queries in *queries* is destructive."""
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') keywords = ("drop", "shutdown", "delete", "truncate", "alter")
return queries_start_with(queries, keywords) return queries_start_with(queries, keywords)

View File

@ -10,12 +10,11 @@ from .meta import TableMetadata, ColumnMetadata
# columns: list of column names # columns: list of column names
# start: index into the original string of the left parens starting the CTE # start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE # stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple('TableExpression', 'name columns start stop') TableExpression = namedtuple("TableExpression", "name columns start stop")
def isolate_query_ctes(full_text, text_before_cursor): def isolate_query_ctes(full_text, text_before_cursor):
"""Simplify a query by converting CTEs into table metadata objects """Simplify a query by converting CTEs into table metadata objects"""
"""
if not full_text: if not full_text:
return full_text, text_before_cursor, tuple() return full_text, text_before_cursor, tuple()
@ -30,8 +29,8 @@ def isolate_query_ctes(full_text, text_before_cursor):
for cte in ctes: for cte in ctes:
if cte.start < current_position < cte.stop: if cte.start < current_position < cte.stop:
# Currently editing a cte - treat its body as the current full_text # Currently editing a cte - treat its body as the current full_text
text_before_cursor = full_text[cte.start:current_position] text_before_cursor = full_text[cte.start: current_position]
full_text = full_text[cte.start:cte.stop] full_text = full_text[cte.start: cte.stop]
return full_text, text_before_cursor, meta return full_text, text_before_cursor, meta
# Append this cte to the list of available table metadata # Append this cte to the list of available table metadata
@ -40,19 +39,19 @@ def isolate_query_ctes(full_text, text_before_cursor):
# Editing past the last cte (ie the main body of the query) # Editing past the last cte (ie the main body of the query)
full_text = full_text[ctes[-1].stop:] full_text = full_text[ctes[-1].stop:]
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position] text_before_cursor = text_before_cursor[ctes[-1].stop: current_position]
return full_text, text_before_cursor, tuple(meta) return full_text, text_before_cursor, tuple(meta)
def extract_ctes(sql): def extract_ctes(sql):
""" Extract constant table expresseions from a query """Extract constant table expresseions from a query
Returns tuple (ctes, remainder_sql) Returns tuple (ctes, remainder_sql)
ctes is a list of TableExpression namedtuples ctes is a list of TableExpression namedtuples
remainder_sql is the text from the original query after the CTEs have remainder_sql is the text from the original query after the CTEs have
been stripped. been stripped.
""" """
p = parse(sql)[0] p = parse(sql)[0]
@ -66,7 +65,7 @@ def extract_ctes(sql):
# Get the next (meaningful) token, which should be the first CTE # Get the next (meaningful) token, which should be the first CTE
idx, tok = p.token_next(idx) idx, tok = p.token_next(idx)
if not tok: if not tok:
return ([], '') return ([], "")
start_pos = token_start_pos(p.tokens, idx) start_pos = token_start_pos(p.tokens, idx)
ctes = [] ctes = []
@ -87,7 +86,7 @@ def extract_ctes(sql):
idx = p.token_index(tok) + 1 idx = p.token_index(tok) + 1
# Collapse everything after the ctes into a remainder query # Collapse everything after the ctes into a remainder query
remainder = ''.join(str(tok) for tok in p.tokens[idx:]) remainder = "".join(str(tok) for tok in p.tokens[idx:])
return ctes, remainder return ctes, remainder
@ -112,15 +111,15 @@ def get_cte_from_token(tok, pos0):
def extract_column_names(parsed): def extract_column_names(parsed):
# Find the first DML token to check if it's a SELECT or # Find the first DML token to check if it's a
# INSERT/UPDATE/DELETE # SELECT or INSERT/UPDATE/DELETE
idx, tok = parsed.token_next_by(t=DML) idx, tok = parsed.token_next_by(t=DML)
tok_val = tok and tok.value.lower() tok_val = tok and tok.value.lower()
if tok_val in ('insert', 'update', 'delete'): if tok_val in ("insert", "update", "delete"):
# Jump ahead to the RETURNING clause where the list of column names is # Jump ahead to the RETURNING clause where the list of column names is
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
elif not tok_val == 'select': elif not tok_val == "select":
# Must be invalid CTE # Must be invalid CTE
return () return ()

View File

@ -1,23 +1,29 @@
from collections import namedtuple from collections import namedtuple
_ColumnMetadata = namedtuple( _ColumnMetadata = namedtuple(
'ColumnMetadata', "ColumnMetadata", ["name", "datatype", "foreignkeys", "default",
['name', 'datatype', 'foreignkeys', 'default', 'has_default'] "has_default"]
) )
def ColumnMetadata( def ColumnMetadata(name, datatype, foreignkeys=None, default=None,
name, datatype, foreignkeys=None, default=None, has_default=False has_default=False):
): return _ColumnMetadata(name, datatype, foreignkeys or [], default,
return _ColumnMetadata( has_default)
name, datatype, foreignkeys or [], default, has_default
)
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', ForeignKey = namedtuple(
'parentcolumn', 'childschema', "ForeignKey",
'childtable', 'childcolumn']) [
TableMetadata = namedtuple('TableMetadata', 'name columns') "parentschema",
"parenttable",
"parentcolumn",
"childschema",
"childtable",
"childcolumn",
],
)
TableMetadata = namedtuple("TableMetadata", "name columns")
def parse_defaults(defaults_string): def parse_defaults(defaults_string):
@ -25,34 +31,42 @@ def parse_defaults(defaults_string):
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
if not defaults_string: if not defaults_string:
return return
current = '' current = ""
in_quote = None in_quote = None
for char in defaults_string: for char in defaults_string:
if current == '' and char == ' ': if current == "" and char == " ":
# Skip space after comma separating default expressions # Skip space after comma separating default expressions
continue continue
if char == '"' or char == '\'': if char == '"' or char == "'":
if in_quote and char == in_quote: if in_quote and char == in_quote:
# End quote # End quote
in_quote = None in_quote = None
elif not in_quote: elif not in_quote:
# Begin quote # Begin quote
in_quote = char in_quote = char
elif char == ',' and not in_quote: elif char == "," and not in_quote:
# End of expression # End of expression
yield current yield current
current = '' current = ""
continue continue
current += char current += char
yield current yield current
class FunctionMetadata(object): class FunctionMetadata(object):
def __init__( def __init__(
self, schema_name, func_name, arg_names, arg_types, arg_modes, self,
return_type, is_aggregate, is_window, is_set_returning, schema_name,
arg_defaults func_name,
arg_names,
arg_types,
arg_modes,
return_type,
is_aggregate,
is_window,
is_set_returning,
is_extension,
arg_defaults,
): ):
"""Class for describing a postgresql function""" """Class for describing a postgresql function"""
@ -80,19 +94,29 @@ class FunctionMetadata(object):
self.is_aggregate = is_aggregate self.is_aggregate = is_aggregate
self.is_window = is_window self.is_window = is_window
self.is_set_returning = is_set_returning self.is_set_returning = is_set_returning
self.is_extension = bool(is_extension)
self.is_public = self.schema_name and self.schema_name == "public"
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, self.__class__) and return isinstance(other, self.__class__) and \
self.__dict__ == other.__dict__) self.__dict__ == other.__dict__
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def _signature(self): def _signature(self):
return ( return (
self.schema_name, self.func_name, self.arg_names, self.arg_types, self.schema_name,
self.arg_modes, self.return_type, self.is_aggregate, self.func_name,
self.is_window, self.is_set_returning, self.arg_defaults self.arg_names,
self.arg_types,
self.arg_modes,
self.return_type,
self.is_aggregate,
self.is_window,
self.is_set_returning,
self.is_extension,
self.arg_defaults,
) )
def __hash__(self): def __hash__(self):
@ -100,26 +124,25 @@ class FunctionMetadata(object):
def __repr__(self): def __repr__(self):
return ( return (
( "%s(schema_name=%r, func_name=%r, arg_names=%r, "
'%s(schema_name=%r, func_name=%r, arg_names=%r, ' "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' "is_window=%r, is_set_returning=%r, is_extension=%r, "
'is_window=%r, is_set_returning=%r, arg_defaults=%r)' "arg_defaults=%r)"
) % (self.__class__.__name__,) + self._signature() ) % ((self.__class__.__name__,) + self._signature())
)
def has_variadic(self): def has_variadic(self):
return self.arg_modes and any( return self.arg_modes and \
arg_mode == 'v' for arg_mode in self.arg_modes) any(arg_mode == "v" for arg_mode in self.arg_modes)
def args(self): def args(self):
"""Returns a list of input-parameter ColumnMetadata namedtuples.""" """Returns a list of input-parameter ColumnMetadata namedtuples."""
if not self.arg_names: if not self.arg_names:
return [] return []
modes = self.arg_modes or ['i'] * len(self.arg_names) modes = self.arg_modes or ["i"] * len(self.arg_names)
args = [ args = [
(name, typ) (name, typ)
for name, typ, mode in zip(self.arg_names, self.arg_types, modes) for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
] ]
def arg(name, typ, num): def arg(name, typ, num):
@ -127,7 +150,8 @@ class FunctionMetadata(object):
num_defaults = len(self.arg_defaults) num_defaults = len(self.arg_defaults)
has_default = num + num_defaults >= num_args has_default = num + num_defaults >= num_args
default = ( default = (
self.arg_defaults[num - num_args + num_defaults] if has_default self.arg_defaults[num - num_args + num_defaults]
if has_default
else None else None
) )
return ColumnMetadata(name, typ, [], default, has_default) return ColumnMetadata(name, typ, [], default, has_default)
@ -137,7 +161,7 @@ class FunctionMetadata(object):
def fields(self): def fields(self):
"""Returns a list of output-field ColumnMetadata namedtuples""" """Returns a list of output-field ColumnMetadata namedtuples"""
if self.return_type.lower() == 'void': if self.return_type.lower() == "void":
return [] return []
elif not self.arg_modes: elif not self.arg_modes:
# For functions without output parameters, the function name # For functions without output parameters, the function name
@ -145,7 +169,9 @@ class FunctionMetadata(object):
# E.g. 'SELECT unnest FROM unnest(...);' # E.g. 'SELECT unnest FROM unnest(...);'
return [ColumnMetadata(self.func_name, self.return_type, [])] return [ColumnMetadata(self.func_name, self.return_type, [])]
return [ColumnMetadata(name, typ, []) return [
for name, typ, mode in zip( ColumnMetadata(name, typ, [])
self.arg_names, self.arg_types, self.arg_modes) for name, typ, mode in zip(self.arg_names, self.arg_types,
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE self.arg_modes)
if mode in ("o", "b", "t")
] # OUT, INOUT, TABLE

View File

@ -3,12 +3,15 @@ from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation from sqlparse.tokens import Keyword, DML, Punctuation
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', TableReference = namedtuple(
'is_function']) "TableReference", ["schema", "name", "alias", "is_function"]
)
TableReference.ref = property( TableReference.ref = property(
lambda self: self.alias or ( lambda self: self.alias or (
self.name if self.name.islower() or self.name[0] == '"' self.name
else '"' + self.name + '"') if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'
)
) )
@ -18,9 +21,13 @@ def is_subselect(parsed):
if not parsed.is_group: if not parsed.is_group:
return False return False
for item in parsed.tokens: for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', if item.ttype is DML and item.value.upper() in (
'UPDATE', 'CREATE', "SELECT",
'DELETE'): "INSERT",
"UPDATE",
"CREATE",
"DELETE",
):
return True return True
return False return False
@ -37,32 +44,42 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for x in extract_from_part(item, stop_at_punctuation): for x in extract_from_part(item, stop_at_punctuation):
yield x yield x
elif stop_at_punctuation and item.ttype is Punctuation: elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration return
# An incomplete nested select won't be recognized correctly as a # An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a # the second FROM to trigger this elif condition resulting in a
# StopIteration. So we need to ignore the keyword if the keyword # `return`. So we need to ignore the keyword if the keyword
# FROM. # FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif # Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants # condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc. # INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and ( elif (
not item.value.upper() == 'FROM') and ( item.ttype is Keyword and
not item.value.upper().endswith('JOIN')): (not item.value.upper() == "FROM") and
(not item.value.upper().endswith("JOIN"))
):
tbl_prefix_seen = False tbl_prefix_seen = False
else: else:
yield item yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML: elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper() item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or if (
item_val.endswith('JOIN')): item_val
in (
"COPY",
"FROM",
"INTO",
"UPDATE",
"TABLE",
) or item_val.endswith("JOIN")
):
tbl_prefix_seen = True tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list. # 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary. # So this check here is necessary.
elif isinstance(item, IdentifierList): elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers(): for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and if identifier.ttype is Keyword and \
identifier.value.upper() == 'FROM'): identifier.value.upper() == "FROM":
tbl_prefix_seen = True tbl_prefix_seen = True
break break
@ -94,29 +111,35 @@ def extract_table_identifiers(token_stream, allow_functions=True):
name = name.lower() name = name.lower()
return schema_name, name, alias return schema_name, name, alias
for item in token_stream: try:
if isinstance(item, IdentifierList): for item in token_stream:
for identifier in item.get_identifiers(): if isinstance(item, IdentifierList):
# Sometimes Keywords (such as FROM ) are classified as for identifier in item.get_identifiers():
# identifiers which don't have the get_real_name() method. # Sometimes Keywords (such as FROM ) are classified as
try: # identifiers which don't have the get_real_name() method.
schema_name = identifier.get_parent_name() try:
real_name = identifier.get_real_name() schema_name = identifier.get_parent_name()
is_function = (allow_functions and real_name = identifier.get_real_name()
_identifier_is_function(identifier)) is_function = allow_functions and \
except AttributeError: _identifier_is_function(identifier)
continue except AttributeError:
if real_name: continue
yield TableReference(schema_name, real_name, if real_name:
identifier.get_alias(), is_function) yield TableReference(
elif isinstance(item, Identifier): schema_name, real_name, identifier.get_alias(),
schema_name, real_name, alias = parse_identifier(item) is_function
is_function = allow_functions and _identifier_is_function(item) )
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
yield TableReference(schema_name, real_name, alias, is_function) yield TableReference(schema_name, real_name, alias,
elif isinstance(item, Function): is_function)
schema_name, real_name, alias = parse_identifier(item) elif isinstance(item, Function):
yield TableReference(None, real_name, alias, allow_functions) schema_name, real_name, alias = parse_identifier(item)
yield TableReference(None, real_name, alias, allow_functions)
except StopIteration:
return
# extract_tables is inspired from examples in the sqlparse lib. # extract_tables is inspired from examples in the sqlparse lib.
@ -134,7 +157,7 @@ def extract_tables(sql):
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then # abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names. # we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert' insert_stmt = parsed[0].token_first().value.lower() == "insert"
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as # Kludge: sqlparse mistakenly identifies insert statements as

View File

@ -5,17 +5,17 @@ from sqlparse.tokens import Token, Error
cleanup_regex = { cleanup_regex = {
# This matches only alphanumerics and underscores. # This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'), "alphanum_underscore": re.compile(r"(\w+)$"),
# This matches everything except spaces, parens, colon, and comma # This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'), "many_punctuations": re.compile(r"([^():,\s]+)$"),
# This matches everything except spaces, parens, colon, comma, and period # This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'), "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
# This matches everything except a space. # This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'), "all_punctuations": re.compile(r"([^\s]+)$"),
} }
def last_word(text, include='alphanum_underscore'): def last_word(text, include="alphanum_underscore"):
r""" r"""
Find the last word in a sentence. Find the last word in a sentence.
@ -49,41 +49,42 @@ def last_word(text, include='alphanum_underscore'):
'"foo*bar' '"foo*bar'
""" """
if not text: # Empty string if not text: # Empty string
return '' return ""
if text[-1].isspace(): if text[-1].isspace():
return '' return ""
else: else:
regex = cleanup_regex[include] regex = cleanup_regex[include]
matches = regex.search(text) matches = regex.search(text)
if matches: if matches:
return matches.group(0) return matches.group(0)
else: else:
return '' return ""
def find_prev_keyword(sql, n_skip=0): def find_prev_keyword(sql, n_skip=0):
""" Find the last sql keyword in an SQL statement """Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped everything after the last keyword stripped
""" """
if not sql.strip(): if not sql.strip():
return None, '' return None, ""
parsed = sqlparse.parse(sql)[0] parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten()) flattened = list(parsed.flatten())
flattened = flattened[:len(flattened) - n_skip] flattened = flattened[: len(flattened) - n_skip]
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') logical_operators = ("AND", "OR", "NOT", "BETWEEN")
for t in reversed(flattened): for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and ( if t.value == "(" or (
t.value.upper() not in logical_operators)): t.is_keyword and (t.value.upper() not in logical_operators)
):
# Find the location of token t in the original parsed statement # Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token # We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error # inside a TokenList, in which case token_index throws an error
# Minimal example: # Minimal example:
# p = sqlparse.parse('select * from foo where bar') # p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token # t = list(p.flatten())[-3] # The "Where" token
@ -93,14 +94,14 @@ def find_prev_keyword(sql, n_skip=0):
# Combine the string values of all tokens in the original list # Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a # up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed # query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx + 1]) text = "".join(tok.value for tok in flattened[: idx + 1])
return t, text return t, text
return None, '' return None, ""
# Postgresql dollar quote signs look like `$$` or `$tag$` # Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$') dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
def is_open_quote(sql): def is_open_quote(sql):

View File

@ -4,13 +4,13 @@ from sqlparse.tokens import Name
from collections import defaultdict from collections import defaultdict
white_space_regex = re.compile(r'\\s+', re.MULTILINE) white_space_regex = re.compile("\\s+", re.MULTILINE)
def _compile_regex(keyword): def _compile_regex(keyword):
# Surround the keyword with word boundaries and replace interior whitespace # Surround the keyword with word boundaries and replace interior whitespace
# with whitespace wildcards # with whitespace wildcards
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b' pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
return re.compile(pattern, re.MULTILINE | re.IGNORECASE) return re.compile(pattern, re.MULTILINE | re.IGNORECASE)

View File

@ -3,28 +3,29 @@ import re
import sqlparse import sqlparse
from collections import namedtuple from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import ( from .parseutils.utils import last_word, find_prev_keyword,\
last_word, find_prev_keyword, parse_partial_identifier) parse_partial_identifier
from .parseutils.tables import extract_tables from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes from .parseutils.ctes import isolate_query_ctes
Special = namedtuple('Special', [])
Database = namedtuple('Database', []) Special = namedtuple("Special", [])
Schema = namedtuple('Schema', ['quoted']) Database = namedtuple("Database", [])
Schema = namedtuple("Schema", ["quoted"])
Schema.__new__.__defaults__ = (False,) Schema.__new__.__defaults__ = (False,)
# FromClauseItem is a table/view/function used in the FROM clause # FromClauseItem is a table/view/function used in the FROM clause
# `table_refs` contains the list of tables/... already in the statement, # `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique # used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables') FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables']) Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
TableFormat = namedtuple('TableFormat', []) TableFormat = namedtuple("TableFormat", [])
View = namedtuple('View', ['schema', 'table_refs']) View = namedtuple("View", ["schema", "table_refs"])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' # JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent']) JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' # Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple('Join', ['table_refs', 'schema']) Join = namedtuple("Join", ["table_refs", "schema"])
Function = namedtuple('Function', ['schema', 'table_refs', 'usage']) Function = namedtuple("Function", ["schema", "table_refs", "usage"])
# For convenience, don't require the `usage` argument in Function constructor # For convenience, don't require the `usage` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None) Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple(), tuple()) Table.__new__.__defaults__ = (None, tuple(), tuple())
@ -32,31 +33,33 @@ View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
Column = namedtuple( Column = namedtuple(
'Column', "Column",
['table_refs', 'require_last_table', 'local_tables', ["table_refs", "require_last_table", "local_tables", "qualifiable",
'qualifiable', 'context'] "context"],
) )
Column.__new__.__defaults__ = (None, None, tuple(), False, None) Column.__new__.__defaults__ = (None, None, tuple(), False, None)
Keyword = namedtuple('Keyword', ['last_token']) Keyword = namedtuple("Keyword", ["last_token"])
Keyword.__new__.__defaults__ = (None,) Keyword.__new__.__defaults__ = (None,)
NamedQuery = namedtuple('NamedQuery', []) NamedQuery = namedtuple("NamedQuery", [])
Datatype = namedtuple('Datatype', ['schema']) Datatype = namedtuple("Datatype", ["schema"])
Alias = namedtuple('Alias', ['aliases']) Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple('Path', []) Path = namedtuple("Path", [])
class SqlStatement(object): class SqlStatement(object):
def __init__(self, full_text, text_before_cursor): def __init__(self, full_text, text_before_cursor):
self.identifier = None self.identifier = None
self.word_before_cursor = word_before_cursor = last_word( self.word_before_cursor = word_before_cursor = last_word(
text_before_cursor, include='many_punctuations') text_before_cursor, include="many_punctuations"
)
full_text = _strip_named_query(full_text) full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor) text_before_cursor = _strip_named_query(text_before_cursor)
full_text, text_before_cursor, self.local_tables = \ full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
isolate_query_ctes(full_text, text_before_cursor) full_text, text_before_cursor
)
self.text_before_cursor_including_last_word = text_before_cursor self.text_before_cursor_including_last_word = text_before_cursor
@ -67,39 +70,41 @@ class SqlStatement(object):
# completion useless because it will always return the list of # completion useless because it will always return the list of
# keywords as completion. # keywords as completion.
if self.word_before_cursor: if self.word_before_cursor:
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
parsed = sqlparse.parse(text_before_cursor) parsed = sqlparse.parse(text_before_cursor)
else: else:
text_before_cursor = \ text_before_cursor = \
text_before_cursor[:-len(word_before_cursor)] text_before_cursor[: -len(word_before_cursor)]
parsed = sqlparse.parse(text_before_cursor) parsed = sqlparse.parse(text_before_cursor)
self.identifier = parse_partial_identifier(word_before_cursor) self.identifier = parse_partial_identifier(word_before_cursor)
else: else:
parsed = sqlparse.parse(text_before_cursor) parsed = sqlparse.parse(text_before_cursor)
full_text, text_before_cursor, parsed = \ full_text, text_before_cursor, parsed = _split_multiple_statements(
_split_multiple_statements(full_text, text_before_cursor, parsed) full_text, text_before_cursor, parsed
)
self.full_text = full_text self.full_text = full_text
self.text_before_cursor = text_before_cursor self.text_before_cursor = text_before_cursor
self.parsed = parsed self.parsed = parsed
self.last_token = \ self.last_token = parsed and \
parsed and parsed.token_prev(len(parsed.tokens))[1] or '' parsed.token_prev(len(parsed.tokens))[1] or ""
def is_insert(self): def is_insert(self):
return self.parsed.token_first().value.lower() == 'insert' return self.parsed.token_first().value.lower() == "insert"
def get_tables(self, scope='full'): def get_tables(self, scope="full"):
""" Gets the tables available in the statement. """Gets the tables available in the statement.
param `scope:` possible values: 'full', 'insert', 'before' param `scope:` possible values: 'full', 'insert', 'before'
If 'insert', only the first table is returned. If 'insert', only the first table is returned.
If 'before', only tables before the cursor are returned. If 'before', only tables before the cursor are returned.
If not 'insert' and the stmt is an insert, the first table is skipped. If not 'insert' and the stmt is an insert, the first table is skipped.
""" """
tables = extract_tables( tables = extract_tables(
self.full_text if scope == 'full' else self.text_before_cursor) self.full_text if scope == "full" else self.text_before_cursor
if scope == 'insert': )
if scope == "insert":
tables = tables[:1] tables = tables[:1]
elif self.is_insert(): elif self.is_insert():
tables = tables[1:] tables = tables[1:]
@ -118,8 +123,9 @@ class SqlStatement(object):
return schema return schema
def reduce_to_prev_keyword(self, n_skip=0): def reduce_to_prev_keyword(self, n_skip=0):
prev_keyword, self.text_before_cursor = \ prev_keyword, self.text_before_cursor = find_prev_keyword(
find_prev_keyword(self.text_before_cursor, n_skip=n_skip) self.text_before_cursor, n_skip=n_skip
)
return prev_keyword return prev_keyword
@ -131,7 +137,7 @@ def suggest_type(full_text, text_before_cursor):
A scope for a column category will be a list of tables. A scope for a column category will be a list of tables.
""" """
if full_text.startswith('\\i '): if full_text.startswith("\\i "):
return (Path(),) return (Path(),)
# This is a temporary hack; the exception handling # This is a temporary hack; the exception handling
@ -144,7 +150,7 @@ def suggest_type(full_text, text_before_cursor):
return suggest_based_on_last_token(stmt.last_token, stmt) return suggest_based_on_last_token(stmt.last_token, stmt)
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
def _strip_named_query(txt): def _strip_named_query(txt):
@ -155,11 +161,11 @@ def _strip_named_query(txt):
""" """
if named_query_regex.match(txt): if named_query_regex.match(txt):
txt = named_query_regex.sub('', txt) txt = named_query_regex.sub("", txt)
return txt return txt
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M) function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
def _find_function_body(text): def _find_function_body(text):
@ -205,12 +211,12 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
return full_text, text_before_cursor, None return full_text, text_before_cursor, None
token2 = None token2 = None
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'): if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
token1 = statement.token_first() token1 = statement.token_first()
if token1: if token1:
token1_idx = statement.token_index(token1) token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1] token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == 'FUNCTION': if token2 and token2.value.upper() == "FUNCTION":
full_text, text_before_cursor, statement = _statement_from_function( full_text, text_before_cursor, statement = _statement_from_function(
full_text, text_before_cursor, statement full_text, text_before_cursor, statement
) )
@ -246,9 +252,9 @@ def suggest_based_on_last_token(token, stmt):
# SELECT Identifier <CURSOR> # SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR> # SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == '(': if prev_keyword and prev_keyword.value == "(":
# Suggest datatypes # Suggest datatypes
return suggest_based_on_last_token('type', stmt) return suggest_based_on_last_token("type", stmt)
else: else:
return (Keyword(),) return (Keyword(),)
else: else:
@ -256,7 +262,7 @@ def suggest_based_on_last_token(token, stmt):
if not token: if not token:
return (Keyword(),) return (Keyword(),)
elif token_v.endswith('('): elif token_v.endswith("("):
p = sqlparse.parse(stmt.text_before_cursor)[0] p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where): if p.tokens and isinstance(p.tokens[-1], Where):
@ -268,10 +274,10 @@ def suggest_based_on_last_token(token, stmt):
# 3 - Subquery expression like "WHERE EXISTS (" # 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery # Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY(" # 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be # Suggest columns/functions AND keywords. (If we wanted to
# really fancy, we could suggest only array-typed columns) # be really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token('where', stmt) column_suggestions = suggest_based_on_last_token("where", stmt)
# Check for a subquery expression (cases 3 & 4) # Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1] where = p.tokens[-1]
@ -282,7 +288,7 @@ def suggest_based_on_last_token(token, stmt):
prev_tok = prev_tok.tokens[-1] prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower() prev_tok = prev_tok.value.lower()
if prev_tok == 'exists': if prev_tok == "exists":
return (Keyword(),) return (Keyword(),)
else: else:
return column_suggestions return column_suggestions
@ -292,59 +298,47 @@ def suggest_based_on_last_token(token, stmt):
if ( if (
prev_tok and prev_tok.value and prev_tok and prev_tok.value and
prev_tok.value.lower().split(' ')[-1] == 'using' prev_tok.value.lower().split(" ")[-1] == "using"
): ):
# tbl1 INNER JOIN tbl2 USING (col1, col2) # tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables('before') tables = stmt.get_tables("before")
# suggest columns that are present in more than one table # suggest columns that are present in more than one table
return (Column(table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables),)
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
elif p.token_first().value.lower() == 'select' and \
last_word(stmt.text_before_cursor,
'all_punctuations').startswith('('):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
return ( return (
Column(table_refs=stmt.get_tables('insert'), context='insert'), Column(
table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables,
),
) )
elif p.token_first().value.lower() == "select":
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor,
"all_punctuations").startswith("("):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
return (Column(table_refs=stmt.get_tables("insert"),
context="insert"),)
# We're probably in a function argument list # We're probably in a function argument list
return (Column(table_refs=extract_tables(stmt.full_text), return _suggest_expression(token_v, stmt)
local_tables=stmt.local_tables, qualifiable=True),) elif token_v == "set":
elif token_v == 'set':
return (Column(table_refs=stmt.get_tables(), return (Column(table_refs=stmt.get_tables(),
local_tables=stmt.local_tables),) local_tables=stmt.local_tables),)
elif token_v in ('select', 'where', 'having', 'by', 'distinct'): elif token_v in ("select", "where", "having", "order by", "distinct"):
# Check for a table alias or schema qualification return _suggest_expression(token_v, stmt)
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or [] elif token_v == "as":
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
return (Column(table_refs=tables, local_tables=stmt.local_tables,
qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),)
elif token_v == 'as':
# Don't suggest anything for aliases # Don't suggest anything for aliases
return () return ()
elif ( elif (token_v.endswith("join") and token.is_keyword) or (
(token_v.endswith('join') and token.is_keyword) or token_v in ("copy", "from", "update", "into", "describe", "truncate")
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
): ):
schema = stmt.get_identifier_schema() schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor) tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith('join') and token.is_keyword is_join = token_v.endswith("join") and token.is_keyword
# Suggest tables from either the currently-selected schema or the # Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified # public schema if no schema has been specified
@ -354,60 +348,77 @@ def suggest_based_on_last_token(token, stmt):
# Suggest schemas # Suggest schemas
suggest.insert(0, Schema()) suggest.insert(0, Schema())
if token_v == 'from' or is_join: if token_v == "from" or is_join:
suggest.append(FromClauseItem(schema=schema, suggest.append(
table_refs=tables, FromClauseItem(
local_tables=stmt.local_tables)) schema=schema, table_refs=tables,
elif token_v == 'truncate': local_tables=stmt.local_tables
)
)
elif token_v == "truncate":
suggest.append(Table(schema)) suggest.append(Table(schema))
else: else:
suggest.extend((Table(schema), View(schema))) suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed): if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables('before') tables = stmt.get_tables("before")
suggest.append(Join(table_refs=tables, schema=schema)) suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest) return tuple(suggest)
elif token_v == 'function': elif token_v == "function":
schema = stmt.get_identifier_schema() schema = stmt.get_identifier_schema()
# stmt.get_previous_token will fail for e.g. # stmt.get_previous_token will fail for e.g.
# `SELECT 1 FROM functions WHERE function:` # `SELECT 1 FROM functions WHERE function:`
try: try:
prev = stmt.get_previous_token(token).value.lower() prev = stmt.get_previous_token(token).value.lower()
if prev in ('drop', 'alter', 'create', 'create or replace'): if prev in ("drop", "alter", "create", "create or replace"):
return (Function(schema=schema, usage='signature'),)
# Suggest functions from either the currently-selected schema
# or the public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
suggest.append(Function(schema=schema, usage="signature"))
return tuple(suggest)
except ValueError: except ValueError:
pass pass
return tuple() return tuple()
elif token_v in ('table', 'view'): elif token_v in ("table", "view"):
# E.g. 'ALTER TABLE <tablname>' # E.g. 'ALTER TABLE <tablname>'
rel_type = \ rel_type = \
{'table': Table, 'view': View, 'function': Function}[token_v] {"table": Table, "view": View, "function": Function}[token_v]
schema = stmt.get_identifier_schema() schema = stmt.get_identifier_schema()
if schema: if schema:
return (rel_type(schema=schema),) return (rel_type(schema=schema),)
else: else:
return (Schema(), rel_type(schema=schema)) return (Schema(), rel_type(schema=schema))
elif token_v == 'column': elif token_v == "column":
# E.g. 'ALTER TABLE foo ALTER COLUMN bar # E.g. 'ALTER TABLE foo ALTER COLUMN bar
return (Column(table_refs=stmt.get_tables()),) return (Column(table_refs=stmt.get_tables()),)
elif token_v == 'on': elif token_v == "on":
tables = stmt.get_tables('before') tables = stmt.get_tables("before")
parent = \ parent = \
(stmt.identifier and stmt.identifier.get_parent_name()) or None (stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent: if parent:
# "ON parent.<suggestion>" # "ON parent.<suggestion>"
# parent can be either a schema name or table alias # parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t)) filteredtables = tuple(t for t in tables if identifies(parent, t))
sugs = [Column(table_refs=filteredtables, sugs = [
local_tables=stmt.local_tables), Column(table_refs=filteredtables,
Table(schema=parent), local_tables=stmt.local_tables),
View(schema=parent), Table(schema=parent),
Function(schema=parent)] View(schema=parent),
Function(schema=parent),
]
if filteredtables and _allow_join_condition(stmt.parsed): if filteredtables and _allow_join_condition(stmt.parsed):
sugs.append(JoinCondition(table_refs=tables, sugs.append(JoinCondition(table_refs=tables,
parent=filteredtables[-1])) parent=filteredtables[-1]))
@ -417,38 +428,39 @@ def suggest_based_on_last_token(token, stmt):
# Use table alias if there is one, otherwise the table name # Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables) aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed): if _allow_join_condition(stmt.parsed):
return (Alias(aliases=aliases), JoinCondition( return (
table_refs=tables, parent=None)) Alias(aliases=aliases),
JoinCondition(table_refs=tables, parent=None),
)
else: else:
return (Alias(aliases=aliases),) return (Alias(aliases=aliases),)
elif token_v in ('c', 'use', 'database', 'template'): elif token_v in ("c", "use", "database", "template"):
# "\c <db", "use <db>", "DROP DATABASE <db>", # "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>" # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return (Database(),) return (Database(),)
elif token_v == 'schema': elif token_v == "schema":
# DROP SCHEMA schema_name, SET SCHEMA schema name # DROP SCHEMA schema_name, SET SCHEMA schema name
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
quoted = prev_keyword and prev_keyword.value.lower() == 'set' quoted = prev_keyword and prev_keyword.value.lower() == "set"
return (Schema(quoted),) return (Schema(quoted),)
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'): elif token_v.endswith(",") or token_v in ("=", "and", "or"):
prev_keyword = stmt.reduce_to_prev_keyword() prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword: if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt) return suggest_based_on_last_token(prev_keyword, stmt)
else: else:
return () return ()
elif token_v in ('type', '::'): elif token_v in ("type", "::"):
# ALTER TABLE foo SET DATA TYPE bar # ALTER TABLE foo SET DATA TYPE bar
# SELECT foo::bar # SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so # Note that tables are a form of composite type in postgresql, so
# they're suggested here as well # they're suggested here as well
schema = stmt.get_identifier_schema() schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema), suggestions = [Datatype(schema=schema), Table(schema=schema)]
Table(schema=schema)]
if not schema: if not schema:
suggestions.append(Schema()) suggestions.append(Schema())
return tuple(suggestions) return tuple(suggestions)
elif token_v in ['alter', 'create', 'drop']: elif token_v in {"alter", "create", "drop"}:
return (Keyword(token_v.upper()),) return (Keyword(token_v.upper()),)
elif token.is_keyword: elif token.is_keyword:
# token is a keyword we haven't implemented any special handling for # token is a keyword we haven't implemented any special handling for
@ -462,11 +474,38 @@ def suggest_based_on_last_token(token, stmt):
return (Keyword(),) return (Keyword(),)
def _suggest_expression(token_v, stmt):
"""
Return suggestions for an expression, taking account of any partially-typed
identifier's parent, which may be a table alias or schema name.
"""
parent = stmt.identifier.get_parent_name() if stmt.identifier else []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (
Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),
)
return (
Column(table_refs=tables, local_tables=stmt.local_tables,
qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),
)
def identifies(id, ref): def identifies(id, ref):
"""Returns true if string `id` matches TableReference `ref`""" """Returns true if string `id` matches TableReference `ref`"""
return id == ref.alias or id == ref.name or ( return (
ref.schema and (id == ref.schema + '.' + ref.name)) id == ref.alias or id == ref.name or
(ref.schema and (id == ref.schema + "." + ref.name))
)
def _allow_join_condition(statement): def _allow_join_condition(statement):
@ -486,7 +525,7 @@ def _allow_join_condition(statement):
return False return False
last_tok = statement.token_prev(len(statement.tokens))[1] last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower() in ('on', 'and', 'or') return last_tok.value.lower() in ("on", "and", "or")
def _allow_join(statement): def _allow_join(statement):
@ -505,7 +544,5 @@ def _allow_join(statement):
return False return False
last_tok = statement.token_prev(len(statement.tokens))[1] last_tok = statement.token_prev(len(statement.tokens))[1]
return ( return last_tok.value.lower().endswith("join") and \
last_tok.value.lower().endswith('join') and last_tok.value.lower() not in ("cross join", "natural join",)
last_tok.value.lower() not in ('cross join', 'natural join')
)