mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-02-25 18:55:31 -06:00
Merged the latest code of 'pgcli' used for the autocomplete feature. Fixes #5497
This commit is contained in:
parent
3f817494f8
commit
300de05a20
@ -16,6 +16,7 @@ Housekeeping
|
||||
************
|
||||
|
||||
| `Issue #5330 <https://redmine.postgresql.org/issues/5330>`_ - Improve code coverage and API test cases for Functions.
|
||||
| `Issue #5497 <https://redmine.postgresql.org/issues/5497>`_ - Merged the latest code of 'pgcli' used for the autocomplete feature.
|
||||
|
||||
Bug fixes
|
||||
*********
|
||||
|
@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
|
||||
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
||||
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
||||
p.proretset is_set_returning,
|
||||
d.deptype = 'e' is_extension,
|
||||
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
|
||||
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||
AND n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1, 2
|
||||
|
@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
|
||||
p.proisagg is_aggregate,
|
||||
p.proiswindow is_window,
|
||||
p.proretset is_set_returning,
|
||||
d.deptype = 'e' is_extension,
|
||||
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
|
||||
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||
AND n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1, 2
|
||||
|
@ -11,8 +11,7 @@
|
||||
|
||||
import re
|
||||
import operator
|
||||
import sys
|
||||
from itertools import count, repeat, chain
|
||||
from itertools import count
|
||||
from .completion import Completion
|
||||
from collections import namedtuple, defaultdict, OrderedDict
|
||||
|
||||
@ -28,9 +27,9 @@ from pgadmin.utils.driver import get_driver
|
||||
from config import PG_DEFAULT_DRIVER
|
||||
from pgadmin.utils.preferences import Preferences
|
||||
|
||||
Match = namedtuple('Match', ['completion', 'priority'])
|
||||
Match = namedtuple("Match", ["completion", "priority"])
|
||||
|
||||
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
|
||||
_SchemaObject = namedtuple("SchemaObject", "name schema meta")
|
||||
|
||||
|
||||
def SchemaObject(name, schema=None, meta=None):
|
||||
@ -41,14 +40,12 @@ def SchemaObject(name, schema=None, meta=None):
|
||||
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
|
||||
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
|
||||
|
||||
_Candidate = namedtuple(
|
||||
'Candidate', 'completion prio meta synonyms prio2 display'
|
||||
)
|
||||
_Candidate = namedtuple("Candidate",
|
||||
"completion prio meta synonyms prio2 display")
|
||||
|
||||
|
||||
def Candidate(
|
||||
completion, prio=None, meta=None, synonyms=None, prio2=None,
|
||||
display=None
|
||||
completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
|
||||
):
|
||||
return _Candidate(
|
||||
completion, prio, meta, synonyms or [completion], prio2,
|
||||
@ -57,7 +54,7 @@ def Candidate(
|
||||
|
||||
|
||||
# Used to strip trailing '::some_type' from default-value expressions
|
||||
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
|
||||
arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
|
||||
|
||||
|
||||
def normalize_ref(ref):
|
||||
@ -65,15 +62,15 @@ def normalize_ref(ref):
|
||||
|
||||
|
||||
def generate_alias(tbl):
|
||||
""" Generate a table alias, consisting of all upper-case letters in
|
||||
"""Generate a table alias, consisting of all upper-case letters in
|
||||
the table name, or, if there are no upper-case letters, the first letter +
|
||||
all letters preceded by _
|
||||
param tbl - unescaped name of the table to alias
|
||||
"""
|
||||
return ''.join(
|
||||
return "".join(
|
||||
[letter for letter in tbl if letter.isupper()] or
|
||||
[letter for letter, prev in zip(tbl, '_' + tbl)
|
||||
if prev == '_' and letter != '_']
|
||||
[letter for letter, prev in zip(tbl, "_" + tbl)
|
||||
if prev == "_" and letter != "_"]
|
||||
)
|
||||
|
||||
|
||||
@ -97,13 +94,14 @@ class SQLAutoComplete(object):
|
||||
self.sid = kwargs['sid'] if 'sid' in kwargs else None
|
||||
self.conn = kwargs['conn'] if 'conn' in kwargs else None
|
||||
self.keywords = []
|
||||
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
|
||||
|
||||
self.databases = []
|
||||
self.functions = []
|
||||
self.datatypes = []
|
||||
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
|
||||
'datatypes': {}}
|
||||
self.dbmetadata = \
|
||||
{"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
|
||||
self.text_before_cursor = None
|
||||
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
|
||||
|
||||
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
|
||||
|
||||
@ -182,7 +180,8 @@ class SQLAutoComplete(object):
|
||||
def escape_name(self, name):
|
||||
if name and (
|
||||
(not self.name_pattern.match(name)) or
|
||||
(name.upper() in self.reserved_words)
|
||||
(name.upper() in self.reserved_words) or
|
||||
(name.upper() in self.functions)
|
||||
):
|
||||
name = '"%s"' % name
|
||||
|
||||
@ -212,7 +211,7 @@ class SQLAutoComplete(object):
|
||||
|
||||
# schemata is a list of schema names
|
||||
schemata = self.escaped_names(schemata)
|
||||
metadata = self.dbmetadata['tables']
|
||||
metadata = self.dbmetadata["tables"]
|
||||
for schema in schemata:
|
||||
metadata[schema] = {}
|
||||
|
||||
@ -224,7 +223,7 @@ class SQLAutoComplete(object):
|
||||
self.all_completions.update(schemata)
|
||||
|
||||
def extend_casing(self, words):
|
||||
""" extend casing data
|
||||
"""extend casing data
|
||||
|
||||
:return:
|
||||
"""
|
||||
@ -274,7 +273,7 @@ class SQLAutoComplete(object):
|
||||
name=colname,
|
||||
datatype=datatype,
|
||||
has_default=has_default,
|
||||
default=default
|
||||
default=default,
|
||||
)
|
||||
metadata[schema][relname][colname] = column
|
||||
self.all_completions.add(colname)
|
||||
@ -285,7 +284,7 @@ class SQLAutoComplete(object):
|
||||
|
||||
# dbmetadata['schema_name']['functions']['function_name'] should return
|
||||
# the function metadata namedtuple for the corresponding function
|
||||
metadata = self.dbmetadata['functions']
|
||||
metadata = self.dbmetadata["functions"]
|
||||
|
||||
for f in func_data:
|
||||
schema, func = self.escaped_names([f.schema_name, f.func_name])
|
||||
@ -309,10 +308,10 @@ class SQLAutoComplete(object):
|
||||
self._arg_list_cache = \
|
||||
dict((usage,
|
||||
dict((meta, self._arg_list(meta, usage))
|
||||
for sch, funcs in self.dbmetadata['functions'].items()
|
||||
for sch, funcs in self.dbmetadata["functions"].items()
|
||||
for func, metas in funcs.items()
|
||||
for meta in metas))
|
||||
for usage in ('call', 'call_display', 'signature'))
|
||||
for usage in ("call", "call_display", "signature"))
|
||||
|
||||
def extend_foreignkeys(self, fk_data):
|
||||
|
||||
@ -322,7 +321,7 @@ class SQLAutoComplete(object):
|
||||
|
||||
# These are added as a list of ForeignKey namedtuples to the
|
||||
# ColumnMetadata namedtuple for both the child and parent
|
||||
meta = self.dbmetadata['tables']
|
||||
meta = self.dbmetadata["tables"]
|
||||
|
||||
for fk in fk_data:
|
||||
e = self.escaped_names
|
||||
@ -350,7 +349,7 @@ class SQLAutoComplete(object):
|
||||
# dbmetadata['datatypes'][schema_name][type_name] should store type
|
||||
# metadata, such as composite type field names. Currently, we're not
|
||||
# storing any metadata beyond typename, so just store None
|
||||
meta = self.dbmetadata['datatypes']
|
||||
meta = self.dbmetadata["datatypes"]
|
||||
|
||||
for t in type_data:
|
||||
schema, type_name = self.escaped_names(t)
|
||||
@ -364,11 +363,11 @@ class SQLAutoComplete(object):
|
||||
self.databases = []
|
||||
self.special_commands = []
|
||||
self.search_path = []
|
||||
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
|
||||
'datatypes': {}}
|
||||
self.dbmetadata = \
|
||||
{"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
|
||||
self.all_completions = set(self.keywords + self.functions)
|
||||
|
||||
def find_matches(self, text, collection, mode='fuzzy', meta=None):
|
||||
def find_matches(self, text, collection, mode="strict", meta=None):
|
||||
"""Find completion matches for the given text.
|
||||
|
||||
Given the user's input text and a collection of available
|
||||
@ -389,17 +388,26 @@ class SQLAutoComplete(object):
|
||||
collection:
|
||||
mode:
|
||||
meta:
|
||||
meta_collection:
|
||||
"""
|
||||
if not collection:
|
||||
return []
|
||||
prio_order = [
|
||||
'keyword', 'function', 'view', 'table', 'datatype', 'database',
|
||||
'schema', 'column', 'table alias', 'join', 'name join', 'fk join',
|
||||
'table format'
|
||||
"keyword",
|
||||
"function",
|
||||
"view",
|
||||
"table",
|
||||
"datatype",
|
||||
"database",
|
||||
"schema",
|
||||
"column",
|
||||
"table alias",
|
||||
"join",
|
||||
"name join",
|
||||
"fk join",
|
||||
"table format",
|
||||
]
|
||||
type_priority = prio_order.index(meta) if meta in prio_order else -1
|
||||
text = last_word(text, include='most_punctuations').lower()
|
||||
text = last_word(text, include="most_punctuations").lower()
|
||||
text_len = len(text)
|
||||
|
||||
if text and text[0] == '"':
|
||||
@ -409,7 +417,7 @@ class SQLAutoComplete(object):
|
||||
# Completion.position value is correct
|
||||
text = text[1:]
|
||||
|
||||
if mode == 'fuzzy':
|
||||
if mode == "fuzzy":
|
||||
fuzzy = True
|
||||
priority_func = self.prioritizer.name_count
|
||||
else:
|
||||
@ -422,19 +430,20 @@ class SQLAutoComplete(object):
|
||||
# Note: higher priority values mean more important, so use negative
|
||||
# signs to flip the direction of the tuple
|
||||
if fuzzy:
|
||||
regex = '.*?'.join(map(re.escape, text))
|
||||
pat = re.compile('(%s)' % regex)
|
||||
regex = ".*?".join(map(re.escape, text))
|
||||
pat = re.compile("(%s)" % regex)
|
||||
|
||||
def _match(item):
|
||||
if item.lower()[:len(text) + 1] in (text, text + ' '):
|
||||
if item.lower()[: len(text) + 1] in (text, text + " "):
|
||||
# Exact match of first word in suggestion
|
||||
# This is to get exact alias matches to the top
|
||||
# E.g. for input `e`, 'Entries E' should be on top
|
||||
# (before e.g. `EndUsers EU`)
|
||||
return float('Infinity'), -1
|
||||
return float("Infinity"), -1
|
||||
r = pat.search(self.unescape_name(item.lower()))
|
||||
if r:
|
||||
return -len(r.group()), -r.start()
|
||||
|
||||
else:
|
||||
match_end_limit = len(text)
|
||||
|
||||
@ -446,7 +455,7 @@ class SQLAutoComplete(object):
|
||||
if match_point >= 0:
|
||||
# Use negative infinity to force keywords to sort after all
|
||||
# fuzzy matches
|
||||
return -float('Infinity'), -match_point
|
||||
return -float("Infinity"), -match_point
|
||||
|
||||
matches = []
|
||||
for cand in collection:
|
||||
@ -466,7 +475,7 @@ class SQLAutoComplete(object):
|
||||
if sort_key:
|
||||
if display_meta and len(display_meta) > 50:
|
||||
# Truncate meta-text to 50 characters, if necessary
|
||||
display_meta = display_meta[:47] + '...'
|
||||
display_meta = display_meta[:47] + "..."
|
||||
|
||||
# Lexical order of items in the collection, used for
|
||||
# tiebreaking items with the same match group length and start
|
||||
@ -478,14 +487,18 @@ class SQLAutoComplete(object):
|
||||
# We also use the unescape_name to make sure quoted names have
|
||||
# the same priority as unquoted names.
|
||||
lexical_priority = (
|
||||
tuple(0 if c in (' _') else -ord(c)
|
||||
tuple(0 if c in (" _") else -ord(c)
|
||||
for c in self.unescape_name(item.lower())) + (1,) +
|
||||
tuple(c for c in item)
|
||||
)
|
||||
|
||||
priority = (
|
||||
sort_key, type_priority, prio, priority_func(item),
|
||||
prio2, lexical_priority
|
||||
sort_key,
|
||||
type_priority,
|
||||
prio,
|
||||
priority_func(item),
|
||||
prio2,
|
||||
lexical_priority,
|
||||
)
|
||||
matches.append(
|
||||
Match(
|
||||
@ -493,9 +506,9 @@ class SQLAutoComplete(object):
|
||||
text=item,
|
||||
start_position=-text_len,
|
||||
display_meta=display_meta,
|
||||
display=display
|
||||
display=display,
|
||||
),
|
||||
priority=priority
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
@ -516,8 +529,8 @@ class SQLAutoComplete(object):
|
||||
matches.extend(matcher(self, suggestion, word_before_cursor))
|
||||
|
||||
# Sort matches so highest priorities are first
|
||||
matches = sorted(matches, key=operator.attrgetter('priority'),
|
||||
reverse=True)
|
||||
matches = \
|
||||
sorted(matches, key=operator.attrgetter("priority"), reverse=True)
|
||||
|
||||
result = dict()
|
||||
for m in matches:
|
||||
@ -539,23 +552,28 @@ class SQLAutoComplete(object):
|
||||
|
||||
tables = suggestion.table_refs
|
||||
do_qualify = suggestion.qualifiable and {
|
||||
'always': True, 'never': False,
|
||||
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns]
|
||||
"always": True,
|
||||
"never": False,
|
||||
"if_more_than_one_table": len(tables) > 1,
|
||||
}[self.qualify_columns]
|
||||
|
||||
def qualify(col, tbl):
|
||||
return (tbl + '.' + col) if do_qualify else col
|
||||
|
||||
scoped_cols = self.populate_scoped_cols(
|
||||
tables, suggestion.local_tables
|
||||
)
|
||||
scoped_cols = \
|
||||
self.populate_scoped_cols(tables, suggestion.local_tables)
|
||||
|
||||
def make_cand(name, ref):
|
||||
synonyms = (name, generate_alias(name))
|
||||
return Candidate(qualify(name, ref), 0, 'column', synonyms)
|
||||
return Candidate(qualify(name, ref), 0, "column", synonyms)
|
||||
|
||||
def flat_cols():
|
||||
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items()
|
||||
for c in cols]
|
||||
return [
|
||||
make_cand(c.name, t.ref)
|
||||
for t, cols in scoped_cols.items()
|
||||
for c in cols
|
||||
]
|
||||
|
||||
if suggestion.require_last_table:
|
||||
# require_last_table is used for 'tb11 JOIN tbl2 USING
|
||||
# (...' which should
|
||||
@ -569,10 +587,11 @@ class SQLAutoComplete(object):
|
||||
dict((t, [col for col in cols if col.name in other_tbl_cols])
|
||||
for t, cols in scoped_cols.items() if t.ref == ltbl)
|
||||
|
||||
lastword = last_word(word_before_cursor, include='most_punctuations')
|
||||
if lastword == '*':
|
||||
if suggestion.context == 'insert':
|
||||
def is_scoped(col):
|
||||
lastword = last_word(word_before_cursor, include="most_punctuations")
|
||||
if lastword == "*":
|
||||
if suggestion.context == "insert":
|
||||
|
||||
def filter(col):
|
||||
if not col.has_default:
|
||||
return True
|
||||
return not any(
|
||||
@ -580,40 +599,39 @@ class SQLAutoComplete(object):
|
||||
for p in self.insert_col_skip_patterns
|
||||
)
|
||||
scoped_cols = \
|
||||
dict((t, [col for col in cols if is_scoped(col)])
|
||||
dict((t, [col for col in cols if filter(col)])
|
||||
for t, cols in scoped_cols.items())
|
||||
if self.asterisk_column_order == 'alphabetic':
|
||||
if self.asterisk_column_order == "alphabetic":
|
||||
for cols in scoped_cols.values():
|
||||
cols.sort(key=operator.attrgetter('name'))
|
||||
cols.sort(key=operator.attrgetter("name"))
|
||||
if (
|
||||
lastword != word_before_cursor and
|
||||
len(tables) == 1 and
|
||||
word_before_cursor[-len(lastword) - 1] == '.'
|
||||
word_before_cursor[-len(lastword) - 1] == "."
|
||||
):
|
||||
# User typed x.*; replicate "x." for all columns except the
|
||||
# first, which gets the original (as we only replace the "*"")
|
||||
sep = ', ' + word_before_cursor[:-1]
|
||||
sep = ", " + word_before_cursor[:-1]
|
||||
collist = sep.join(c.completion for c in flat_cols())
|
||||
else:
|
||||
collist = ', '.join(qualify(c.name, t.ref)
|
||||
collist = ", ".join(qualify(c.name, t.ref)
|
||||
for t, cs in scoped_cols.items()
|
||||
for c in cs)
|
||||
|
||||
return [Match(
|
||||
completion=Completion(
|
||||
collist,
|
||||
-1,
|
||||
display_meta='columns',
|
||||
display='*'
|
||||
),
|
||||
priority=(1, 1, 1)
|
||||
)]
|
||||
return [
|
||||
Match(
|
||||
completion=Completion(
|
||||
collist, -1, display_meta="columns", display="*"
|
||||
),
|
||||
priority=(1, 1, 1),
|
||||
)
|
||||
]
|
||||
|
||||
return self.find_matches(word_before_cursor, flat_cols(),
|
||||
mode='strict', meta='column')
|
||||
meta="column")
|
||||
|
||||
def alias(self, tbl, tbls):
|
||||
""" Generate a unique table alias
|
||||
"""Generate a unique table alias
|
||||
tbl - name of the table to alias, quoted if it needs to be
|
||||
tbls - TableReference iterable of tables already in query
|
||||
"""
|
||||
@ -628,25 +646,6 @@ class SQLAutoComplete(object):
|
||||
aliases = (tbl + str(i) for i in count(2))
|
||||
return next(a for a in aliases if normalize_ref(a) not in tbls)
|
||||
|
||||
def _check_for_aliases(self, left, refs, rtbl, suggestion, right):
|
||||
"""
|
||||
Check for generate aliases and return join value
|
||||
:param left:
|
||||
:param refs:
|
||||
:param rtbl:
|
||||
:param suggestion:
|
||||
:param right:
|
||||
:return: return join string.
|
||||
"""
|
||||
if self.generate_aliases or normalize_ref(left.tbl) in refs:
|
||||
lref = self.alias(left.tbl, suggestion.table_refs)
|
||||
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
|
||||
left.tbl, left.col, rtbl.ref, right.col, lref)
|
||||
else:
|
||||
join = '{0} ON {0}.{1} = {2}.{3}'.format(
|
||||
left.tbl, left.col, rtbl.ref, right.col)
|
||||
return join
|
||||
|
||||
def get_join_matches(self, suggestion, word_before_cursor):
|
||||
tbls = suggestion.table_refs
|
||||
cols = self.populate_scoped_cols(tbls)
|
||||
@ -658,10 +657,12 @@ class SQLAutoComplete(object):
|
||||
joins = []
|
||||
# Iterate over FKs in existing tables to find potential joins
|
||||
fks = (
|
||||
(fk, rtbl, rcol) for rtbl, rcols in cols.items()
|
||||
for rcol in rcols for fk in rcol.foreignkeys
|
||||
(fk, rtbl, rcol)
|
||||
for rtbl, rcols in cols.items()
|
||||
for rcol in rcols
|
||||
for fk in rcol.foreignkeys
|
||||
)
|
||||
col = namedtuple('col', 'schema tbl col')
|
||||
col = namedtuple("col", "schema tbl col")
|
||||
for fk, rtbl, rcol in fks:
|
||||
right = col(rtbl.schema, rtbl.name, rcol.name)
|
||||
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
||||
@ -670,54 +671,38 @@ class SQLAutoComplete(object):
|
||||
if suggestion.schema and left.schema != suggestion.schema:
|
||||
continue
|
||||
|
||||
join = self._check_for_aliases(left, refs, rtbl, suggestion, right)
|
||||
if self.generate_aliases or normalize_ref(left.tbl) in refs:
|
||||
lref = self.alias(left.tbl, suggestion.table_refs)
|
||||
join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
|
||||
left.tbl, left.col, rtbl.ref, right.col, lref
|
||||
)
|
||||
else:
|
||||
join = "{0} ON {0}.{1} = {2}.{3}".format(
|
||||
left.tbl, left.col, rtbl.ref, right.col
|
||||
)
|
||||
alias = generate_alias(left.tbl)
|
||||
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format(
|
||||
alias, left.col, rtbl.ref, right.col)]
|
||||
synonyms = [
|
||||
join,
|
||||
"{0} ON {0}.{1} = {2}.{3}".format(
|
||||
alias, left.col, rtbl.ref, right.col
|
||||
),
|
||||
]
|
||||
# Schema-qualify if (1) new table in same schema as old, and old
|
||||
# is schema-qualified, or (2) new in other schema, except public
|
||||
if not suggestion.schema and \
|
||||
(qualified[normalize_ref(rtbl.ref)] and
|
||||
left.schema == right.schema or
|
||||
left.schema not in (right.schema, 'public')):
|
||||
join = left.schema + '.' + join
|
||||
left.schema not in (right.schema, "public")):
|
||||
join = left.schema + "." + join
|
||||
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
|
||||
0 if (left.schema, left.tbl) in other_tbls else 1)
|
||||
joins.append(Candidate(join, prio, 'join', synonyms=synonyms))
|
||||
0 if (left.schema, left.tbl) in other_tbls else 1
|
||||
)
|
||||
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
|
||||
|
||||
return self.find_matches(word_before_cursor, joins,
|
||||
mode='strict', meta='join')
|
||||
|
||||
def list_dict(self, pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
||||
d = defaultdict(list)
|
||||
for pair in pairs:
|
||||
d[pair[0]].append(pair[1])
|
||||
return d
|
||||
|
||||
def add_cond(self, lcol, rcol, rref, prio, meta, **kwargs):
|
||||
"""
|
||||
Add Condition in join
|
||||
:param lcol:
|
||||
:param rcol:
|
||||
:param rref:
|
||||
:param prio:
|
||||
:param meta:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
suggestion = kwargs['suggestion']
|
||||
found_conds = kwargs['found_conds']
|
||||
ltbl = kwargs['ltbl']
|
||||
ref_prio = kwargs['ref_prio']
|
||||
conds = kwargs['conds']
|
||||
prefix = '' if suggestion.parent else ltbl.ref + '.'
|
||||
cond = prefix + lcol + ' = ' + rref + '.' + rcol
|
||||
if cond not in found_conds:
|
||||
found_conds.add(cond)
|
||||
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
|
||||
return self.find_matches(word_before_cursor, joins, meta="join")
|
||||
|
||||
def get_join_condition_matches(self, suggestion, word_before_cursor):
|
||||
col = namedtuple('col', 'schema tbl col')
|
||||
col = namedtuple("col", "schema tbl col")
|
||||
tbls = self.populate_scoped_cols(suggestion.table_refs).items
|
||||
cols = [(t, c) for t, cs in tbls() for c in cs]
|
||||
try:
|
||||
@ -727,11 +712,24 @@ class SQLAutoComplete(object):
|
||||
return []
|
||||
conds, found_conds = [], set()
|
||||
|
||||
def add_cond(lcol, rcol, rref, prio, meta):
|
||||
prefix = "" if suggestion.parent else ltbl.ref + "."
|
||||
cond = prefix + lcol + " = " + rref + "." + rcol
|
||||
if cond not in found_conds:
|
||||
found_conds.add(cond)
|
||||
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
|
||||
|
||||
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
||||
d = defaultdict(list)
|
||||
for pair in pairs:
|
||||
d[pair[0]].append(pair[1])
|
||||
return d
|
||||
|
||||
# Tables that are closer to the cursor get higher prio
|
||||
ref_prio = dict((tbl.ref, num)
|
||||
for num, tbl in enumerate(suggestion.table_refs))
|
||||
# Map (schema, table, col) to tables
|
||||
coldict = self.list_dict(
|
||||
coldict = list_dict(
|
||||
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
||||
)
|
||||
# For each fk from the left table, generate a join condition if
|
||||
@ -742,89 +740,76 @@ class SQLAutoComplete(object):
|
||||
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
||||
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
||||
left, right = (child, par) if left == child else (par, child)
|
||||
|
||||
for rtbl in coldict[right]:
|
||||
kwargs = {
|
||||
"suggestion": suggestion,
|
||||
"found_conds": found_conds,
|
||||
"ltbl": ltbl,
|
||||
"conds": conds,
|
||||
"ref_prio": ref_prio
|
||||
}
|
||||
self.add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join',
|
||||
**kwargs)
|
||||
add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
|
||||
# For name matching, use a {(colname, coltype): TableReference} dict
|
||||
coltyp = namedtuple('coltyp', 'name datatype')
|
||||
col_table = self.list_dict(
|
||||
(coltyp(c.name, c.datatype), t) for t, c in cols)
|
||||
coltyp = namedtuple("coltyp", "name datatype")
|
||||
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
|
||||
# Find all name-match join conditions
|
||||
for c in (coltyp(c.name, c.datatype) for c in lcols):
|
||||
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
|
||||
kwargs = {
|
||||
"suggestion": suggestion,
|
||||
"found_conds": found_conds,
|
||||
"ltbl": ltbl,
|
||||
"conds": conds,
|
||||
"ref_prio": ref_prio
|
||||
}
|
||||
prio = 1000 if c.datatype in (
|
||||
'integer', 'bigint', 'smallint') else 0
|
||||
self.add_cond(c.name, c.name, rtbl.ref, prio, 'name join',
|
||||
**kwargs)
|
||||
"integer", "bigint", "smallint") else 0
|
||||
add_cond(c.name, c.name, rtbl.ref, prio, "name join")
|
||||
|
||||
return self.find_matches(word_before_cursor, conds,
|
||||
mode='strict', meta='join')
|
||||
return self.find_matches(word_before_cursor, conds, meta="join")
|
||||
|
||||
def get_function_matches(self, suggestion, word_before_cursor,
|
||||
alias=False):
|
||||
if suggestion.usage == 'from':
|
||||
if suggestion.usage == "from":
|
||||
# Only suggest functions allowed in FROM clause
|
||||
|
||||
def filt(f):
|
||||
return not f.is_aggregate and not f.is_window
|
||||
return (
|
||||
not f.is_aggregate and not f.is_window and
|
||||
not f.is_extension and
|
||||
(f.is_public or f.schema_name == suggestion.schema)
|
||||
)
|
||||
|
||||
else:
|
||||
alias = False
|
||||
|
||||
def filt(f):
|
||||
return True
|
||||
return not f.is_extension and (
|
||||
f.is_public or f.schema_name == suggestion.schema
|
||||
)
|
||||
|
||||
arg_mode = {
|
||||
'signature': 'signature',
|
||||
'special': None,
|
||||
}.get(suggestion.usage, 'call')
|
||||
# Function overloading means we way have multiple functions of the same
|
||||
# name at this point, so keep unique names only
|
||||
funcs = set(
|
||||
self._make_cand(f, alias, suggestion, arg_mode)
|
||||
for f in self.populate_functions(suggestion.schema, filt)
|
||||
arg_mode = {"signature": "signature", "special": None}.get(
|
||||
suggestion.usage, "call"
|
||||
)
|
||||
|
||||
matches = self.find_matches(word_before_cursor, funcs,
|
||||
mode='strict', meta='function')
|
||||
# Function overloading means we way have multiple functions of the same
|
||||
# name at this point, so keep unique names only
|
||||
all_functions = self.populate_functions(suggestion.schema, filt)
|
||||
funcs = set(
|
||||
self._make_cand(f, alias, suggestion, arg_mode)
|
||||
for f in all_functions
|
||||
)
|
||||
|
||||
matches = self.find_matches(word_before_cursor, funcs, meta="function")
|
||||
|
||||
return matches
|
||||
|
||||
def get_schema_matches(self, suggestion, word_before_cursor):
|
||||
schema_names = self.dbmetadata['tables'].keys()
|
||||
schema_names = self.dbmetadata["tables"].keys()
|
||||
|
||||
# Unless we're sure the user really wants them, hide schema names
|
||||
# starting with pg_, which are mostly temporary schemas
|
||||
if not word_before_cursor.startswith('pg_'):
|
||||
schema_names = [s
|
||||
for s in schema_names
|
||||
if not s.startswith('pg_')]
|
||||
if not word_before_cursor.startswith("pg_"):
|
||||
schema_names = [s for s in schema_names if not s.startswith("pg_")]
|
||||
|
||||
if suggestion.quoted:
|
||||
schema_names = [self.escape_schema(s) for s in schema_names]
|
||||
|
||||
return self.find_matches(word_before_cursor, schema_names,
|
||||
mode='strict', meta='schema')
|
||||
meta="schema")
|
||||
|
||||
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
|
||||
alias = self.generate_aliases
|
||||
s = suggestion
|
||||
t_sug = Table(s.schema, s.table_refs, s.local_tables)
|
||||
v_sug = View(s.schema, s.table_refs)
|
||||
f_sug = Function(s.schema, s.table_refs, usage='from')
|
||||
f_sug = Function(s.schema, s.table_refs, usage="from")
|
||||
return (
|
||||
self.get_table_matches(t_sug, word_before_cursor, alias) +
|
||||
self.get_view_matches(v_sug, word_before_cursor, alias) +
|
||||
@ -839,42 +824,43 @@ class SQLAutoComplete(object):
|
||||
|
||||
"""
|
||||
template = {
|
||||
'call': self.call_arg_style,
|
||||
'call_display': self.call_arg_display_style,
|
||||
'signature': self.signature_arg_style
|
||||
"call": self.call_arg_style,
|
||||
"call_display": self.call_arg_display_style,
|
||||
"signature": self.signature_arg_style,
|
||||
}[usage]
|
||||
args = func.args()
|
||||
if not template or (
|
||||
usage == 'call' and (
|
||||
len(args) < 2 or func.has_variadic())):
|
||||
return '()'
|
||||
|
||||
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max
|
||||
if not template:
|
||||
return "()"
|
||||
elif usage == "call" and len(args) < 2:
|
||||
return "()"
|
||||
elif usage == "call" and func.has_variadic():
|
||||
return "()"
|
||||
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
|
||||
max_arg_len = max(len(a.name) for a in args) if multiline else 0
|
||||
args = (
|
||||
self._format_arg(template, arg, arg_num + 1, max_arg_len)
|
||||
for arg_num, arg in enumerate(args)
|
||||
)
|
||||
if multiline:
|
||||
return '(' + ','.join('\n ' + a for a in args if a) + '\n)'
|
||||
return "(" + ",".join("\n " + a for a in args if a) + "\n)"
|
||||
else:
|
||||
return '(' + ', '.join(a for a in args if a) + ')'
|
||||
return "(" + ", ".join(a for a in args if a) + ")"
|
||||
|
||||
def _format_arg(self, template, arg, arg_num, max_arg_len):
|
||||
if not template:
|
||||
return None
|
||||
if arg.has_default:
|
||||
arg_default = 'NULL' if arg.default is None else arg.default
|
||||
arg_default = "NULL" if arg.default is None else arg.default
|
||||
# Remove trailing ::(schema.)type
|
||||
arg_default = arg_default_type_strip_regex.sub('', arg_default)
|
||||
arg_default = arg_default_type_strip_regex.sub("", arg_default)
|
||||
else:
|
||||
arg_default = ''
|
||||
arg_default = ""
|
||||
return template.format(
|
||||
max_arg_len=max_arg_len,
|
||||
arg_name=arg.name,
|
||||
arg_num=arg_num,
|
||||
arg_type=arg.datatype,
|
||||
arg_default=arg_default
|
||||
arg_default=arg_default,
|
||||
)
|
||||
|
||||
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
|
||||
@ -890,63 +876,60 @@ class SQLAutoComplete(object):
|
||||
if do_alias:
|
||||
alias = self.alias(cased_tbl, suggestion.table_refs)
|
||||
synonyms = (cased_tbl, generate_alias(cased_tbl))
|
||||
maybe_alias = (' ' + alias) if do_alias else ''
|
||||
maybe_schema = (tbl.schema + '.') if tbl.schema else ''
|
||||
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ''
|
||||
if arg_mode == 'call':
|
||||
display_suffix = self._arg_list_cache['call_display'][tbl.meta]
|
||||
elif arg_mode == 'signature':
|
||||
display_suffix = self._arg_list_cache['signature'][tbl.meta]
|
||||
maybe_alias = (" " + alias) if do_alias else ""
|
||||
maybe_schema = (tbl.schema + ".") if tbl.schema else ""
|
||||
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
|
||||
if arg_mode == "call":
|
||||
display_suffix = self._arg_list_cache["call_display"][tbl.meta]
|
||||
elif arg_mode == "signature":
|
||||
display_suffix = self._arg_list_cache["signature"][tbl.meta]
|
||||
else:
|
||||
display_suffix = ''
|
||||
display_suffix = ""
|
||||
item = maybe_schema + cased_tbl + suffix + maybe_alias
|
||||
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
|
||||
prio2 = 0 if tbl.schema else 1
|
||||
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
|
||||
|
||||
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
|
||||
tables = self.populate_schema_objects(suggestion.schema, 'tables')
|
||||
tables = self.populate_schema_objects(suggestion.schema, "tables")
|
||||
tables.extend(
|
||||
SchemaObject(tbl.name) for tbl in suggestion.local_tables)
|
||||
|
||||
# Unless we're sure the user really wants them, don't suggest the
|
||||
# pg_catalog tables that are implicitly on the search path
|
||||
if not suggestion.schema and (
|
||||
not word_before_cursor.startswith('pg_')):
|
||||
tables = [t for t in tables if not t.name.startswith('pg_')]
|
||||
if not suggestion.schema and \
|
||||
(not word_before_cursor.startswith("pg_")):
|
||||
tables = [t for t in tables if not t.name.startswith("pg_")]
|
||||
tables = [self._make_cand(t, alias, suggestion) for t in tables]
|
||||
return self.find_matches(word_before_cursor, tables,
|
||||
mode='strict', meta='table')
|
||||
return self.find_matches(word_before_cursor, tables, meta="table")
|
||||
|
||||
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
|
||||
views = self.populate_schema_objects(suggestion.schema, 'views')
|
||||
views = self.populate_schema_objects(suggestion.schema, "views")
|
||||
|
||||
if not suggestion.schema and (
|
||||
not word_before_cursor.startswith('pg_')):
|
||||
views = [v for v in views if not v.name.startswith('pg_')]
|
||||
not word_before_cursor.startswith("pg_")):
|
||||
views = [v for v in views if not v.name.startswith("pg_")]
|
||||
views = [self._make_cand(v, alias, suggestion) for v in views]
|
||||
return self.find_matches(word_before_cursor, views,
|
||||
mode='strict', meta='view')
|
||||
return self.find_matches(word_before_cursor, views, meta="view")
|
||||
|
||||
def get_alias_matches(self, suggestion, word_before_cursor):
|
||||
aliases = suggestion.aliases
|
||||
return self.find_matches(word_before_cursor, aliases,
|
||||
mode='strict', meta='table alias')
|
||||
meta="table alias")
|
||||
|
||||
def get_database_matches(self, _, word_before_cursor):
|
||||
return self.find_matches(word_before_cursor, self.databases,
|
||||
mode='strict', meta='database')
|
||||
meta="database")
|
||||
|
||||
def get_keyword_matches(self, suggestion, word_before_cursor):
|
||||
return self.find_matches(word_before_cursor, self.keywords,
|
||||
mode='strict', meta='keyword')
|
||||
meta="keyword")
|
||||
|
||||
def get_datatype_matches(self, suggestion, word_before_cursor):
|
||||
# suggest custom datatypes
|
||||
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
|
||||
types = self.populate_schema_objects(suggestion.schema, "datatypes")
|
||||
types = [self._make_cand(t, False, suggestion) for t in types]
|
||||
matches = self.find_matches(word_before_cursor, types,
|
||||
mode='strict', meta='datatype')
|
||||
matches = self.find_matches(word_before_cursor, types, meta="datatype")
|
||||
return matches
|
||||
|
||||
def get_word_before_cursor(self, word=False):
|
||||
@ -1004,52 +987,6 @@ class SQLAutoComplete(object):
|
||||
Datatype: get_datatype_matches,
|
||||
}
|
||||
|
||||
def addcols(self, schema, rel, alias, reltype, cols, columns):
|
||||
"""
|
||||
Add columns in schema column list.
|
||||
:param schema: Schema for reference.
|
||||
:param rel:
|
||||
:param alias:
|
||||
:param reltype:
|
||||
:param cols:
|
||||
:param columns:
|
||||
:return:
|
||||
"""
|
||||
tbl = TableReference(schema, rel, alias, reltype == 'functions')
|
||||
if tbl not in columns:
|
||||
columns[tbl] = []
|
||||
columns[tbl].extend(cols)
|
||||
|
||||
def _get_schema_columns(self, schemas, tbl, meta, columns):
|
||||
"""
|
||||
Check and add schema table columns as per table.
|
||||
:param schemas: Schema
|
||||
:param tbl:
|
||||
:param meta:
|
||||
:param columns: column list
|
||||
:return:
|
||||
"""
|
||||
for schema in schemas:
|
||||
relname = self.escape_name(tbl.name)
|
||||
schema = self.escape_name(schema)
|
||||
if tbl.is_function:
|
||||
# Return column names from a set-returning function
|
||||
# Get an array of FunctionMetadata objects
|
||||
functions = meta['functions'].get(schema, {}).get(relname)
|
||||
for func in (functions or []):
|
||||
# func is a FunctionMetadata object
|
||||
cols = func.fields()
|
||||
self.addcols(schema, relname, tbl.alias, 'functions', cols,
|
||||
columns)
|
||||
else:
|
||||
for reltype in ('tables', 'views'):
|
||||
cols = meta[reltype].get(schema, {}).get(relname)
|
||||
if cols:
|
||||
cols = cols.values()
|
||||
self.addcols(schema, relname, tbl.alias, reltype, cols,
|
||||
columns)
|
||||
break
|
||||
|
||||
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
|
||||
"""Find all columns in a set of scoped_tables.
|
||||
|
||||
@ -1062,14 +999,37 @@ class SQLAutoComplete(object):
|
||||
columns = OrderedDict()
|
||||
meta = self.dbmetadata
|
||||
|
||||
def addcols(schema, rel, alias, reltype, cols):
|
||||
tbl = TableReference(schema, rel, alias, reltype == "functions")
|
||||
if tbl not in columns:
|
||||
columns[tbl] = []
|
||||
columns[tbl].extend(cols)
|
||||
|
||||
for tbl in scoped_tbls:
|
||||
# Local tables should shadow database tables
|
||||
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
|
||||
cols = ctes[normalize_ref(tbl.name)]
|
||||
self.addcols(None, tbl.name, 'CTE', tbl.alias, cols, columns)
|
||||
addcols(None, tbl.name, "CTE", tbl.alias, cols)
|
||||
continue
|
||||
schemas = [tbl.schema] if tbl.schema else self.search_path
|
||||
self._get_schema_columns(schemas, tbl, meta, columns)
|
||||
for schema in schemas:
|
||||
relname = self.escape_name(tbl.name)
|
||||
schema = self.escape_name(schema)
|
||||
if tbl.is_function:
|
||||
# Return column names from a set-returning function
|
||||
# Get an array of FunctionMetadata objects
|
||||
functions = meta["functions"].get(schema, {}).get(relname)
|
||||
for func in functions or []:
|
||||
# func is a FunctionMetadata object
|
||||
cols = func.fields()
|
||||
addcols(schema, relname, tbl.alias, "functions", cols)
|
||||
else:
|
||||
for reltype in ("tables", "views"):
|
||||
cols = meta[reltype].get(schema, {}).get(relname)
|
||||
if cols:
|
||||
cols = cols.values()
|
||||
addcols(schema, relname, tbl.alias, reltype, cols)
|
||||
break
|
||||
|
||||
return columns
|
||||
|
||||
@ -1125,10 +1085,10 @@ class SQLAutoComplete(object):
|
||||
SchemaObject(
|
||||
name=func,
|
||||
schema=(self._maybe_schema(schema=sch, parent=schema)),
|
||||
meta=meta
|
||||
meta=meta,
|
||||
)
|
||||
for sch in self._get_schemas('functions', schema)
|
||||
for (func, metas) in self.dbmetadata['functions'][sch].items()
|
||||
for sch in self._get_schemas("functions", schema)
|
||||
for (func, metas) in self.dbmetadata["functions"][sch].items()
|
||||
for meta in metas
|
||||
if filter_func(meta)
|
||||
]
|
||||
@ -1234,6 +1194,7 @@ class SQLAutoComplete(object):
|
||||
row['is_aggregate'],
|
||||
row['is_window'],
|
||||
row['is_set_returning'],
|
||||
row['is_extension'],
|
||||
row['arg_defaults'].strip('{}').split(',')
|
||||
if row['arg_defaults'] is not None
|
||||
else row['arg_defaults']
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Using Completion class from
|
||||
https://github.com/jonathanslenders/python-prompt-toolkit/
|
||||
blob/master/prompt_toolkit/completion.py
|
||||
https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/completion/base.py
|
||||
"""
|
||||
|
||||
__all__ = (
|
||||
@ -38,7 +37,7 @@ class Completion(object):
|
||||
assert self.start_position <= 0
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(text=%r, start_position=%r)' % (
|
||||
return "%s(text=%r, start_position=%r)" % (
|
||||
self.__class__.__name__, self.text, self.start_position)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -1,295 +0,0 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile('([^\s]+)$'),
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
"""
|
||||
Find the last word in a sentence.
|
||||
|
||||
>>> last_word('abc')
|
||||
'abc'
|
||||
>>> last_word(' abc')
|
||||
'abc'
|
||||
>>> last_word('')
|
||||
''
|
||||
>>> last_word(' ')
|
||||
''
|
||||
>>> last_word('abc ')
|
||||
''
|
||||
>>> last_word('abc def')
|
||||
'def'
|
||||
>>> last_word('abc def ')
|
||||
''
|
||||
>>> last_word('abc def;')
|
||||
''
|
||||
>>> last_word('bac $def')
|
||||
'def'
|
||||
>>> last_word('bac $def', include='most_punctuations')
|
||||
'$def'
|
||||
>>> last_word('bac \def', include='most_punctuations')
|
||||
'\\\\def'
|
||||
>>> last_word('bac \def;', include='most_punctuations')
|
||||
'\\\\def;'
|
||||
>>> last_word('bac::def', include='most_punctuations')
|
||||
'def'
|
||||
>>> last_word('"foo*bar', include='most_punctuations')
|
||||
'"foo*bar'
|
||||
"""
|
||||
|
||||
if not text: # Empty string
|
||||
return ''
|
||||
|
||||
if text[-1].isspace():
|
||||
return ''
|
||||
else:
|
||||
regex = cleanup_regex[include]
|
||||
matches = regex.search(text)
|
||||
if matches:
|
||||
return matches.group(0)
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
TableReference = namedtuple(
|
||||
'TableReference', ['schema', 'name', 'alias', 'is_function']
|
||||
)
|
||||
|
||||
|
||||
# This code is borrowed from sqlparse example script.
|
||||
# <url>
|
||||
def is_subselect(parsed):
|
||||
if not parsed.is_group():
|
||||
return False
|
||||
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
|
||||
for item in parsed.tokens:
|
||||
if item.ttype is DML and item.value.upper() in sql_type:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _identifier_is_function(identifier):
|
||||
return any(isinstance(t, Function) for t in identifier.tokens)
|
||||
|
||||
|
||||
def extract_from_part(parsed, stop_at_punctuation=True):
|
||||
tbl_prefix_seen = False
|
||||
for item in parsed.tokens:
|
||||
if tbl_prefix_seen:
|
||||
if is_subselect(item):
|
||||
for x in extract_from_part(item, stop_at_punctuation):
|
||||
yield x
|
||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||
raise StopIteration
|
||||
# An incomplete nested select won't be recognized correctly as a
|
||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This
|
||||
# causes the second FROM to trigger this elif condition resulting
|
||||
# in a StopIteration. So we need to ignore the keyword if the
|
||||
# keyword FROM.
|
||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||
# condition. So we need to ignore the keyword JOIN and its
|
||||
# variants INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
tbl_prefix_seen = False
|
||||
else:
|
||||
yield item
|
||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||
item_val = item.value.upper()
|
||||
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
||||
item_val.endswith('JOIN')):
|
||||
tbl_prefix_seen = True
|
||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||
# So this check here is necessary.
|
||||
elif isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
if (identifier.ttype is Keyword and
|
||||
identifier.value.upper() == 'FROM'):
|
||||
tbl_prefix_seen = True
|
||||
break
|
||||
|
||||
|
||||
def extract_table_identifiers(token_stream, allow_functions=True):
|
||||
"""yields tuples of TableReference namedtuples"""
|
||||
|
||||
for item in token_stream:
|
||||
if isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
# Sometimes Keywords (such as FROM ) are classified as
|
||||
# identifiers which don't have the get_real_name() method.
|
||||
try:
|
||||
schema_name = identifier.get_parent_name()
|
||||
real_name = identifier.get_real_name()
|
||||
is_function = (allow_functions and
|
||||
_identifier_is_function(identifier))
|
||||
except AttributeError:
|
||||
continue
|
||||
if real_name:
|
||||
yield TableReference(
|
||||
schema_name, real_name, identifier.get_alias(),
|
||||
is_function
|
||||
)
|
||||
elif isinstance(item, Identifier):
|
||||
real_name = item.get_real_name()
|
||||
schema_name = item.get_parent_name()
|
||||
is_function = allow_functions and _identifier_is_function(item)
|
||||
|
||||
if real_name:
|
||||
yield TableReference(
|
||||
schema_name, real_name, item.get_alias(), is_function
|
||||
)
|
||||
else:
|
||||
name = item.get_name()
|
||||
yield TableReference(
|
||||
None, name, item.get_alias() or name, is_function
|
||||
)
|
||||
elif isinstance(item, Function):
|
||||
yield TableReference(
|
||||
None, item.get_real_name(), item.get_alias(), allow_functions
|
||||
)
|
||||
|
||||
|
||||
# extract_tables is inspired from examples in the sqlparse lib.
|
||||
def extract_tables(sql):
|
||||
"""Extract the table names from an SQL statment.
|
||||
|
||||
Returns a list of TableReference namedtuples
|
||||
|
||||
"""
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
return ()
|
||||
|
||||
# INSERT statements must stop looking for tables at the sign of first
|
||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||
# abc is the table name, but if we don't stop at the first lparen, then
|
||||
# we'll identify abc, col1 and col2 as table names.
|
||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||
|
||||
# Kludge: sqlparse mistakenly identifies insert statements as
|
||||
# function calls due to the parenthesized column list, e.g. interprets
|
||||
# "insert into foo (bar, baz)" as a function call to foo with arguments
|
||||
# (bar, baz). So don't allow any identifiers in insert statements
|
||||
# to have is_function=True
|
||||
identifiers = extract_table_identifiers(
|
||||
stream, allow_functions=not insert_stmt
|
||||
)
|
||||
return tuple(identifiers)
|
||||
|
||||
|
||||
def find_prev_keyword(sql):
|
||||
""" Find the last sql keyword in an SQL statement
|
||||
|
||||
Returns the value of the last keyword, and the text of the query with
|
||||
everything after the last keyword stripped
|
||||
"""
|
||||
if not sql.strip():
|
||||
return None, ''
|
||||
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
flattened = list(parsed.flatten())
|
||||
|
||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||
|
||||
for t in reversed(flattened):
|
||||
if t.value == '(' or (t.is_keyword and (
|
||||
t.value.upper() not in logical_operators)):
|
||||
# Find the location of token t in the original parsed statement
|
||||
# We can't use parsed.token_index(t) because t may be a child
|
||||
# token inside a TokenList, in which case token_index thows an
|
||||
# error Minimal example:
|
||||
# p = sqlparse.parse('select * from foo where bar')
|
||||
# t = list(p.flatten())[-3] # The "Where" token
|
||||
# p.token_index(t) # Throws ValueError: not in list
|
||||
idx = flattened.index(t)
|
||||
|
||||
# Combine the string values of all tokens in the original list
|
||||
# up to and including the target keyword token t, to produce a
|
||||
# query string with everything after the keyword token removed
|
||||
text = ''.join(tok.value for tok in flattened[:idx + 1])
|
||||
return t, text
|
||||
|
||||
return None, ''
|
||||
|
||||
|
||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote"""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
def _parsed_is_open_quote(parsed):
|
||||
tokens = list(parsed.flatten())
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
if tok.match(Token.Error, "'"):
|
||||
# An unmatched single quote
|
||||
return True
|
||||
elif (tok.ttype in Token.Name.Builtin and
|
||||
dollar_quote_regex.match(tok.value)):
|
||||
# Find the matching closing dollar quote sign
|
||||
for (j, tok2) in enumerate(tokens[i + 1:], i + 1):
|
||||
if tok2.match(Token.Name.Builtin, tok.value):
|
||||
# Found the matching closing quote - continue our scan
|
||||
# for open quotes thereafter
|
||||
i = j
|
||||
break
|
||||
else:
|
||||
# No matching dollar sign quote
|
||||
return True
|
||||
|
||||
i += 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def parse_partial_identifier(word):
|
||||
"""Attempt to parse a (partially typed) word as an identifier
|
||||
|
||||
word may include a schema qualification, like `schema_name.partial_name`
|
||||
or `schema_name.` There may also be unclosed quotation marks, like
|
||||
`"schema`, or `schema."partial_name`
|
||||
|
||||
:param word: string representing a (partially complete) identifier
|
||||
:return: sqlparse.sql.Identifier, or None
|
||||
"""
|
||||
|
||||
p = sqlparse.parse(word)[0]
|
||||
n_tok = len(p.tokens)
|
||||
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
|
||||
return p.tokens[0]
|
||||
elif hasattr(p, 'token_next_match') and p.token_next_match(0, Error, '"'):
|
||||
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
|
||||
# Close the double quote, then reparse
|
||||
return parse_partial_identifier(word + '"')
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print(extract_tables(sql))
|
@ -4,7 +4,7 @@ import sqlparse
|
||||
def query_starts_with(query, prefixes):
|
||||
"""Check if the query starts with any item from *prefixes*."""
|
||||
prefixes = [prefix.lower() for prefix in prefixes]
|
||||
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
|
||||
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
|
||||
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
||||
|
||||
|
||||
@ -18,5 +18,5 @@ def queries_start_with(queries, prefixes):
|
||||
|
||||
def is_destructive(queries):
|
||||
"""Returns if any of the queries in *queries* is destructive."""
|
||||
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
|
||||
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
|
||||
return queries_start_with(queries, keywords)
|
||||
|
@ -10,12 +10,11 @@ from .meta import TableMetadata, ColumnMetadata
|
||||
# columns: list of column names
|
||||
# start: index into the original string of the left parens starting the CTE
|
||||
# stop: index into the original string of the right parens ending the CTE
|
||||
TableExpression = namedtuple('TableExpression', 'name columns start stop')
|
||||
TableExpression = namedtuple("TableExpression", "name columns start stop")
|
||||
|
||||
|
||||
def isolate_query_ctes(full_text, text_before_cursor):
|
||||
"""Simplify a query by converting CTEs into table metadata objects
|
||||
"""
|
||||
"""Simplify a query by converting CTEs into table metadata objects"""
|
||||
|
||||
if not full_text:
|
||||
return full_text, text_before_cursor, tuple()
|
||||
@ -30,8 +29,8 @@ def isolate_query_ctes(full_text, text_before_cursor):
|
||||
for cte in ctes:
|
||||
if cte.start < current_position < cte.stop:
|
||||
# Currently editing a cte - treat its body as the current full_text
|
||||
text_before_cursor = full_text[cte.start:current_position]
|
||||
full_text = full_text[cte.start:cte.stop]
|
||||
text_before_cursor = full_text[cte.start: current_position]
|
||||
full_text = full_text[cte.start: cte.stop]
|
||||
return full_text, text_before_cursor, meta
|
||||
|
||||
# Append this cte to the list of available table metadata
|
||||
@ -40,19 +39,19 @@ def isolate_query_ctes(full_text, text_before_cursor):
|
||||
|
||||
# Editing past the last cte (ie the main body of the query)
|
||||
full_text = full_text[ctes[-1].stop:]
|
||||
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
|
||||
text_before_cursor = text_before_cursor[ctes[-1].stop: current_position]
|
||||
|
||||
return full_text, text_before_cursor, tuple(meta)
|
||||
|
||||
|
||||
def extract_ctes(sql):
|
||||
""" Extract constant table expresseions from a query
|
||||
"""Extract constant table expresseions from a query
|
||||
|
||||
Returns tuple (ctes, remainder_sql)
|
||||
Returns tuple (ctes, remainder_sql)
|
||||
|
||||
ctes is a list of TableExpression namedtuples
|
||||
remainder_sql is the text from the original query after the CTEs have
|
||||
been stripped.
|
||||
ctes is a list of TableExpression namedtuples
|
||||
remainder_sql is the text from the original query after the CTEs have
|
||||
been stripped.
|
||||
"""
|
||||
|
||||
p = parse(sql)[0]
|
||||
@ -66,7 +65,7 @@ def extract_ctes(sql):
|
||||
# Get the next (meaningful) token, which should be the first CTE
|
||||
idx, tok = p.token_next(idx)
|
||||
if not tok:
|
||||
return ([], '')
|
||||
return ([], "")
|
||||
start_pos = token_start_pos(p.tokens, idx)
|
||||
ctes = []
|
||||
|
||||
@ -87,7 +86,7 @@ def extract_ctes(sql):
|
||||
idx = p.token_index(tok) + 1
|
||||
|
||||
# Collapse everything after the ctes into a remainder query
|
||||
remainder = ''.join(str(tok) for tok in p.tokens[idx:])
|
||||
remainder = "".join(str(tok) for tok in p.tokens[idx:])
|
||||
|
||||
return ctes, remainder
|
||||
|
||||
@ -112,15 +111,15 @@ def get_cte_from_token(tok, pos0):
|
||||
|
||||
|
||||
def extract_column_names(parsed):
|
||||
# Find the first DML token to check if it's a SELECT or
|
||||
# INSERT/UPDATE/DELETE
|
||||
# Find the first DML token to check if it's a
|
||||
# SELECT or INSERT/UPDATE/DELETE
|
||||
idx, tok = parsed.token_next_by(t=DML)
|
||||
tok_val = tok and tok.value.lower()
|
||||
|
||||
if tok_val in ('insert', 'update', 'delete'):
|
||||
if tok_val in ("insert", "update", "delete"):
|
||||
# Jump ahead to the RETURNING clause where the list of column names is
|
||||
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
|
||||
elif not tok_val == 'select':
|
||||
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
|
||||
elif not tok_val == "select":
|
||||
# Must be invalid CTE
|
||||
return ()
|
||||
|
||||
|
@ -1,23 +1,29 @@
|
||||
from collections import namedtuple
|
||||
|
||||
_ColumnMetadata = namedtuple(
|
||||
'ColumnMetadata',
|
||||
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
|
||||
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default",
|
||||
"has_default"]
|
||||
)
|
||||
|
||||
|
||||
def ColumnMetadata(
|
||||
name, datatype, foreignkeys=None, default=None, has_default=False
|
||||
):
|
||||
return _ColumnMetadata(
|
||||
name, datatype, foreignkeys or [], default, has_default
|
||||
)
|
||||
def ColumnMetadata(name, datatype, foreignkeys=None, default=None,
|
||||
has_default=False):
|
||||
return _ColumnMetadata(name, datatype, foreignkeys or [], default,
|
||||
has_default)
|
||||
|
||||
|
||||
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
|
||||
'parentcolumn', 'childschema',
|
||||
'childtable', 'childcolumn'])
|
||||
TableMetadata = namedtuple('TableMetadata', 'name columns')
|
||||
ForeignKey = namedtuple(
|
||||
"ForeignKey",
|
||||
[
|
||||
"parentschema",
|
||||
"parenttable",
|
||||
"parentcolumn",
|
||||
"childschema",
|
||||
"childtable",
|
||||
"childcolumn",
|
||||
],
|
||||
)
|
||||
TableMetadata = namedtuple("TableMetadata", "name columns")
|
||||
|
||||
|
||||
def parse_defaults(defaults_string):
|
||||
@ -25,34 +31,42 @@ def parse_defaults(defaults_string):
|
||||
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
|
||||
if not defaults_string:
|
||||
return
|
||||
current = ''
|
||||
current = ""
|
||||
in_quote = None
|
||||
for char in defaults_string:
|
||||
if current == '' and char == ' ':
|
||||
if current == "" and char == " ":
|
||||
# Skip space after comma separating default expressions
|
||||
continue
|
||||
if char == '"' or char == '\'':
|
||||
if char == '"' or char == "'":
|
||||
if in_quote and char == in_quote:
|
||||
# End quote
|
||||
in_quote = None
|
||||
elif not in_quote:
|
||||
# Begin quote
|
||||
in_quote = char
|
||||
elif char == ',' and not in_quote:
|
||||
elif char == "," and not in_quote:
|
||||
# End of expression
|
||||
yield current
|
||||
current = ''
|
||||
current = ""
|
||||
continue
|
||||
current += char
|
||||
yield current
|
||||
|
||||
|
||||
class FunctionMetadata(object):
|
||||
|
||||
def __init__(
|
||||
self, schema_name, func_name, arg_names, arg_types, arg_modes,
|
||||
return_type, is_aggregate, is_window, is_set_returning,
|
||||
arg_defaults
|
||||
self,
|
||||
schema_name,
|
||||
func_name,
|
||||
arg_names,
|
||||
arg_types,
|
||||
arg_modes,
|
||||
return_type,
|
||||
is_aggregate,
|
||||
is_window,
|
||||
is_set_returning,
|
||||
is_extension,
|
||||
arg_defaults,
|
||||
):
|
||||
"""Class for describing a postgresql function"""
|
||||
|
||||
@ -80,19 +94,29 @@ class FunctionMetadata(object):
|
||||
self.is_aggregate = is_aggregate
|
||||
self.is_window = is_window
|
||||
self.is_set_returning = is_set_returning
|
||||
self.is_extension = bool(is_extension)
|
||||
self.is_public = self.schema_name and self.schema_name == "public"
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
self.__dict__ == other.__dict__)
|
||||
return isinstance(other, self.__class__) and \
|
||||
self.__dict__ == other.__dict__
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def _signature(self):
|
||||
return (
|
||||
self.schema_name, self.func_name, self.arg_names, self.arg_types,
|
||||
self.arg_modes, self.return_type, self.is_aggregate,
|
||||
self.is_window, self.is_set_returning, self.arg_defaults
|
||||
self.schema_name,
|
||||
self.func_name,
|
||||
self.arg_names,
|
||||
self.arg_types,
|
||||
self.arg_modes,
|
||||
self.return_type,
|
||||
self.is_aggregate,
|
||||
self.is_window,
|
||||
self.is_set_returning,
|
||||
self.is_extension,
|
||||
self.arg_defaults,
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
@ -100,26 +124,25 @@ class FunctionMetadata(object):
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
(
|
||||
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
|
||||
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
|
||||
'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
|
||||
) % (self.__class__.__name__,) + self._signature()
|
||||
)
|
||||
"%s(schema_name=%r, func_name=%r, arg_names=%r, "
|
||||
"arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
|
||||
"is_window=%r, is_set_returning=%r, is_extension=%r, "
|
||||
"arg_defaults=%r)"
|
||||
) % ((self.__class__.__name__,) + self._signature())
|
||||
|
||||
def has_variadic(self):
|
||||
return self.arg_modes and any(
|
||||
arg_mode == 'v' for arg_mode in self.arg_modes)
|
||||
return self.arg_modes and \
|
||||
any(arg_mode == "v" for arg_mode in self.arg_modes)
|
||||
|
||||
def args(self):
|
||||
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
|
||||
if not self.arg_names:
|
||||
return []
|
||||
modes = self.arg_modes or ['i'] * len(self.arg_names)
|
||||
modes = self.arg_modes or ["i"] * len(self.arg_names)
|
||||
args = [
|
||||
(name, typ)
|
||||
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
|
||||
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
|
||||
if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
|
||||
]
|
||||
|
||||
def arg(name, typ, num):
|
||||
@ -127,7 +150,8 @@ class FunctionMetadata(object):
|
||||
num_defaults = len(self.arg_defaults)
|
||||
has_default = num + num_defaults >= num_args
|
||||
default = (
|
||||
self.arg_defaults[num - num_args + num_defaults] if has_default
|
||||
self.arg_defaults[num - num_args + num_defaults]
|
||||
if has_default
|
||||
else None
|
||||
)
|
||||
return ColumnMetadata(name, typ, [], default, has_default)
|
||||
@ -137,7 +161,7 @@ class FunctionMetadata(object):
|
||||
def fields(self):
|
||||
"""Returns a list of output-field ColumnMetadata namedtuples"""
|
||||
|
||||
if self.return_type.lower() == 'void':
|
||||
if self.return_type.lower() == "void":
|
||||
return []
|
||||
elif not self.arg_modes:
|
||||
# For functions without output parameters, the function name
|
||||
@ -145,7 +169,9 @@ class FunctionMetadata(object):
|
||||
# E.g. 'SELECT unnest FROM unnest(...);'
|
||||
return [ColumnMetadata(self.func_name, self.return_type, [])]
|
||||
|
||||
return [ColumnMetadata(name, typ, [])
|
||||
for name, typ, mode in zip(
|
||||
self.arg_names, self.arg_types, self.arg_modes)
|
||||
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
|
||||
return [
|
||||
ColumnMetadata(name, typ, [])
|
||||
for name, typ, mode in zip(self.arg_names, self.arg_types,
|
||||
self.arg_modes)
|
||||
if mode in ("o", "b", "t")
|
||||
] # OUT, INOUT, TABLE
|
||||
|
@ -3,12 +3,15 @@ from collections import namedtuple
|
||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||
|
||||
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
|
||||
'is_function'])
|
||||
TableReference = namedtuple(
|
||||
"TableReference", ["schema", "name", "alias", "is_function"]
|
||||
)
|
||||
TableReference.ref = property(
|
||||
lambda self: self.alias or (
|
||||
self.name if self.name.islower() or self.name[0] == '"'
|
||||
else '"' + self.name + '"')
|
||||
self.name
|
||||
if self.name.islower() or self.name[0] == '"'
|
||||
else '"' + self.name + '"'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -18,9 +21,13 @@ def is_subselect(parsed):
|
||||
if not parsed.is_group:
|
||||
return False
|
||||
for item in parsed.tokens:
|
||||
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
||||
'UPDATE', 'CREATE',
|
||||
'DELETE'):
|
||||
if item.ttype is DML and item.value.upper() in (
|
||||
"SELECT",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"CREATE",
|
||||
"DELETE",
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -37,32 +44,42 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
||||
for x in extract_from_part(item, stop_at_punctuation):
|
||||
yield x
|
||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||
raise StopIteration
|
||||
return
|
||||
# An incomplete nested select won't be recognized correctly as a
|
||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
||||
# the second FROM to trigger this elif condition resulting in a
|
||||
# StopIteration. So we need to ignore the keyword if the keyword
|
||||
# `return`. So we need to ignore the keyword if the keyword
|
||||
# FROM.
|
||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||
# condition. So we need to ignore the keyword JOIN and its variants
|
||||
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
elif (
|
||||
item.ttype is Keyword and
|
||||
(not item.value.upper() == "FROM") and
|
||||
(not item.value.upper().endswith("JOIN"))
|
||||
):
|
||||
tbl_prefix_seen = False
|
||||
else:
|
||||
yield item
|
||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||
item_val = item.value.upper()
|
||||
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
||||
item_val.endswith('JOIN')):
|
||||
if (
|
||||
item_val
|
||||
in (
|
||||
"COPY",
|
||||
"FROM",
|
||||
"INTO",
|
||||
"UPDATE",
|
||||
"TABLE",
|
||||
) or item_val.endswith("JOIN")
|
||||
):
|
||||
tbl_prefix_seen = True
|
||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||
# So this check here is necessary.
|
||||
elif isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
if (identifier.ttype is Keyword and
|
||||
identifier.value.upper() == 'FROM'):
|
||||
if identifier.ttype is Keyword and \
|
||||
identifier.value.upper() == "FROM":
|
||||
tbl_prefix_seen = True
|
||||
break
|
||||
|
||||
@ -94,29 +111,35 @@ def extract_table_identifiers(token_stream, allow_functions=True):
|
||||
name = name.lower()
|
||||
return schema_name, name, alias
|
||||
|
||||
for item in token_stream:
|
||||
if isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
# Sometimes Keywords (such as FROM ) are classified as
|
||||
# identifiers which don't have the get_real_name() method.
|
||||
try:
|
||||
schema_name = identifier.get_parent_name()
|
||||
real_name = identifier.get_real_name()
|
||||
is_function = (allow_functions and
|
||||
_identifier_is_function(identifier))
|
||||
except AttributeError:
|
||||
continue
|
||||
if real_name:
|
||||
yield TableReference(schema_name, real_name,
|
||||
identifier.get_alias(), is_function)
|
||||
elif isinstance(item, Identifier):
|
||||
schema_name, real_name, alias = parse_identifier(item)
|
||||
is_function = allow_functions and _identifier_is_function(item)
|
||||
try:
|
||||
for item in token_stream:
|
||||
if isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
# Sometimes Keywords (such as FROM ) are classified as
|
||||
# identifiers which don't have the get_real_name() method.
|
||||
try:
|
||||
schema_name = identifier.get_parent_name()
|
||||
real_name = identifier.get_real_name()
|
||||
is_function = allow_functions and \
|
||||
_identifier_is_function(identifier)
|
||||
except AttributeError:
|
||||
continue
|
||||
if real_name:
|
||||
yield TableReference(
|
||||
schema_name, real_name, identifier.get_alias(),
|
||||
is_function
|
||||
)
|
||||
elif isinstance(item, Identifier):
|
||||
schema_name, real_name, alias = parse_identifier(item)
|
||||
is_function = allow_functions and _identifier_is_function(item)
|
||||
|
||||
yield TableReference(schema_name, real_name, alias, is_function)
|
||||
elif isinstance(item, Function):
|
||||
schema_name, real_name, alias = parse_identifier(item)
|
||||
yield TableReference(None, real_name, alias, allow_functions)
|
||||
yield TableReference(schema_name, real_name, alias,
|
||||
is_function)
|
||||
elif isinstance(item, Function):
|
||||
schema_name, real_name, alias = parse_identifier(item)
|
||||
yield TableReference(None, real_name, alias, allow_functions)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
|
||||
# extract_tables is inspired from examples in the sqlparse lib.
|
||||
@ -134,7 +157,7 @@ def extract_tables(sql):
|
||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||
# abc is the table name, but if we don't stop at the first lparen, then
|
||||
# we'll identify abc, col1 and col2 as table names.
|
||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||
insert_stmt = parsed[0].token_first().value.lower() == "insert"
|
||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||
|
||||
# Kludge: sqlparse mistakenly identifies insert statements as
|
||||
|
@ -5,17 +5,17 @@ from sqlparse.tokens import Token, Error
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||
"alphanum_underscore": re.compile(r"(\w+)$"),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||
"many_punctuations": re.compile(r"([^():,\s]+)$"),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||
"all_punctuations": re.compile(r"([^\s]+)$"),
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
def last_word(text, include="alphanum_underscore"):
|
||||
r"""
|
||||
Find the last word in a sentence.
|
||||
|
||||
@ -49,41 +49,42 @@ def last_word(text, include='alphanum_underscore'):
|
||||
'"foo*bar'
|
||||
"""
|
||||
|
||||
if not text: # Empty string
|
||||
return ''
|
||||
if not text: # Empty string
|
||||
return ""
|
||||
|
||||
if text[-1].isspace():
|
||||
return ''
|
||||
return ""
|
||||
else:
|
||||
regex = cleanup_regex[include]
|
||||
matches = regex.search(text)
|
||||
if matches:
|
||||
return matches.group(0)
|
||||
else:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
def find_prev_keyword(sql, n_skip=0):
|
||||
""" Find the last sql keyword in an SQL statement
|
||||
"""Find the last sql keyword in an SQL statement
|
||||
|
||||
Returns the value of the last keyword, and the text of the query with
|
||||
everything after the last keyword stripped
|
||||
"""
|
||||
if not sql.strip():
|
||||
return None, ''
|
||||
return None, ""
|
||||
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
flattened = list(parsed.flatten())
|
||||
flattened = flattened[:len(flattened) - n_skip]
|
||||
flattened = flattened[: len(flattened) - n_skip]
|
||||
|
||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
|
||||
|
||||
for t in reversed(flattened):
|
||||
if t.value == '(' or (t.is_keyword and (
|
||||
t.value.upper() not in logical_operators)):
|
||||
if t.value == "(" or (
|
||||
t.is_keyword and (t.value.upper() not in logical_operators)
|
||||
):
|
||||
# Find the location of token t in the original parsed statement
|
||||
# We can't use parsed.token_index(t) because t may be a child token
|
||||
# inside a TokenList, in which case token_index thows an error
|
||||
# inside a TokenList, in which case token_index throws an error
|
||||
# Minimal example:
|
||||
# p = sqlparse.parse('select * from foo where bar')
|
||||
# t = list(p.flatten())[-3] # The "Where" token
|
||||
@ -93,14 +94,14 @@ def find_prev_keyword(sql, n_skip=0):
|
||||
# Combine the string values of all tokens in the original list
|
||||
# up to and including the target keyword token t, to produce a
|
||||
# query string with everything after the keyword token removed
|
||||
text = ''.join(tok.value for tok in flattened[:idx + 1])
|
||||
text = "".join(tok.value for tok in flattened[: idx + 1])
|
||||
return t, text
|
||||
|
||||
return None, ''
|
||||
return None, ""
|
||||
|
||||
|
||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
||||
dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
|
@ -4,13 +4,13 @@ from sqlparse.tokens import Name
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
||||
white_space_regex = re.compile("\\s+", re.MULTILINE)
|
||||
|
||||
|
||||
def _compile_regex(keyword):
|
||||
# Surround the keyword with word boundaries and replace interior whitespace
|
||||
# with whitespace wildcards
|
||||
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b'
|
||||
pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
|
||||
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
||||
|
||||
|
||||
|
@ -3,28 +3,29 @@ import re
|
||||
import sqlparse
|
||||
from collections import namedtuple
|
||||
from sqlparse.sql import Comparison, Identifier, Where
|
||||
from .parseutils.utils import (
|
||||
last_word, find_prev_keyword, parse_partial_identifier)
|
||||
from .parseutils.utils import last_word, find_prev_keyword,\
|
||||
parse_partial_identifier
|
||||
from .parseutils.tables import extract_tables
|
||||
from .parseutils.ctes import isolate_query_ctes
|
||||
|
||||
Special = namedtuple('Special', [])
|
||||
Database = namedtuple('Database', [])
|
||||
Schema = namedtuple('Schema', ['quoted'])
|
||||
|
||||
Special = namedtuple("Special", [])
|
||||
Database = namedtuple("Database", [])
|
||||
Schema = namedtuple("Schema", ["quoted"])
|
||||
Schema.__new__.__defaults__ = (False,)
|
||||
# FromClauseItem is a table/view/function used in the FROM clause
|
||||
# `table_refs` contains the list of tables/... already in the statement,
|
||||
# used to ensure that the alias we suggest is unique
|
||||
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
|
||||
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
|
||||
TableFormat = namedtuple('TableFormat', [])
|
||||
View = namedtuple('View', ['schema', 'table_refs'])
|
||||
FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
|
||||
Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
|
||||
TableFormat = namedtuple("TableFormat", [])
|
||||
View = namedtuple("View", ["schema", "table_refs"])
|
||||
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
|
||||
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
|
||||
JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
|
||||
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
|
||||
Join = namedtuple('Join', ['table_refs', 'schema'])
|
||||
Join = namedtuple("Join", ["table_refs", "schema"])
|
||||
|
||||
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
|
||||
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
|
||||
# For convenience, don't require the `usage` argument in Function constructor
|
||||
Function.__new__.__defaults__ = (None, tuple(), None)
|
||||
Table.__new__.__defaults__ = (None, tuple(), tuple())
|
||||
@ -32,31 +33,33 @@ View.__new__.__defaults__ = (None, tuple())
|
||||
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
|
||||
|
||||
Column = namedtuple(
|
||||
'Column',
|
||||
['table_refs', 'require_last_table', 'local_tables',
|
||||
'qualifiable', 'context']
|
||||
"Column",
|
||||
["table_refs", "require_last_table", "local_tables", "qualifiable",
|
||||
"context"],
|
||||
)
|
||||
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
|
||||
|
||||
Keyword = namedtuple('Keyword', ['last_token'])
|
||||
Keyword = namedtuple("Keyword", ["last_token"])
|
||||
Keyword.__new__.__defaults__ = (None,)
|
||||
NamedQuery = namedtuple('NamedQuery', [])
|
||||
Datatype = namedtuple('Datatype', ['schema'])
|
||||
Alias = namedtuple('Alias', ['aliases'])
|
||||
NamedQuery = namedtuple("NamedQuery", [])
|
||||
Datatype = namedtuple("Datatype", ["schema"])
|
||||
Alias = namedtuple("Alias", ["aliases"])
|
||||
|
||||
Path = namedtuple('Path', [])
|
||||
Path = namedtuple("Path", [])
|
||||
|
||||
|
||||
class SqlStatement(object):
|
||||
def __init__(self, full_text, text_before_cursor):
|
||||
self.identifier = None
|
||||
self.word_before_cursor = word_before_cursor = last_word(
|
||||
text_before_cursor, include='many_punctuations')
|
||||
text_before_cursor, include="many_punctuations"
|
||||
)
|
||||
full_text = _strip_named_query(full_text)
|
||||
text_before_cursor = _strip_named_query(text_before_cursor)
|
||||
|
||||
full_text, text_before_cursor, self.local_tables = \
|
||||
isolate_query_ctes(full_text, text_before_cursor)
|
||||
full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
|
||||
full_text, text_before_cursor
|
||||
)
|
||||
|
||||
self.text_before_cursor_including_last_word = text_before_cursor
|
||||
|
||||
@ -67,39 +70,41 @@ class SqlStatement(object):
|
||||
# completion useless because it will always return the list of
|
||||
# keywords as completion.
|
||||
if self.word_before_cursor:
|
||||
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
|
||||
if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
else:
|
||||
text_before_cursor = \
|
||||
text_before_cursor[:-len(word_before_cursor)]
|
||||
text_before_cursor[: -len(word_before_cursor)]
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
self.identifier = parse_partial_identifier(word_before_cursor)
|
||||
else:
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
|
||||
full_text, text_before_cursor, parsed = \
|
||||
_split_multiple_statements(full_text, text_before_cursor, parsed)
|
||||
full_text, text_before_cursor, parsed = _split_multiple_statements(
|
||||
full_text, text_before_cursor, parsed
|
||||
)
|
||||
|
||||
self.full_text = full_text
|
||||
self.text_before_cursor = text_before_cursor
|
||||
self.parsed = parsed
|
||||
|
||||
self.last_token = \
|
||||
parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
|
||||
self.last_token = parsed and \
|
||||
parsed.token_prev(len(parsed.tokens))[1] or ""
|
||||
|
||||
def is_insert(self):
|
||||
return self.parsed.token_first().value.lower() == 'insert'
|
||||
return self.parsed.token_first().value.lower() == "insert"
|
||||
|
||||
def get_tables(self, scope='full'):
|
||||
""" Gets the tables available in the statement.
|
||||
def get_tables(self, scope="full"):
|
||||
"""Gets the tables available in the statement.
|
||||
param `scope:` possible values: 'full', 'insert', 'before'
|
||||
If 'insert', only the first table is returned.
|
||||
If 'before', only tables before the cursor are returned.
|
||||
If not 'insert' and the stmt is an insert, the first table is skipped.
|
||||
"""
|
||||
tables = extract_tables(
|
||||
self.full_text if scope == 'full' else self.text_before_cursor)
|
||||
if scope == 'insert':
|
||||
self.full_text if scope == "full" else self.text_before_cursor
|
||||
)
|
||||
if scope == "insert":
|
||||
tables = tables[:1]
|
||||
elif self.is_insert():
|
||||
tables = tables[1:]
|
||||
@ -118,8 +123,9 @@ class SqlStatement(object):
|
||||
return schema
|
||||
|
||||
def reduce_to_prev_keyword(self, n_skip=0):
|
||||
prev_keyword, self.text_before_cursor = \
|
||||
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
|
||||
prev_keyword, self.text_before_cursor = find_prev_keyword(
|
||||
self.text_before_cursor, n_skip=n_skip
|
||||
)
|
||||
return prev_keyword
|
||||
|
||||
|
||||
@ -131,7 +137,7 @@ def suggest_type(full_text, text_before_cursor):
|
||||
A scope for a column category will be a list of tables.
|
||||
"""
|
||||
|
||||
if full_text.startswith('\\i '):
|
||||
if full_text.startswith("\\i "):
|
||||
return (Path(),)
|
||||
|
||||
# This is a temporary hack; the exception handling
|
||||
@ -144,7 +150,7 @@ def suggest_type(full_text, text_before_cursor):
|
||||
return suggest_based_on_last_token(stmt.last_token, stmt)
|
||||
|
||||
|
||||
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
|
||||
named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
|
||||
|
||||
|
||||
def _strip_named_query(txt):
|
||||
@ -155,11 +161,11 @@ def _strip_named_query(txt):
|
||||
"""
|
||||
|
||||
if named_query_regex.match(txt):
|
||||
txt = named_query_regex.sub('', txt)
|
||||
txt = named_query_regex.sub("", txt)
|
||||
return txt
|
||||
|
||||
|
||||
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
|
||||
function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
|
||||
|
||||
|
||||
def _find_function_body(text):
|
||||
@ -205,12 +211,12 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
|
||||
return full_text, text_before_cursor, None
|
||||
|
||||
token2 = None
|
||||
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
|
||||
if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
|
||||
token1 = statement.token_first()
|
||||
if token1:
|
||||
token1_idx = statement.token_index(token1)
|
||||
token2 = statement.token_next(token1_idx)[1]
|
||||
if token2 and token2.value.upper() == 'FUNCTION':
|
||||
if token2 and token2.value.upper() == "FUNCTION":
|
||||
full_text, text_before_cursor, statement = _statement_from_function(
|
||||
full_text, text_before_cursor, statement
|
||||
)
|
||||
@ -246,9 +252,9 @@ def suggest_based_on_last_token(token, stmt):
|
||||
# SELECT Identifier <CURSOR>
|
||||
# SELECT foo FROM Identifier <CURSOR>
|
||||
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
|
||||
if prev_keyword and prev_keyword.value == '(':
|
||||
if prev_keyword and prev_keyword.value == "(":
|
||||
# Suggest datatypes
|
||||
return suggest_based_on_last_token('type', stmt)
|
||||
return suggest_based_on_last_token("type", stmt)
|
||||
else:
|
||||
return (Keyword(),)
|
||||
else:
|
||||
@ -256,7 +262,7 @@ def suggest_based_on_last_token(token, stmt):
|
||||
|
||||
if not token:
|
||||
return (Keyword(),)
|
||||
elif token_v.endswith('('):
|
||||
elif token_v.endswith("("):
|
||||
p = sqlparse.parse(stmt.text_before_cursor)[0]
|
||||
|
||||
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||
@ -268,10 +274,10 @@ def suggest_based_on_last_token(token, stmt):
|
||||
# 3 - Subquery expression like "WHERE EXISTS ("
|
||||
# Suggest keywords, in order to do a subquery
|
||||
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
||||
# Suggest columns/functions AND keywords. (If we wanted to be
|
||||
# really fancy, we could suggest only array-typed columns)
|
||||
# Suggest columns/functions AND keywords. (If we wanted to
|
||||
# be really fancy, we could suggest only array-typed columns)
|
||||
|
||||
column_suggestions = suggest_based_on_last_token('where', stmt)
|
||||
column_suggestions = suggest_based_on_last_token("where", stmt)
|
||||
|
||||
# Check for a subquery expression (cases 3 & 4)
|
||||
where = p.tokens[-1]
|
||||
@ -282,7 +288,7 @@ def suggest_based_on_last_token(token, stmt):
|
||||
prev_tok = prev_tok.tokens[-1]
|
||||
|
||||
prev_tok = prev_tok.value.lower()
|
||||
if prev_tok == 'exists':
|
||||
if prev_tok == "exists":
|
||||
return (Keyword(),)
|
||||
else:
|
||||
return column_suggestions
|
||||
@ -292,59 +298,47 @@ def suggest_based_on_last_token(token, stmt):
|
||||
|
||||
if (
|
||||
prev_tok and prev_tok.value and
|
||||
prev_tok.value.lower().split(' ')[-1] == 'using'
|
||||
prev_tok.value.lower().split(" ")[-1] == "using"
|
||||
):
|
||||
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||
tables = stmt.get_tables('before')
|
||||
tables = stmt.get_tables("before")
|
||||
|
||||
# suggest columns that are present in more than one table
|
||||
return (Column(table_refs=tables,
|
||||
require_last_table=True,
|
||||
local_tables=stmt.local_tables),)
|
||||
|
||||
# If the lparen is preceeded by a space chances are we're about to
|
||||
# do a sub-select.
|
||||
elif p.token_first().value.lower() == 'select' and \
|
||||
last_word(stmt.text_before_cursor,
|
||||
'all_punctuations').startswith('('):
|
||||
return (Keyword(),)
|
||||
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
|
||||
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
|
||||
return (
|
||||
Column(table_refs=stmt.get_tables('insert'), context='insert'),
|
||||
Column(
|
||||
table_refs=tables,
|
||||
require_last_table=True,
|
||||
local_tables=stmt.local_tables,
|
||||
),
|
||||
)
|
||||
|
||||
elif p.token_first().value.lower() == "select":
|
||||
# If the lparen is preceeded by a space chances are we're about to
|
||||
# do a sub-select.
|
||||
if last_word(stmt.text_before_cursor,
|
||||
"all_punctuations").startswith("("):
|
||||
return (Keyword(),)
|
||||
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
|
||||
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
|
||||
return (Column(table_refs=stmt.get_tables("insert"),
|
||||
context="insert"),)
|
||||
# We're probably in a function argument list
|
||||
return (Column(table_refs=extract_tables(stmt.full_text),
|
||||
local_tables=stmt.local_tables, qualifiable=True),)
|
||||
elif token_v == 'set':
|
||||
return _suggest_expression(token_v, stmt)
|
||||
elif token_v == "set":
|
||||
return (Column(table_refs=stmt.get_tables(),
|
||||
local_tables=stmt.local_tables),)
|
||||
elif token_v in ('select', 'where', 'having', 'by', 'distinct'):
|
||||
# Check for a table alias or schema qualification
|
||||
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
|
||||
tables = stmt.get_tables()
|
||||
if parent:
|
||||
tables = tuple(t for t in tables if identifies(parent, t))
|
||||
return (Column(table_refs=tables, local_tables=stmt.local_tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent),)
|
||||
else:
|
||||
return (Column(table_refs=tables, local_tables=stmt.local_tables,
|
||||
qualifiable=True),
|
||||
Function(schema=None),
|
||||
Keyword(token_v.upper()),)
|
||||
elif token_v == 'as':
|
||||
elif token_v in ("select", "where", "having", "order by", "distinct"):
|
||||
return _suggest_expression(token_v, stmt)
|
||||
elif token_v == "as":
|
||||
# Don't suggest anything for aliases
|
||||
return ()
|
||||
elif (
|
||||
(token_v.endswith('join') and token.is_keyword) or
|
||||
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
|
||||
elif (token_v.endswith("join") and token.is_keyword) or (
|
||||
token_v in ("copy", "from", "update", "into", "describe", "truncate")
|
||||
):
|
||||
|
||||
schema = stmt.get_identifier_schema()
|
||||
tables = extract_tables(stmt.text_before_cursor)
|
||||
is_join = token_v.endswith('join') and token.is_keyword
|
||||
is_join = token_v.endswith("join") and token.is_keyword
|
||||
|
||||
# Suggest tables from either the currently-selected schema or the
|
||||
# public schema if no schema has been specified
|
||||
@ -354,60 +348,77 @@ def suggest_based_on_last_token(token, stmt):
|
||||
# Suggest schemas
|
||||
suggest.insert(0, Schema())
|
||||
|
||||
if token_v == 'from' or is_join:
|
||||
suggest.append(FromClauseItem(schema=schema,
|
||||
table_refs=tables,
|
||||
local_tables=stmt.local_tables))
|
||||
elif token_v == 'truncate':
|
||||
if token_v == "from" or is_join:
|
||||
suggest.append(
|
||||
FromClauseItem(
|
||||
schema=schema, table_refs=tables,
|
||||
local_tables=stmt.local_tables
|
||||
)
|
||||
)
|
||||
elif token_v == "truncate":
|
||||
suggest.append(Table(schema))
|
||||
else:
|
||||
suggest.extend((Table(schema), View(schema)))
|
||||
|
||||
if is_join and _allow_join(stmt.parsed):
|
||||
tables = stmt.get_tables('before')
|
||||
tables = stmt.get_tables("before")
|
||||
suggest.append(Join(table_refs=tables, schema=schema))
|
||||
|
||||
return tuple(suggest)
|
||||
|
||||
elif token_v == 'function':
|
||||
elif token_v == "function":
|
||||
schema = stmt.get_identifier_schema()
|
||||
|
||||
# stmt.get_previous_token will fail for e.g.
|
||||
# `SELECT 1 FROM functions WHERE function:`
|
||||
try:
|
||||
prev = stmt.get_previous_token(token).value.lower()
|
||||
if prev in ('drop', 'alter', 'create', 'create or replace'):
|
||||
return (Function(schema=schema, usage='signature'),)
|
||||
if prev in ("drop", "alter", "create", "create or replace"):
|
||||
|
||||
# Suggest functions from either the currently-selected schema
|
||||
# or the public schema if no schema has been specified
|
||||
suggest = []
|
||||
|
||||
if not schema:
|
||||
# Suggest schemas
|
||||
suggest.insert(0, Schema())
|
||||
|
||||
suggest.append(Function(schema=schema, usage="signature"))
|
||||
return tuple(suggest)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
return tuple()
|
||||
|
||||
elif token_v in ('table', 'view'):
|
||||
elif token_v in ("table", "view"):
|
||||
# E.g. 'ALTER TABLE <tablname>'
|
||||
rel_type = \
|
||||
{'table': Table, 'view': View, 'function': Function}[token_v]
|
||||
{"table": Table, "view": View, "function": Function}[token_v]
|
||||
schema = stmt.get_identifier_schema()
|
||||
if schema:
|
||||
return (rel_type(schema=schema),)
|
||||
else:
|
||||
return (Schema(), rel_type(schema=schema))
|
||||
|
||||
elif token_v == 'column':
|
||||
elif token_v == "column":
|
||||
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
|
||||
return (Column(table_refs=stmt.get_tables()),)
|
||||
|
||||
elif token_v == 'on':
|
||||
tables = stmt.get_tables('before')
|
||||
elif token_v == "on":
|
||||
tables = stmt.get_tables("before")
|
||||
parent = \
|
||||
(stmt.identifier and stmt.identifier.get_parent_name()) or None
|
||||
if parent:
|
||||
# "ON parent.<suggestion>"
|
||||
# parent can be either a schema name or table alias
|
||||
filteredtables = tuple(t for t in tables if identifies(parent, t))
|
||||
sugs = [Column(table_refs=filteredtables,
|
||||
local_tables=stmt.local_tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent)]
|
||||
sugs = [
|
||||
Column(table_refs=filteredtables,
|
||||
local_tables=stmt.local_tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent),
|
||||
]
|
||||
if filteredtables and _allow_join_condition(stmt.parsed):
|
||||
sugs.append(JoinCondition(table_refs=tables,
|
||||
parent=filteredtables[-1]))
|
||||
@ -417,38 +428,39 @@ def suggest_based_on_last_token(token, stmt):
|
||||
# Use table alias if there is one, otherwise the table name
|
||||
aliases = tuple(t.ref for t in tables)
|
||||
if _allow_join_condition(stmt.parsed):
|
||||
return (Alias(aliases=aliases), JoinCondition(
|
||||
table_refs=tables, parent=None))
|
||||
return (
|
||||
Alias(aliases=aliases),
|
||||
JoinCondition(table_refs=tables, parent=None),
|
||||
)
|
||||
else:
|
||||
return (Alias(aliases=aliases),)
|
||||
|
||||
elif token_v in ('c', 'use', 'database', 'template'):
|
||||
elif token_v in ("c", "use", "database", "template"):
|
||||
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||
return (Database(),)
|
||||
elif token_v == 'schema':
|
||||
elif token_v == "schema":
|
||||
# DROP SCHEMA schema_name, SET SCHEMA schema name
|
||||
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
|
||||
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
|
||||
quoted = prev_keyword and prev_keyword.value.lower() == "set"
|
||||
return (Schema(quoted),)
|
||||
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
|
||||
elif token_v.endswith(",") or token_v in ("=", "and", "or"):
|
||||
prev_keyword = stmt.reduce_to_prev_keyword()
|
||||
if prev_keyword:
|
||||
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||
else:
|
||||
return ()
|
||||
elif token_v in ('type', '::'):
|
||||
elif token_v in ("type", "::"):
|
||||
# ALTER TABLE foo SET DATA TYPE bar
|
||||
# SELECT foo::bar
|
||||
# Note that tables are a form of composite type in postgresql, so
|
||||
# they're suggested here as well
|
||||
schema = stmt.get_identifier_schema()
|
||||
suggestions = [Datatype(schema=schema),
|
||||
Table(schema=schema)]
|
||||
suggestions = [Datatype(schema=schema), Table(schema=schema)]
|
||||
if not schema:
|
||||
suggestions.append(Schema())
|
||||
return tuple(suggestions)
|
||||
elif token_v in ['alter', 'create', 'drop']:
|
||||
elif token_v in {"alter", "create", "drop"}:
|
||||
return (Keyword(token_v.upper()),)
|
||||
elif token.is_keyword:
|
||||
# token is a keyword we haven't implemented any special handling for
|
||||
@ -462,11 +474,38 @@ def suggest_based_on_last_token(token, stmt):
|
||||
return (Keyword(),)
|
||||
|
||||
|
||||
def _suggest_expression(token_v, stmt):
|
||||
"""
|
||||
Return suggestions for an expression, taking account of any partially-typed
|
||||
identifier's parent, which may be a table alias or schema name.
|
||||
"""
|
||||
parent = stmt.identifier.get_parent_name() if stmt.identifier else []
|
||||
tables = stmt.get_tables()
|
||||
|
||||
if parent:
|
||||
tables = tuple(t for t in tables if identifies(parent, t))
|
||||
return (
|
||||
Column(table_refs=tables, local_tables=stmt.local_tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent),
|
||||
)
|
||||
|
||||
return (
|
||||
Column(table_refs=tables, local_tables=stmt.local_tables,
|
||||
qualifiable=True),
|
||||
Function(schema=None),
|
||||
Keyword(token_v.upper()),
|
||||
)
|
||||
|
||||
|
||||
def identifies(id, ref):
|
||||
"""Returns true if string `id` matches TableReference `ref`"""
|
||||
|
||||
return id == ref.alias or id == ref.name or (
|
||||
ref.schema and (id == ref.schema + '.' + ref.name))
|
||||
return (
|
||||
id == ref.alias or id == ref.name or
|
||||
(ref.schema and (id == ref.schema + "." + ref.name))
|
||||
)
|
||||
|
||||
|
||||
def _allow_join_condition(statement):
|
||||
@ -486,7 +525,7 @@ def _allow_join_condition(statement):
|
||||
return False
|
||||
|
||||
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||
return last_tok.value.lower() in ('on', 'and', 'or')
|
||||
return last_tok.value.lower() in ("on", "and", "or")
|
||||
|
||||
|
||||
def _allow_join(statement):
|
||||
@ -505,7 +544,5 @@ def _allow_join(statement):
|
||||
return False
|
||||
|
||||
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||
return (
|
||||
last_tok.value.lower().endswith('join') and
|
||||
last_tok.value.lower() not in ('cross join', 'natural join')
|
||||
)
|
||||
return last_tok.value.lower().endswith("join") and \
|
||||
last_tok.value.lower() not in ("cross join", "natural join",)
|
||||
|
Loading…
Reference in New Issue
Block a user