mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-02-25 18:55:31 -06:00
Merged the latest code of 'pgcli' used for the autocomplete feature. Fixes #5497
This commit is contained in:
parent
3f817494f8
commit
300de05a20
@ -16,6 +16,7 @@ Housekeeping
|
|||||||
************
|
************
|
||||||
|
|
||||||
| `Issue #5330 <https://redmine.postgresql.org/issues/5330>`_ - Improve code coverage and API test cases for Functions.
|
| `Issue #5330 <https://redmine.postgresql.org/issues/5330>`_ - Improve code coverage and API test cases for Functions.
|
||||||
|
| `Issue #5497 <https://redmine.postgresql.org/issues/5497>`_ - Merged the latest code of 'pgcli' used for the autocomplete feature.
|
||||||
|
|
||||||
Bug fixes
|
Bug fixes
|
||||||
*********
|
*********
|
||||||
|
@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
|
|||||||
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
||||||
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
||||||
p.proretset is_set_returning,
|
p.proretset is_set_returning,
|
||||||
|
d.deptype = 'e' is_extension,
|
||||||
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||||
FROM pg_catalog.pg_proc p
|
FROM pg_catalog.pg_proc p
|
||||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||||
|
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
|
||||||
WHERE p.prorettype::regtype != 'trigger'::regtype
|
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||||
AND n.nspname IN ({{schema_names}})
|
AND n.nspname IN ({{schema_names}})
|
||||||
ORDER BY 1, 2
|
ORDER BY 1, 2
|
||||||
|
@ -8,9 +8,11 @@ SELECT n.nspname schema_name,
|
|||||||
p.proisagg is_aggregate,
|
p.proisagg is_aggregate,
|
||||||
p.proiswindow is_window,
|
p.proiswindow is_window,
|
||||||
p.proretset is_set_returning,
|
p.proretset is_set_returning,
|
||||||
|
d.deptype = 'e' is_extension,
|
||||||
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||||
FROM pg_catalog.pg_proc p
|
FROM pg_catalog.pg_proc p
|
||||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||||
|
LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
|
||||||
WHERE p.prorettype::regtype != 'trigger'::regtype
|
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||||
AND n.nspname IN ({{schema_names}})
|
AND n.nspname IN ({{schema_names}})
|
||||||
ORDER BY 1, 2
|
ORDER BY 1, 2
|
||||||
|
@ -11,8 +11,7 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import operator
|
import operator
|
||||||
import sys
|
from itertools import count
|
||||||
from itertools import count, repeat, chain
|
|
||||||
from .completion import Completion
|
from .completion import Completion
|
||||||
from collections import namedtuple, defaultdict, OrderedDict
|
from collections import namedtuple, defaultdict, OrderedDict
|
||||||
|
|
||||||
@ -28,9 +27,9 @@ from pgadmin.utils.driver import get_driver
|
|||||||
from config import PG_DEFAULT_DRIVER
|
from config import PG_DEFAULT_DRIVER
|
||||||
from pgadmin.utils.preferences import Preferences
|
from pgadmin.utils.preferences import Preferences
|
||||||
|
|
||||||
Match = namedtuple('Match', ['completion', 'priority'])
|
Match = namedtuple("Match", ["completion", "priority"])
|
||||||
|
|
||||||
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
|
_SchemaObject = namedtuple("SchemaObject", "name schema meta")
|
||||||
|
|
||||||
|
|
||||||
def SchemaObject(name, schema=None, meta=None):
|
def SchemaObject(name, schema=None, meta=None):
|
||||||
@ -41,14 +40,12 @@ def SchemaObject(name, schema=None, meta=None):
|
|||||||
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
|
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
|
||||||
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
|
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
|
||||||
|
|
||||||
_Candidate = namedtuple(
|
_Candidate = namedtuple("Candidate",
|
||||||
'Candidate', 'completion prio meta synonyms prio2 display'
|
"completion prio meta synonyms prio2 display")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def Candidate(
|
def Candidate(
|
||||||
completion, prio=None, meta=None, synonyms=None, prio2=None,
|
completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
|
||||||
display=None
|
|
||||||
):
|
):
|
||||||
return _Candidate(
|
return _Candidate(
|
||||||
completion, prio, meta, synonyms or [completion], prio2,
|
completion, prio, meta, synonyms or [completion], prio2,
|
||||||
@ -57,7 +54,7 @@ def Candidate(
|
|||||||
|
|
||||||
|
|
||||||
# Used to strip trailing '::some_type' from default-value expressions
|
# Used to strip trailing '::some_type' from default-value expressions
|
||||||
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
|
arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
|
||||||
|
|
||||||
|
|
||||||
def normalize_ref(ref):
|
def normalize_ref(ref):
|
||||||
@ -65,15 +62,15 @@ def normalize_ref(ref):
|
|||||||
|
|
||||||
|
|
||||||
def generate_alias(tbl):
|
def generate_alias(tbl):
|
||||||
""" Generate a table alias, consisting of all upper-case letters in
|
"""Generate a table alias, consisting of all upper-case letters in
|
||||||
the table name, or, if there are no upper-case letters, the first letter +
|
the table name, or, if there are no upper-case letters, the first letter +
|
||||||
all letters preceded by _
|
all letters preceded by _
|
||||||
param tbl - unescaped name of the table to alias
|
param tbl - unescaped name of the table to alias
|
||||||
"""
|
"""
|
||||||
return ''.join(
|
return "".join(
|
||||||
[letter for letter in tbl if letter.isupper()] or
|
[letter for letter in tbl if letter.isupper()] or
|
||||||
[letter for letter, prev in zip(tbl, '_' + tbl)
|
[letter for letter, prev in zip(tbl, "_" + tbl)
|
||||||
if prev == '_' and letter != '_']
|
if prev == "_" and letter != "_"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -97,13 +94,14 @@ class SQLAutoComplete(object):
|
|||||||
self.sid = kwargs['sid'] if 'sid' in kwargs else None
|
self.sid = kwargs['sid'] if 'sid' in kwargs else None
|
||||||
self.conn = kwargs['conn'] if 'conn' in kwargs else None
|
self.conn = kwargs['conn'] if 'conn' in kwargs else None
|
||||||
self.keywords = []
|
self.keywords = []
|
||||||
|
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
|
||||||
|
|
||||||
self.databases = []
|
self.databases = []
|
||||||
self.functions = []
|
self.functions = []
|
||||||
self.datatypes = []
|
self.datatypes = []
|
||||||
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
|
self.dbmetadata = \
|
||||||
'datatypes': {}}
|
{"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
|
||||||
self.text_before_cursor = None
|
self.text_before_cursor = None
|
||||||
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
|
|
||||||
|
|
||||||
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
|
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
|
||||||
|
|
||||||
@ -182,7 +180,8 @@ class SQLAutoComplete(object):
|
|||||||
def escape_name(self, name):
|
def escape_name(self, name):
|
||||||
if name and (
|
if name and (
|
||||||
(not self.name_pattern.match(name)) or
|
(not self.name_pattern.match(name)) or
|
||||||
(name.upper() in self.reserved_words)
|
(name.upper() in self.reserved_words) or
|
||||||
|
(name.upper() in self.functions)
|
||||||
):
|
):
|
||||||
name = '"%s"' % name
|
name = '"%s"' % name
|
||||||
|
|
||||||
@ -212,7 +211,7 @@ class SQLAutoComplete(object):
|
|||||||
|
|
||||||
# schemata is a list of schema names
|
# schemata is a list of schema names
|
||||||
schemata = self.escaped_names(schemata)
|
schemata = self.escaped_names(schemata)
|
||||||
metadata = self.dbmetadata['tables']
|
metadata = self.dbmetadata["tables"]
|
||||||
for schema in schemata:
|
for schema in schemata:
|
||||||
metadata[schema] = {}
|
metadata[schema] = {}
|
||||||
|
|
||||||
@ -224,7 +223,7 @@ class SQLAutoComplete(object):
|
|||||||
self.all_completions.update(schemata)
|
self.all_completions.update(schemata)
|
||||||
|
|
||||||
def extend_casing(self, words):
|
def extend_casing(self, words):
|
||||||
""" extend casing data
|
"""extend casing data
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -274,7 +273,7 @@ class SQLAutoComplete(object):
|
|||||||
name=colname,
|
name=colname,
|
||||||
datatype=datatype,
|
datatype=datatype,
|
||||||
has_default=has_default,
|
has_default=has_default,
|
||||||
default=default
|
default=default,
|
||||||
)
|
)
|
||||||
metadata[schema][relname][colname] = column
|
metadata[schema][relname][colname] = column
|
||||||
self.all_completions.add(colname)
|
self.all_completions.add(colname)
|
||||||
@ -285,7 +284,7 @@ class SQLAutoComplete(object):
|
|||||||
|
|
||||||
# dbmetadata['schema_name']['functions']['function_name'] should return
|
# dbmetadata['schema_name']['functions']['function_name'] should return
|
||||||
# the function metadata namedtuple for the corresponding function
|
# the function metadata namedtuple for the corresponding function
|
||||||
metadata = self.dbmetadata['functions']
|
metadata = self.dbmetadata["functions"]
|
||||||
|
|
||||||
for f in func_data:
|
for f in func_data:
|
||||||
schema, func = self.escaped_names([f.schema_name, f.func_name])
|
schema, func = self.escaped_names([f.schema_name, f.func_name])
|
||||||
@ -309,10 +308,10 @@ class SQLAutoComplete(object):
|
|||||||
self._arg_list_cache = \
|
self._arg_list_cache = \
|
||||||
dict((usage,
|
dict((usage,
|
||||||
dict((meta, self._arg_list(meta, usage))
|
dict((meta, self._arg_list(meta, usage))
|
||||||
for sch, funcs in self.dbmetadata['functions'].items()
|
for sch, funcs in self.dbmetadata["functions"].items()
|
||||||
for func, metas in funcs.items()
|
for func, metas in funcs.items()
|
||||||
for meta in metas))
|
for meta in metas))
|
||||||
for usage in ('call', 'call_display', 'signature'))
|
for usage in ("call", "call_display", "signature"))
|
||||||
|
|
||||||
def extend_foreignkeys(self, fk_data):
|
def extend_foreignkeys(self, fk_data):
|
||||||
|
|
||||||
@ -322,7 +321,7 @@ class SQLAutoComplete(object):
|
|||||||
|
|
||||||
# These are added as a list of ForeignKey namedtuples to the
|
# These are added as a list of ForeignKey namedtuples to the
|
||||||
# ColumnMetadata namedtuple for both the child and parent
|
# ColumnMetadata namedtuple for both the child and parent
|
||||||
meta = self.dbmetadata['tables']
|
meta = self.dbmetadata["tables"]
|
||||||
|
|
||||||
for fk in fk_data:
|
for fk in fk_data:
|
||||||
e = self.escaped_names
|
e = self.escaped_names
|
||||||
@ -350,7 +349,7 @@ class SQLAutoComplete(object):
|
|||||||
# dbmetadata['datatypes'][schema_name][type_name] should store type
|
# dbmetadata['datatypes'][schema_name][type_name] should store type
|
||||||
# metadata, such as composite type field names. Currently, we're not
|
# metadata, such as composite type field names. Currently, we're not
|
||||||
# storing any metadata beyond typename, so just store None
|
# storing any metadata beyond typename, so just store None
|
||||||
meta = self.dbmetadata['datatypes']
|
meta = self.dbmetadata["datatypes"]
|
||||||
|
|
||||||
for t in type_data:
|
for t in type_data:
|
||||||
schema, type_name = self.escaped_names(t)
|
schema, type_name = self.escaped_names(t)
|
||||||
@ -364,11 +363,11 @@ class SQLAutoComplete(object):
|
|||||||
self.databases = []
|
self.databases = []
|
||||||
self.special_commands = []
|
self.special_commands = []
|
||||||
self.search_path = []
|
self.search_path = []
|
||||||
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
|
self.dbmetadata = \
|
||||||
'datatypes': {}}
|
{"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
|
||||||
self.all_completions = set(self.keywords + self.functions)
|
self.all_completions = set(self.keywords + self.functions)
|
||||||
|
|
||||||
def find_matches(self, text, collection, mode='fuzzy', meta=None):
|
def find_matches(self, text, collection, mode="strict", meta=None):
|
||||||
"""Find completion matches for the given text.
|
"""Find completion matches for the given text.
|
||||||
|
|
||||||
Given the user's input text and a collection of available
|
Given the user's input text and a collection of available
|
||||||
@ -389,17 +388,26 @@ class SQLAutoComplete(object):
|
|||||||
collection:
|
collection:
|
||||||
mode:
|
mode:
|
||||||
meta:
|
meta:
|
||||||
meta_collection:
|
|
||||||
"""
|
"""
|
||||||
if not collection:
|
if not collection:
|
||||||
return []
|
return []
|
||||||
prio_order = [
|
prio_order = [
|
||||||
'keyword', 'function', 'view', 'table', 'datatype', 'database',
|
"keyword",
|
||||||
'schema', 'column', 'table alias', 'join', 'name join', 'fk join',
|
"function",
|
||||||
'table format'
|
"view",
|
||||||
|
"table",
|
||||||
|
"datatype",
|
||||||
|
"database",
|
||||||
|
"schema",
|
||||||
|
"column",
|
||||||
|
"table alias",
|
||||||
|
"join",
|
||||||
|
"name join",
|
||||||
|
"fk join",
|
||||||
|
"table format",
|
||||||
]
|
]
|
||||||
type_priority = prio_order.index(meta) if meta in prio_order else -1
|
type_priority = prio_order.index(meta) if meta in prio_order else -1
|
||||||
text = last_word(text, include='most_punctuations').lower()
|
text = last_word(text, include="most_punctuations").lower()
|
||||||
text_len = len(text)
|
text_len = len(text)
|
||||||
|
|
||||||
if text and text[0] == '"':
|
if text and text[0] == '"':
|
||||||
@ -409,7 +417,7 @@ class SQLAutoComplete(object):
|
|||||||
# Completion.position value is correct
|
# Completion.position value is correct
|
||||||
text = text[1:]
|
text = text[1:]
|
||||||
|
|
||||||
if mode == 'fuzzy':
|
if mode == "fuzzy":
|
||||||
fuzzy = True
|
fuzzy = True
|
||||||
priority_func = self.prioritizer.name_count
|
priority_func = self.prioritizer.name_count
|
||||||
else:
|
else:
|
||||||
@ -422,19 +430,20 @@ class SQLAutoComplete(object):
|
|||||||
# Note: higher priority values mean more important, so use negative
|
# Note: higher priority values mean more important, so use negative
|
||||||
# signs to flip the direction of the tuple
|
# signs to flip the direction of the tuple
|
||||||
if fuzzy:
|
if fuzzy:
|
||||||
regex = '.*?'.join(map(re.escape, text))
|
regex = ".*?".join(map(re.escape, text))
|
||||||
pat = re.compile('(%s)' % regex)
|
pat = re.compile("(%s)" % regex)
|
||||||
|
|
||||||
def _match(item):
|
def _match(item):
|
||||||
if item.lower()[:len(text) + 1] in (text, text + ' '):
|
if item.lower()[: len(text) + 1] in (text, text + " "):
|
||||||
# Exact match of first word in suggestion
|
# Exact match of first word in suggestion
|
||||||
# This is to get exact alias matches to the top
|
# This is to get exact alias matches to the top
|
||||||
# E.g. for input `e`, 'Entries E' should be on top
|
# E.g. for input `e`, 'Entries E' should be on top
|
||||||
# (before e.g. `EndUsers EU`)
|
# (before e.g. `EndUsers EU`)
|
||||||
return float('Infinity'), -1
|
return float("Infinity"), -1
|
||||||
r = pat.search(self.unescape_name(item.lower()))
|
r = pat.search(self.unescape_name(item.lower()))
|
||||||
if r:
|
if r:
|
||||||
return -len(r.group()), -r.start()
|
return -len(r.group()), -r.start()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
match_end_limit = len(text)
|
match_end_limit = len(text)
|
||||||
|
|
||||||
@ -446,7 +455,7 @@ class SQLAutoComplete(object):
|
|||||||
if match_point >= 0:
|
if match_point >= 0:
|
||||||
# Use negative infinity to force keywords to sort after all
|
# Use negative infinity to force keywords to sort after all
|
||||||
# fuzzy matches
|
# fuzzy matches
|
||||||
return -float('Infinity'), -match_point
|
return -float("Infinity"), -match_point
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for cand in collection:
|
for cand in collection:
|
||||||
@ -466,7 +475,7 @@ class SQLAutoComplete(object):
|
|||||||
if sort_key:
|
if sort_key:
|
||||||
if display_meta and len(display_meta) > 50:
|
if display_meta and len(display_meta) > 50:
|
||||||
# Truncate meta-text to 50 characters, if necessary
|
# Truncate meta-text to 50 characters, if necessary
|
||||||
display_meta = display_meta[:47] + '...'
|
display_meta = display_meta[:47] + "..."
|
||||||
|
|
||||||
# Lexical order of items in the collection, used for
|
# Lexical order of items in the collection, used for
|
||||||
# tiebreaking items with the same match group length and start
|
# tiebreaking items with the same match group length and start
|
||||||
@ -478,14 +487,18 @@ class SQLAutoComplete(object):
|
|||||||
# We also use the unescape_name to make sure quoted names have
|
# We also use the unescape_name to make sure quoted names have
|
||||||
# the same priority as unquoted names.
|
# the same priority as unquoted names.
|
||||||
lexical_priority = (
|
lexical_priority = (
|
||||||
tuple(0 if c in (' _') else -ord(c)
|
tuple(0 if c in (" _") else -ord(c)
|
||||||
for c in self.unescape_name(item.lower())) + (1,) +
|
for c in self.unescape_name(item.lower())) + (1,) +
|
||||||
tuple(c for c in item)
|
tuple(c for c in item)
|
||||||
)
|
)
|
||||||
|
|
||||||
priority = (
|
priority = (
|
||||||
sort_key, type_priority, prio, priority_func(item),
|
sort_key,
|
||||||
prio2, lexical_priority
|
type_priority,
|
||||||
|
prio,
|
||||||
|
priority_func(item),
|
||||||
|
prio2,
|
||||||
|
lexical_priority,
|
||||||
)
|
)
|
||||||
matches.append(
|
matches.append(
|
||||||
Match(
|
Match(
|
||||||
@ -493,9 +506,9 @@ class SQLAutoComplete(object):
|
|||||||
text=item,
|
text=item,
|
||||||
start_position=-text_len,
|
start_position=-text_len,
|
||||||
display_meta=display_meta,
|
display_meta=display_meta,
|
||||||
display=display
|
display=display,
|
||||||
),
|
),
|
||||||
priority=priority
|
priority=priority,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return matches
|
return matches
|
||||||
@ -516,8 +529,8 @@ class SQLAutoComplete(object):
|
|||||||
matches.extend(matcher(self, suggestion, word_before_cursor))
|
matches.extend(matcher(self, suggestion, word_before_cursor))
|
||||||
|
|
||||||
# Sort matches so highest priorities are first
|
# Sort matches so highest priorities are first
|
||||||
matches = sorted(matches, key=operator.attrgetter('priority'),
|
matches = \
|
||||||
reverse=True)
|
sorted(matches, key=operator.attrgetter("priority"), reverse=True)
|
||||||
|
|
||||||
result = dict()
|
result = dict()
|
||||||
for m in matches:
|
for m in matches:
|
||||||
@ -539,23 +552,28 @@ class SQLAutoComplete(object):
|
|||||||
|
|
||||||
tables = suggestion.table_refs
|
tables = suggestion.table_refs
|
||||||
do_qualify = suggestion.qualifiable and {
|
do_qualify = suggestion.qualifiable and {
|
||||||
'always': True, 'never': False,
|
"always": True,
|
||||||
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns]
|
"never": False,
|
||||||
|
"if_more_than_one_table": len(tables) > 1,
|
||||||
|
}[self.qualify_columns]
|
||||||
|
|
||||||
def qualify(col, tbl):
|
def qualify(col, tbl):
|
||||||
return (tbl + '.' + col) if do_qualify else col
|
return (tbl + '.' + col) if do_qualify else col
|
||||||
|
|
||||||
scoped_cols = self.populate_scoped_cols(
|
scoped_cols = \
|
||||||
tables, suggestion.local_tables
|
self.populate_scoped_cols(tables, suggestion.local_tables)
|
||||||
)
|
|
||||||
|
|
||||||
def make_cand(name, ref):
|
def make_cand(name, ref):
|
||||||
synonyms = (name, generate_alias(name))
|
synonyms = (name, generate_alias(name))
|
||||||
return Candidate(qualify(name, ref), 0, 'column', synonyms)
|
return Candidate(qualify(name, ref), 0, "column", synonyms)
|
||||||
|
|
||||||
def flat_cols():
|
def flat_cols():
|
||||||
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items()
|
return [
|
||||||
for c in cols]
|
make_cand(c.name, t.ref)
|
||||||
|
for t, cols in scoped_cols.items()
|
||||||
|
for c in cols
|
||||||
|
]
|
||||||
|
|
||||||
if suggestion.require_last_table:
|
if suggestion.require_last_table:
|
||||||
# require_last_table is used for 'tb11 JOIN tbl2 USING
|
# require_last_table is used for 'tb11 JOIN tbl2 USING
|
||||||
# (...' which should
|
# (...' which should
|
||||||
@ -569,10 +587,11 @@ class SQLAutoComplete(object):
|
|||||||
dict((t, [col for col in cols if col.name in other_tbl_cols])
|
dict((t, [col for col in cols if col.name in other_tbl_cols])
|
||||||
for t, cols in scoped_cols.items() if t.ref == ltbl)
|
for t, cols in scoped_cols.items() if t.ref == ltbl)
|
||||||
|
|
||||||
lastword = last_word(word_before_cursor, include='most_punctuations')
|
lastword = last_word(word_before_cursor, include="most_punctuations")
|
||||||
if lastword == '*':
|
if lastword == "*":
|
||||||
if suggestion.context == 'insert':
|
if suggestion.context == "insert":
|
||||||
def is_scoped(col):
|
|
||||||
|
def filter(col):
|
||||||
if not col.has_default:
|
if not col.has_default:
|
||||||
return True
|
return True
|
||||||
return not any(
|
return not any(
|
||||||
@ -580,40 +599,39 @@ class SQLAutoComplete(object):
|
|||||||
for p in self.insert_col_skip_patterns
|
for p in self.insert_col_skip_patterns
|
||||||
)
|
)
|
||||||
scoped_cols = \
|
scoped_cols = \
|
||||||
dict((t, [col for col in cols if is_scoped(col)])
|
dict((t, [col for col in cols if filter(col)])
|
||||||
for t, cols in scoped_cols.items())
|
for t, cols in scoped_cols.items())
|
||||||
if self.asterisk_column_order == 'alphabetic':
|
if self.asterisk_column_order == "alphabetic":
|
||||||
for cols in scoped_cols.values():
|
for cols in scoped_cols.values():
|
||||||
cols.sort(key=operator.attrgetter('name'))
|
cols.sort(key=operator.attrgetter("name"))
|
||||||
if (
|
if (
|
||||||
lastword != word_before_cursor and
|
lastword != word_before_cursor and
|
||||||
len(tables) == 1 and
|
len(tables) == 1 and
|
||||||
word_before_cursor[-len(lastword) - 1] == '.'
|
word_before_cursor[-len(lastword) - 1] == "."
|
||||||
):
|
):
|
||||||
# User typed x.*; replicate "x." for all columns except the
|
# User typed x.*; replicate "x." for all columns except the
|
||||||
# first, which gets the original (as we only replace the "*"")
|
# first, which gets the original (as we only replace the "*"")
|
||||||
sep = ', ' + word_before_cursor[:-1]
|
sep = ", " + word_before_cursor[:-1]
|
||||||
collist = sep.join(c.completion for c in flat_cols())
|
collist = sep.join(c.completion for c in flat_cols())
|
||||||
else:
|
else:
|
||||||
collist = ', '.join(qualify(c.name, t.ref)
|
collist = ", ".join(qualify(c.name, t.ref)
|
||||||
for t, cs in scoped_cols.items()
|
for t, cs in scoped_cols.items()
|
||||||
for c in cs)
|
for c in cs)
|
||||||
|
|
||||||
return [Match(
|
return [
|
||||||
completion=Completion(
|
Match(
|
||||||
collist,
|
completion=Completion(
|
||||||
-1,
|
collist, -1, display_meta="columns", display="*"
|
||||||
display_meta='columns',
|
),
|
||||||
display='*'
|
priority=(1, 1, 1),
|
||||||
),
|
)
|
||||||
priority=(1, 1, 1)
|
]
|
||||||
)]
|
|
||||||
|
|
||||||
return self.find_matches(word_before_cursor, flat_cols(),
|
return self.find_matches(word_before_cursor, flat_cols(),
|
||||||
mode='strict', meta='column')
|
meta="column")
|
||||||
|
|
||||||
def alias(self, tbl, tbls):
|
def alias(self, tbl, tbls):
|
||||||
""" Generate a unique table alias
|
"""Generate a unique table alias
|
||||||
tbl - name of the table to alias, quoted if it needs to be
|
tbl - name of the table to alias, quoted if it needs to be
|
||||||
tbls - TableReference iterable of tables already in query
|
tbls - TableReference iterable of tables already in query
|
||||||
"""
|
"""
|
||||||
@ -628,25 +646,6 @@ class SQLAutoComplete(object):
|
|||||||
aliases = (tbl + str(i) for i in count(2))
|
aliases = (tbl + str(i) for i in count(2))
|
||||||
return next(a for a in aliases if normalize_ref(a) not in tbls)
|
return next(a for a in aliases if normalize_ref(a) not in tbls)
|
||||||
|
|
||||||
def _check_for_aliases(self, left, refs, rtbl, suggestion, right):
|
|
||||||
"""
|
|
||||||
Check for generate aliases and return join value
|
|
||||||
:param left:
|
|
||||||
:param refs:
|
|
||||||
:param rtbl:
|
|
||||||
:param suggestion:
|
|
||||||
:param right:
|
|
||||||
:return: return join string.
|
|
||||||
"""
|
|
||||||
if self.generate_aliases or normalize_ref(left.tbl) in refs:
|
|
||||||
lref = self.alias(left.tbl, suggestion.table_refs)
|
|
||||||
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
|
|
||||||
left.tbl, left.col, rtbl.ref, right.col, lref)
|
|
||||||
else:
|
|
||||||
join = '{0} ON {0}.{1} = {2}.{3}'.format(
|
|
||||||
left.tbl, left.col, rtbl.ref, right.col)
|
|
||||||
return join
|
|
||||||
|
|
||||||
def get_join_matches(self, suggestion, word_before_cursor):
|
def get_join_matches(self, suggestion, word_before_cursor):
|
||||||
tbls = suggestion.table_refs
|
tbls = suggestion.table_refs
|
||||||
cols = self.populate_scoped_cols(tbls)
|
cols = self.populate_scoped_cols(tbls)
|
||||||
@ -658,10 +657,12 @@ class SQLAutoComplete(object):
|
|||||||
joins = []
|
joins = []
|
||||||
# Iterate over FKs in existing tables to find potential joins
|
# Iterate over FKs in existing tables to find potential joins
|
||||||
fks = (
|
fks = (
|
||||||
(fk, rtbl, rcol) for rtbl, rcols in cols.items()
|
(fk, rtbl, rcol)
|
||||||
for rcol in rcols for fk in rcol.foreignkeys
|
for rtbl, rcols in cols.items()
|
||||||
|
for rcol in rcols
|
||||||
|
for fk in rcol.foreignkeys
|
||||||
)
|
)
|
||||||
col = namedtuple('col', 'schema tbl col')
|
col = namedtuple("col", "schema tbl col")
|
||||||
for fk, rtbl, rcol in fks:
|
for fk, rtbl, rcol in fks:
|
||||||
right = col(rtbl.schema, rtbl.name, rcol.name)
|
right = col(rtbl.schema, rtbl.name, rcol.name)
|
||||||
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
||||||
@ -670,54 +671,38 @@ class SQLAutoComplete(object):
|
|||||||
if suggestion.schema and left.schema != suggestion.schema:
|
if suggestion.schema and left.schema != suggestion.schema:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
join = self._check_for_aliases(left, refs, rtbl, suggestion, right)
|
if self.generate_aliases or normalize_ref(left.tbl) in refs:
|
||||||
|
lref = self.alias(left.tbl, suggestion.table_refs)
|
||||||
|
join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
|
||||||
|
left.tbl, left.col, rtbl.ref, right.col, lref
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
join = "{0} ON {0}.{1} = {2}.{3}".format(
|
||||||
|
left.tbl, left.col, rtbl.ref, right.col
|
||||||
|
)
|
||||||
alias = generate_alias(left.tbl)
|
alias = generate_alias(left.tbl)
|
||||||
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format(
|
synonyms = [
|
||||||
alias, left.col, rtbl.ref, right.col)]
|
join,
|
||||||
|
"{0} ON {0}.{1} = {2}.{3}".format(
|
||||||
|
alias, left.col, rtbl.ref, right.col
|
||||||
|
),
|
||||||
|
]
|
||||||
# Schema-qualify if (1) new table in same schema as old, and old
|
# Schema-qualify if (1) new table in same schema as old, and old
|
||||||
# is schema-qualified, or (2) new in other schema, except public
|
# is schema-qualified, or (2) new in other schema, except public
|
||||||
if not suggestion.schema and \
|
if not suggestion.schema and \
|
||||||
(qualified[normalize_ref(rtbl.ref)] and
|
(qualified[normalize_ref(rtbl.ref)] and
|
||||||
left.schema == right.schema or
|
left.schema == right.schema or
|
||||||
left.schema not in (right.schema, 'public')):
|
left.schema not in (right.schema, "public")):
|
||||||
join = left.schema + '.' + join
|
join = left.schema + "." + join
|
||||||
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
|
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
|
||||||
0 if (left.schema, left.tbl) in other_tbls else 1)
|
0 if (left.schema, left.tbl) in other_tbls else 1
|
||||||
joins.append(Candidate(join, prio, 'join', synonyms=synonyms))
|
)
|
||||||
|
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
|
||||||
|
|
||||||
return self.find_matches(word_before_cursor, joins,
|
return self.find_matches(word_before_cursor, joins, meta="join")
|
||||||
mode='strict', meta='join')
|
|
||||||
|
|
||||||
def list_dict(self, pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
|
||||||
d = defaultdict(list)
|
|
||||||
for pair in pairs:
|
|
||||||
d[pair[0]].append(pair[1])
|
|
||||||
return d
|
|
||||||
|
|
||||||
def add_cond(self, lcol, rcol, rref, prio, meta, **kwargs):
|
|
||||||
"""
|
|
||||||
Add Condition in join
|
|
||||||
:param lcol:
|
|
||||||
:param rcol:
|
|
||||||
:param rref:
|
|
||||||
:param prio:
|
|
||||||
:param meta:
|
|
||||||
:param kwargs:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
suggestion = kwargs['suggestion']
|
|
||||||
found_conds = kwargs['found_conds']
|
|
||||||
ltbl = kwargs['ltbl']
|
|
||||||
ref_prio = kwargs['ref_prio']
|
|
||||||
conds = kwargs['conds']
|
|
||||||
prefix = '' if suggestion.parent else ltbl.ref + '.'
|
|
||||||
cond = prefix + lcol + ' = ' + rref + '.' + rcol
|
|
||||||
if cond not in found_conds:
|
|
||||||
found_conds.add(cond)
|
|
||||||
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
|
|
||||||
|
|
||||||
def get_join_condition_matches(self, suggestion, word_before_cursor):
|
def get_join_condition_matches(self, suggestion, word_before_cursor):
|
||||||
col = namedtuple('col', 'schema tbl col')
|
col = namedtuple("col", "schema tbl col")
|
||||||
tbls = self.populate_scoped_cols(suggestion.table_refs).items
|
tbls = self.populate_scoped_cols(suggestion.table_refs).items
|
||||||
cols = [(t, c) for t, cs in tbls() for c in cs]
|
cols = [(t, c) for t, cs in tbls() for c in cs]
|
||||||
try:
|
try:
|
||||||
@ -727,11 +712,24 @@ class SQLAutoComplete(object):
|
|||||||
return []
|
return []
|
||||||
conds, found_conds = [], set()
|
conds, found_conds = [], set()
|
||||||
|
|
||||||
|
def add_cond(lcol, rcol, rref, prio, meta):
|
||||||
|
prefix = "" if suggestion.parent else ltbl.ref + "."
|
||||||
|
cond = prefix + lcol + " = " + rref + "." + rcol
|
||||||
|
if cond not in found_conds:
|
||||||
|
found_conds.add(cond)
|
||||||
|
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
|
||||||
|
|
||||||
|
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
||||||
|
d = defaultdict(list)
|
||||||
|
for pair in pairs:
|
||||||
|
d[pair[0]].append(pair[1])
|
||||||
|
return d
|
||||||
|
|
||||||
# Tables that are closer to the cursor get higher prio
|
# Tables that are closer to the cursor get higher prio
|
||||||
ref_prio = dict((tbl.ref, num)
|
ref_prio = dict((tbl.ref, num)
|
||||||
for num, tbl in enumerate(suggestion.table_refs))
|
for num, tbl in enumerate(suggestion.table_refs))
|
||||||
# Map (schema, table, col) to tables
|
# Map (schema, table, col) to tables
|
||||||
coldict = self.list_dict(
|
coldict = list_dict(
|
||||||
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
||||||
)
|
)
|
||||||
# For each fk from the left table, generate a join condition if
|
# For each fk from the left table, generate a join condition if
|
||||||
@ -742,89 +740,76 @@ class SQLAutoComplete(object):
|
|||||||
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
||||||
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
||||||
left, right = (child, par) if left == child else (par, child)
|
left, right = (child, par) if left == child else (par, child)
|
||||||
|
|
||||||
for rtbl in coldict[right]:
|
for rtbl in coldict[right]:
|
||||||
kwargs = {
|
add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
|
||||||
"suggestion": suggestion,
|
|
||||||
"found_conds": found_conds,
|
|
||||||
"ltbl": ltbl,
|
|
||||||
"conds": conds,
|
|
||||||
"ref_prio": ref_prio
|
|
||||||
}
|
|
||||||
self.add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join',
|
|
||||||
**kwargs)
|
|
||||||
# For name matching, use a {(colname, coltype): TableReference} dict
|
# For name matching, use a {(colname, coltype): TableReference} dict
|
||||||
coltyp = namedtuple('coltyp', 'name datatype')
|
coltyp = namedtuple("coltyp", "name datatype")
|
||||||
col_table = self.list_dict(
|
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
|
||||||
(coltyp(c.name, c.datatype), t) for t, c in cols)
|
|
||||||
# Find all name-match join conditions
|
# Find all name-match join conditions
|
||||||
for c in (coltyp(c.name, c.datatype) for c in lcols):
|
for c in (coltyp(c.name, c.datatype) for c in lcols):
|
||||||
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
|
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
|
||||||
kwargs = {
|
|
||||||
"suggestion": suggestion,
|
|
||||||
"found_conds": found_conds,
|
|
||||||
"ltbl": ltbl,
|
|
||||||
"conds": conds,
|
|
||||||
"ref_prio": ref_prio
|
|
||||||
}
|
|
||||||
prio = 1000 if c.datatype in (
|
prio = 1000 if c.datatype in (
|
||||||
'integer', 'bigint', 'smallint') else 0
|
"integer", "bigint", "smallint") else 0
|
||||||
self.add_cond(c.name, c.name, rtbl.ref, prio, 'name join',
|
add_cond(c.name, c.name, rtbl.ref, prio, "name join")
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
return self.find_matches(word_before_cursor, conds,
|
return self.find_matches(word_before_cursor, conds, meta="join")
|
||||||
mode='strict', meta='join')
|
|
||||||
|
|
||||||
def get_function_matches(self, suggestion, word_before_cursor,
|
def get_function_matches(self, suggestion, word_before_cursor,
|
||||||
alias=False):
|
alias=False):
|
||||||
if suggestion.usage == 'from':
|
if suggestion.usage == "from":
|
||||||
# Only suggest functions allowed in FROM clause
|
# Only suggest functions allowed in FROM clause
|
||||||
|
|
||||||
def filt(f):
|
def filt(f):
|
||||||
return not f.is_aggregate and not f.is_window
|
return (
|
||||||
|
not f.is_aggregate and not f.is_window and
|
||||||
|
not f.is_extension and
|
||||||
|
(f.is_public or f.schema_name == suggestion.schema)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
alias = False
|
alias = False
|
||||||
|
|
||||||
def filt(f):
|
def filt(f):
|
||||||
return True
|
return not f.is_extension and (
|
||||||
|
f.is_public or f.schema_name == suggestion.schema
|
||||||
|
)
|
||||||
|
|
||||||
arg_mode = {
|
arg_mode = {"signature": "signature", "special": None}.get(
|
||||||
'signature': 'signature',
|
suggestion.usage, "call"
|
||||||
'special': None,
|
|
||||||
}.get(suggestion.usage, 'call')
|
|
||||||
# Function overloading means we way have multiple functions of the same
|
|
||||||
# name at this point, so keep unique names only
|
|
||||||
funcs = set(
|
|
||||||
self._make_cand(f, alias, suggestion, arg_mode)
|
|
||||||
for f in self.populate_functions(suggestion.schema, filt)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
matches = self.find_matches(word_before_cursor, funcs,
|
# Function overloading means we way have multiple functions of the same
|
||||||
mode='strict', meta='function')
|
# name at this point, so keep unique names only
|
||||||
|
all_functions = self.populate_functions(suggestion.schema, filt)
|
||||||
|
funcs = set(
|
||||||
|
self._make_cand(f, alias, suggestion, arg_mode)
|
||||||
|
for f in all_functions
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = self.find_matches(word_before_cursor, funcs, meta="function")
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def get_schema_matches(self, suggestion, word_before_cursor):
|
def get_schema_matches(self, suggestion, word_before_cursor):
|
||||||
schema_names = self.dbmetadata['tables'].keys()
|
schema_names = self.dbmetadata["tables"].keys()
|
||||||
|
|
||||||
# Unless we're sure the user really wants them, hide schema names
|
# Unless we're sure the user really wants them, hide schema names
|
||||||
# starting with pg_, which are mostly temporary schemas
|
# starting with pg_, which are mostly temporary schemas
|
||||||
if not word_before_cursor.startswith('pg_'):
|
if not word_before_cursor.startswith("pg_"):
|
||||||
schema_names = [s
|
schema_names = [s for s in schema_names if not s.startswith("pg_")]
|
||||||
for s in schema_names
|
|
||||||
if not s.startswith('pg_')]
|
|
||||||
|
|
||||||
if suggestion.quoted:
|
if suggestion.quoted:
|
||||||
schema_names = [self.escape_schema(s) for s in schema_names]
|
schema_names = [self.escape_schema(s) for s in schema_names]
|
||||||
|
|
||||||
return self.find_matches(word_before_cursor, schema_names,
|
return self.find_matches(word_before_cursor, schema_names,
|
||||||
mode='strict', meta='schema')
|
meta="schema")
|
||||||
|
|
||||||
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
|
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
|
||||||
alias = self.generate_aliases
|
alias = self.generate_aliases
|
||||||
s = suggestion
|
s = suggestion
|
||||||
t_sug = Table(s.schema, s.table_refs, s.local_tables)
|
t_sug = Table(s.schema, s.table_refs, s.local_tables)
|
||||||
v_sug = View(s.schema, s.table_refs)
|
v_sug = View(s.schema, s.table_refs)
|
||||||
f_sug = Function(s.schema, s.table_refs, usage='from')
|
f_sug = Function(s.schema, s.table_refs, usage="from")
|
||||||
return (
|
return (
|
||||||
self.get_table_matches(t_sug, word_before_cursor, alias) +
|
self.get_table_matches(t_sug, word_before_cursor, alias) +
|
||||||
self.get_view_matches(v_sug, word_before_cursor, alias) +
|
self.get_view_matches(v_sug, word_before_cursor, alias) +
|
||||||
@ -839,42 +824,43 @@ class SQLAutoComplete(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
template = {
|
template = {
|
||||||
'call': self.call_arg_style,
|
"call": self.call_arg_style,
|
||||||
'call_display': self.call_arg_display_style,
|
"call_display": self.call_arg_display_style,
|
||||||
'signature': self.signature_arg_style
|
"signature": self.signature_arg_style,
|
||||||
}[usage]
|
}[usage]
|
||||||
args = func.args()
|
args = func.args()
|
||||||
if not template or (
|
if not template:
|
||||||
usage == 'call' and (
|
return "()"
|
||||||
len(args) < 2 or func.has_variadic())):
|
elif usage == "call" and len(args) < 2:
|
||||||
return '()'
|
return "()"
|
||||||
|
elif usage == "call" and func.has_variadic():
|
||||||
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max
|
return "()"
|
||||||
|
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
|
||||||
max_arg_len = max(len(a.name) for a in args) if multiline else 0
|
max_arg_len = max(len(a.name) for a in args) if multiline else 0
|
||||||
args = (
|
args = (
|
||||||
self._format_arg(template, arg, arg_num + 1, max_arg_len)
|
self._format_arg(template, arg, arg_num + 1, max_arg_len)
|
||||||
for arg_num, arg in enumerate(args)
|
for arg_num, arg in enumerate(args)
|
||||||
)
|
)
|
||||||
if multiline:
|
if multiline:
|
||||||
return '(' + ','.join('\n ' + a for a in args if a) + '\n)'
|
return "(" + ",".join("\n " + a for a in args if a) + "\n)"
|
||||||
else:
|
else:
|
||||||
return '(' + ', '.join(a for a in args if a) + ')'
|
return "(" + ", ".join(a for a in args if a) + ")"
|
||||||
|
|
||||||
def _format_arg(self, template, arg, arg_num, max_arg_len):
|
def _format_arg(self, template, arg, arg_num, max_arg_len):
|
||||||
if not template:
|
if not template:
|
||||||
return None
|
return None
|
||||||
if arg.has_default:
|
if arg.has_default:
|
||||||
arg_default = 'NULL' if arg.default is None else arg.default
|
arg_default = "NULL" if arg.default is None else arg.default
|
||||||
# Remove trailing ::(schema.)type
|
# Remove trailing ::(schema.)type
|
||||||
arg_default = arg_default_type_strip_regex.sub('', arg_default)
|
arg_default = arg_default_type_strip_regex.sub("", arg_default)
|
||||||
else:
|
else:
|
||||||
arg_default = ''
|
arg_default = ""
|
||||||
return template.format(
|
return template.format(
|
||||||
max_arg_len=max_arg_len,
|
max_arg_len=max_arg_len,
|
||||||
arg_name=arg.name,
|
arg_name=arg.name,
|
||||||
arg_num=arg_num,
|
arg_num=arg_num,
|
||||||
arg_type=arg.datatype,
|
arg_type=arg.datatype,
|
||||||
arg_default=arg_default
|
arg_default=arg_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
|
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
|
||||||
@ -890,63 +876,60 @@ class SQLAutoComplete(object):
|
|||||||
if do_alias:
|
if do_alias:
|
||||||
alias = self.alias(cased_tbl, suggestion.table_refs)
|
alias = self.alias(cased_tbl, suggestion.table_refs)
|
||||||
synonyms = (cased_tbl, generate_alias(cased_tbl))
|
synonyms = (cased_tbl, generate_alias(cased_tbl))
|
||||||
maybe_alias = (' ' + alias) if do_alias else ''
|
maybe_alias = (" " + alias) if do_alias else ""
|
||||||
maybe_schema = (tbl.schema + '.') if tbl.schema else ''
|
maybe_schema = (tbl.schema + ".") if tbl.schema else ""
|
||||||
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ''
|
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
|
||||||
if arg_mode == 'call':
|
if arg_mode == "call":
|
||||||
display_suffix = self._arg_list_cache['call_display'][tbl.meta]
|
display_suffix = self._arg_list_cache["call_display"][tbl.meta]
|
||||||
elif arg_mode == 'signature':
|
elif arg_mode == "signature":
|
||||||
display_suffix = self._arg_list_cache['signature'][tbl.meta]
|
display_suffix = self._arg_list_cache["signature"][tbl.meta]
|
||||||
else:
|
else:
|
||||||
display_suffix = ''
|
display_suffix = ""
|
||||||
item = maybe_schema + cased_tbl + suffix + maybe_alias
|
item = maybe_schema + cased_tbl + suffix + maybe_alias
|
||||||
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
|
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
|
||||||
prio2 = 0 if tbl.schema else 1
|
prio2 = 0 if tbl.schema else 1
|
||||||
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
|
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
|
||||||
|
|
||||||
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
|
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
|
||||||
tables = self.populate_schema_objects(suggestion.schema, 'tables')
|
tables = self.populate_schema_objects(suggestion.schema, "tables")
|
||||||
tables.extend(
|
tables.extend(
|
||||||
SchemaObject(tbl.name) for tbl in suggestion.local_tables)
|
SchemaObject(tbl.name) for tbl in suggestion.local_tables)
|
||||||
|
|
||||||
# Unless we're sure the user really wants them, don't suggest the
|
# Unless we're sure the user really wants them, don't suggest the
|
||||||
# pg_catalog tables that are implicitly on the search path
|
# pg_catalog tables that are implicitly on the search path
|
||||||
if not suggestion.schema and (
|
if not suggestion.schema and \
|
||||||
not word_before_cursor.startswith('pg_')):
|
(not word_before_cursor.startswith("pg_")):
|
||||||
tables = [t for t in tables if not t.name.startswith('pg_')]
|
tables = [t for t in tables if not t.name.startswith("pg_")]
|
||||||
tables = [self._make_cand(t, alias, suggestion) for t in tables]
|
tables = [self._make_cand(t, alias, suggestion) for t in tables]
|
||||||
return self.find_matches(word_before_cursor, tables,
|
return self.find_matches(word_before_cursor, tables, meta="table")
|
||||||
mode='strict', meta='table')
|
|
||||||
|
|
||||||
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
|
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
|
||||||
views = self.populate_schema_objects(suggestion.schema, 'views')
|
views = self.populate_schema_objects(suggestion.schema, "views")
|
||||||
|
|
||||||
if not suggestion.schema and (
|
if not suggestion.schema and (
|
||||||
not word_before_cursor.startswith('pg_')):
|
not word_before_cursor.startswith("pg_")):
|
||||||
views = [v for v in views if not v.name.startswith('pg_')]
|
views = [v for v in views if not v.name.startswith("pg_")]
|
||||||
views = [self._make_cand(v, alias, suggestion) for v in views]
|
views = [self._make_cand(v, alias, suggestion) for v in views]
|
||||||
return self.find_matches(word_before_cursor, views,
|
return self.find_matches(word_before_cursor, views, meta="view")
|
||||||
mode='strict', meta='view')
|
|
||||||
|
|
||||||
def get_alias_matches(self, suggestion, word_before_cursor):
|
def get_alias_matches(self, suggestion, word_before_cursor):
|
||||||
aliases = suggestion.aliases
|
aliases = suggestion.aliases
|
||||||
return self.find_matches(word_before_cursor, aliases,
|
return self.find_matches(word_before_cursor, aliases,
|
||||||
mode='strict', meta='table alias')
|
meta="table alias")
|
||||||
|
|
||||||
def get_database_matches(self, _, word_before_cursor):
|
def get_database_matches(self, _, word_before_cursor):
|
||||||
return self.find_matches(word_before_cursor, self.databases,
|
return self.find_matches(word_before_cursor, self.databases,
|
||||||
mode='strict', meta='database')
|
meta="database")
|
||||||
|
|
||||||
def get_keyword_matches(self, suggestion, word_before_cursor):
|
def get_keyword_matches(self, suggestion, word_before_cursor):
|
||||||
return self.find_matches(word_before_cursor, self.keywords,
|
return self.find_matches(word_before_cursor, self.keywords,
|
||||||
mode='strict', meta='keyword')
|
meta="keyword")
|
||||||
|
|
||||||
def get_datatype_matches(self, suggestion, word_before_cursor):
|
def get_datatype_matches(self, suggestion, word_before_cursor):
|
||||||
# suggest custom datatypes
|
# suggest custom datatypes
|
||||||
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
|
types = self.populate_schema_objects(suggestion.schema, "datatypes")
|
||||||
types = [self._make_cand(t, False, suggestion) for t in types]
|
types = [self._make_cand(t, False, suggestion) for t in types]
|
||||||
matches = self.find_matches(word_before_cursor, types,
|
matches = self.find_matches(word_before_cursor, types, meta="datatype")
|
||||||
mode='strict', meta='datatype')
|
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def get_word_before_cursor(self, word=False):
|
def get_word_before_cursor(self, word=False):
|
||||||
@ -1004,52 +987,6 @@ class SQLAutoComplete(object):
|
|||||||
Datatype: get_datatype_matches,
|
Datatype: get_datatype_matches,
|
||||||
}
|
}
|
||||||
|
|
||||||
def addcols(self, schema, rel, alias, reltype, cols, columns):
|
|
||||||
"""
|
|
||||||
Add columns in schema column list.
|
|
||||||
:param schema: Schema for reference.
|
|
||||||
:param rel:
|
|
||||||
:param alias:
|
|
||||||
:param reltype:
|
|
||||||
:param cols:
|
|
||||||
:param columns:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
tbl = TableReference(schema, rel, alias, reltype == 'functions')
|
|
||||||
if tbl not in columns:
|
|
||||||
columns[tbl] = []
|
|
||||||
columns[tbl].extend(cols)
|
|
||||||
|
|
||||||
def _get_schema_columns(self, schemas, tbl, meta, columns):
|
|
||||||
"""
|
|
||||||
Check and add schema table columns as per table.
|
|
||||||
:param schemas: Schema
|
|
||||||
:param tbl:
|
|
||||||
:param meta:
|
|
||||||
:param columns: column list
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for schema in schemas:
|
|
||||||
relname = self.escape_name(tbl.name)
|
|
||||||
schema = self.escape_name(schema)
|
|
||||||
if tbl.is_function:
|
|
||||||
# Return column names from a set-returning function
|
|
||||||
# Get an array of FunctionMetadata objects
|
|
||||||
functions = meta['functions'].get(schema, {}).get(relname)
|
|
||||||
for func in (functions or []):
|
|
||||||
# func is a FunctionMetadata object
|
|
||||||
cols = func.fields()
|
|
||||||
self.addcols(schema, relname, tbl.alias, 'functions', cols,
|
|
||||||
columns)
|
|
||||||
else:
|
|
||||||
for reltype in ('tables', 'views'):
|
|
||||||
cols = meta[reltype].get(schema, {}).get(relname)
|
|
||||||
if cols:
|
|
||||||
cols = cols.values()
|
|
||||||
self.addcols(schema, relname, tbl.alias, reltype, cols,
|
|
||||||
columns)
|
|
||||||
break
|
|
||||||
|
|
||||||
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
|
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
|
||||||
"""Find all columns in a set of scoped_tables.
|
"""Find all columns in a set of scoped_tables.
|
||||||
|
|
||||||
@ -1062,14 +999,37 @@ class SQLAutoComplete(object):
|
|||||||
columns = OrderedDict()
|
columns = OrderedDict()
|
||||||
meta = self.dbmetadata
|
meta = self.dbmetadata
|
||||||
|
|
||||||
|
def addcols(schema, rel, alias, reltype, cols):
|
||||||
|
tbl = TableReference(schema, rel, alias, reltype == "functions")
|
||||||
|
if tbl not in columns:
|
||||||
|
columns[tbl] = []
|
||||||
|
columns[tbl].extend(cols)
|
||||||
|
|
||||||
for tbl in scoped_tbls:
|
for tbl in scoped_tbls:
|
||||||
# Local tables should shadow database tables
|
# Local tables should shadow database tables
|
||||||
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
|
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
|
||||||
cols = ctes[normalize_ref(tbl.name)]
|
cols = ctes[normalize_ref(tbl.name)]
|
||||||
self.addcols(None, tbl.name, 'CTE', tbl.alias, cols, columns)
|
addcols(None, tbl.name, "CTE", tbl.alias, cols)
|
||||||
continue
|
continue
|
||||||
schemas = [tbl.schema] if tbl.schema else self.search_path
|
schemas = [tbl.schema] if tbl.schema else self.search_path
|
||||||
self._get_schema_columns(schemas, tbl, meta, columns)
|
for schema in schemas:
|
||||||
|
relname = self.escape_name(tbl.name)
|
||||||
|
schema = self.escape_name(schema)
|
||||||
|
if tbl.is_function:
|
||||||
|
# Return column names from a set-returning function
|
||||||
|
# Get an array of FunctionMetadata objects
|
||||||
|
functions = meta["functions"].get(schema, {}).get(relname)
|
||||||
|
for func in functions or []:
|
||||||
|
# func is a FunctionMetadata object
|
||||||
|
cols = func.fields()
|
||||||
|
addcols(schema, relname, tbl.alias, "functions", cols)
|
||||||
|
else:
|
||||||
|
for reltype in ("tables", "views"):
|
||||||
|
cols = meta[reltype].get(schema, {}).get(relname)
|
||||||
|
if cols:
|
||||||
|
cols = cols.values()
|
||||||
|
addcols(schema, relname, tbl.alias, reltype, cols)
|
||||||
|
break
|
||||||
|
|
||||||
return columns
|
return columns
|
||||||
|
|
||||||
@ -1125,10 +1085,10 @@ class SQLAutoComplete(object):
|
|||||||
SchemaObject(
|
SchemaObject(
|
||||||
name=func,
|
name=func,
|
||||||
schema=(self._maybe_schema(schema=sch, parent=schema)),
|
schema=(self._maybe_schema(schema=sch, parent=schema)),
|
||||||
meta=meta
|
meta=meta,
|
||||||
)
|
)
|
||||||
for sch in self._get_schemas('functions', schema)
|
for sch in self._get_schemas("functions", schema)
|
||||||
for (func, metas) in self.dbmetadata['functions'][sch].items()
|
for (func, metas) in self.dbmetadata["functions"][sch].items()
|
||||||
for meta in metas
|
for meta in metas
|
||||||
if filter_func(meta)
|
if filter_func(meta)
|
||||||
]
|
]
|
||||||
@ -1234,6 +1194,7 @@ class SQLAutoComplete(object):
|
|||||||
row['is_aggregate'],
|
row['is_aggregate'],
|
||||||
row['is_window'],
|
row['is_window'],
|
||||||
row['is_set_returning'],
|
row['is_set_returning'],
|
||||||
|
row['is_extension'],
|
||||||
row['arg_defaults'].strip('{}').split(',')
|
row['arg_defaults'].strip('{}').split(',')
|
||||||
if row['arg_defaults'] is not None
|
if row['arg_defaults'] is not None
|
||||||
else row['arg_defaults']
|
else row['arg_defaults']
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Using Completion class from
|
Using Completion class from
|
||||||
https://github.com/jonathanslenders/python-prompt-toolkit/
|
https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/completion/base.py
|
||||||
blob/master/prompt_toolkit/completion.py
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
@ -38,7 +37,7 @@ class Completion(object):
|
|||||||
assert self.start_position <= 0
|
assert self.start_position <= 0
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '%s(text=%r, start_position=%r)' % (
|
return "%s(text=%r, start_position=%r)" % (
|
||||||
self.__class__.__name__, self.text, self.start_position)
|
self.__class__.__name__, self.text, self.start_position)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
@ -1,295 +0,0 @@
|
|||||||
import re
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import sqlparse
|
|
||||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
|
||||||
from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
|
|
||||||
|
|
||||||
cleanup_regex = {
|
|
||||||
# This matches only alphanumerics and underscores.
|
|
||||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
|
||||||
# This matches everything except spaces, parens, colon, and comma
|
|
||||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
|
||||||
# This matches everything except spaces, parens, colon, comma, and period
|
|
||||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
|
||||||
# This matches everything except a space.
|
|
||||||
'all_punctuations': re.compile('([^\s]+)$'),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def last_word(text, include='alphanum_underscore'):
|
|
||||||
"""
|
|
||||||
Find the last word in a sentence.
|
|
||||||
|
|
||||||
>>> last_word('abc')
|
|
||||||
'abc'
|
|
||||||
>>> last_word(' abc')
|
|
||||||
'abc'
|
|
||||||
>>> last_word('')
|
|
||||||
''
|
|
||||||
>>> last_word(' ')
|
|
||||||
''
|
|
||||||
>>> last_word('abc ')
|
|
||||||
''
|
|
||||||
>>> last_word('abc def')
|
|
||||||
'def'
|
|
||||||
>>> last_word('abc def ')
|
|
||||||
''
|
|
||||||
>>> last_word('abc def;')
|
|
||||||
''
|
|
||||||
>>> last_word('bac $def')
|
|
||||||
'def'
|
|
||||||
>>> last_word('bac $def', include='most_punctuations')
|
|
||||||
'$def'
|
|
||||||
>>> last_word('bac \def', include='most_punctuations')
|
|
||||||
'\\\\def'
|
|
||||||
>>> last_word('bac \def;', include='most_punctuations')
|
|
||||||
'\\\\def;'
|
|
||||||
>>> last_word('bac::def', include='most_punctuations')
|
|
||||||
'def'
|
|
||||||
>>> last_word('"foo*bar', include='most_punctuations')
|
|
||||||
'"foo*bar'
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not text: # Empty string
|
|
||||||
return ''
|
|
||||||
|
|
||||||
if text[-1].isspace():
|
|
||||||
return ''
|
|
||||||
else:
|
|
||||||
regex = cleanup_regex[include]
|
|
||||||
matches = regex.search(text)
|
|
||||||
if matches:
|
|
||||||
return matches.group(0)
|
|
||||||
else:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
TableReference = namedtuple(
|
|
||||||
'TableReference', ['schema', 'name', 'alias', 'is_function']
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This code is borrowed from sqlparse example script.
|
|
||||||
# <url>
|
|
||||||
def is_subselect(parsed):
|
|
||||||
if not parsed.is_group():
|
|
||||||
return False
|
|
||||||
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
|
|
||||||
for item in parsed.tokens:
|
|
||||||
if item.ttype is DML and item.value.upper() in sql_type:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _identifier_is_function(identifier):
|
|
||||||
return any(isinstance(t, Function) for t in identifier.tokens)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_from_part(parsed, stop_at_punctuation=True):
|
|
||||||
tbl_prefix_seen = False
|
|
||||||
for item in parsed.tokens:
|
|
||||||
if tbl_prefix_seen:
|
|
||||||
if is_subselect(item):
|
|
||||||
for x in extract_from_part(item, stop_at_punctuation):
|
|
||||||
yield x
|
|
||||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
|
||||||
raise StopIteration
|
|
||||||
# An incomplete nested select won't be recognized correctly as a
|
|
||||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This
|
|
||||||
# causes the second FROM to trigger this elif condition resulting
|
|
||||||
# in a StopIteration. So we need to ignore the keyword if the
|
|
||||||
# keyword FROM.
|
|
||||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
|
||||||
# condition. So we need to ignore the keyword JOIN and its
|
|
||||||
# variants INNER JOIN, FULL OUTER JOIN, etc.
|
|
||||||
elif item.ttype is Keyword and (
|
|
||||||
not item.value.upper() == 'FROM') and (
|
|
||||||
not item.value.upper().endswith('JOIN')):
|
|
||||||
tbl_prefix_seen = False
|
|
||||||
else:
|
|
||||||
yield item
|
|
||||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
|
||||||
item_val = item.value.upper()
|
|
||||||
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
|
||||||
item_val.endswith('JOIN')):
|
|
||||||
tbl_prefix_seen = True
|
|
||||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
|
||||||
# So this check here is necessary.
|
|
||||||
elif isinstance(item, IdentifierList):
|
|
||||||
for identifier in item.get_identifiers():
|
|
||||||
if (identifier.ttype is Keyword and
|
|
||||||
identifier.value.upper() == 'FROM'):
|
|
||||||
tbl_prefix_seen = True
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def extract_table_identifiers(token_stream, allow_functions=True):
|
|
||||||
"""yields tuples of TableReference namedtuples"""
|
|
||||||
|
|
||||||
for item in token_stream:
|
|
||||||
if isinstance(item, IdentifierList):
|
|
||||||
for identifier in item.get_identifiers():
|
|
||||||
# Sometimes Keywords (such as FROM ) are classified as
|
|
||||||
# identifiers which don't have the get_real_name() method.
|
|
||||||
try:
|
|
||||||
schema_name = identifier.get_parent_name()
|
|
||||||
real_name = identifier.get_real_name()
|
|
||||||
is_function = (allow_functions and
|
|
||||||
_identifier_is_function(identifier))
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
if real_name:
|
|
||||||
yield TableReference(
|
|
||||||
schema_name, real_name, identifier.get_alias(),
|
|
||||||
is_function
|
|
||||||
)
|
|
||||||
elif isinstance(item, Identifier):
|
|
||||||
real_name = item.get_real_name()
|
|
||||||
schema_name = item.get_parent_name()
|
|
||||||
is_function = allow_functions and _identifier_is_function(item)
|
|
||||||
|
|
||||||
if real_name:
|
|
||||||
yield TableReference(
|
|
||||||
schema_name, real_name, item.get_alias(), is_function
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
name = item.get_name()
|
|
||||||
yield TableReference(
|
|
||||||
None, name, item.get_alias() or name, is_function
|
|
||||||
)
|
|
||||||
elif isinstance(item, Function):
|
|
||||||
yield TableReference(
|
|
||||||
None, item.get_real_name(), item.get_alias(), allow_functions
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# extract_tables is inspired from examples in the sqlparse lib.
|
|
||||||
def extract_tables(sql):
|
|
||||||
"""Extract the table names from an SQL statment.
|
|
||||||
|
|
||||||
Returns a list of TableReference namedtuples
|
|
||||||
|
|
||||||
"""
|
|
||||||
parsed = sqlparse.parse(sql)
|
|
||||||
if not parsed:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
# INSERT statements must stop looking for tables at the sign of first
|
|
||||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
|
||||||
# abc is the table name, but if we don't stop at the first lparen, then
|
|
||||||
# we'll identify abc, col1 and col2 as table names.
|
|
||||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
|
||||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
|
||||||
|
|
||||||
# Kludge: sqlparse mistakenly identifies insert statements as
|
|
||||||
# function calls due to the parenthesized column list, e.g. interprets
|
|
||||||
# "insert into foo (bar, baz)" as a function call to foo with arguments
|
|
||||||
# (bar, baz). So don't allow any identifiers in insert statements
|
|
||||||
# to have is_function=True
|
|
||||||
identifiers = extract_table_identifiers(
|
|
||||||
stream, allow_functions=not insert_stmt
|
|
||||||
)
|
|
||||||
return tuple(identifiers)
|
|
||||||
|
|
||||||
|
|
||||||
def find_prev_keyword(sql):
|
|
||||||
""" Find the last sql keyword in an SQL statement
|
|
||||||
|
|
||||||
Returns the value of the last keyword, and the text of the query with
|
|
||||||
everything after the last keyword stripped
|
|
||||||
"""
|
|
||||||
if not sql.strip():
|
|
||||||
return None, ''
|
|
||||||
|
|
||||||
parsed = sqlparse.parse(sql)[0]
|
|
||||||
flattened = list(parsed.flatten())
|
|
||||||
|
|
||||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
|
||||||
|
|
||||||
for t in reversed(flattened):
|
|
||||||
if t.value == '(' or (t.is_keyword and (
|
|
||||||
t.value.upper() not in logical_operators)):
|
|
||||||
# Find the location of token t in the original parsed statement
|
|
||||||
# We can't use parsed.token_index(t) because t may be a child
|
|
||||||
# token inside a TokenList, in which case token_index thows an
|
|
||||||
# error Minimal example:
|
|
||||||
# p = sqlparse.parse('select * from foo where bar')
|
|
||||||
# t = list(p.flatten())[-3] # The "Where" token
|
|
||||||
# p.token_index(t) # Throws ValueError: not in list
|
|
||||||
idx = flattened.index(t)
|
|
||||||
|
|
||||||
# Combine the string values of all tokens in the original list
|
|
||||||
# up to and including the target keyword token t, to produce a
|
|
||||||
# query string with everything after the keyword token removed
|
|
||||||
text = ''.join(tok.value for tok in flattened[:idx + 1])
|
|
||||||
return t, text
|
|
||||||
|
|
||||||
return None, ''
|
|
||||||
|
|
||||||
|
|
||||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
|
||||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
|
||||||
|
|
||||||
|
|
||||||
def is_open_quote(sql):
|
|
||||||
"""Returns true if the query contains an unclosed quote"""
|
|
||||||
|
|
||||||
# parsed can contain one or more semi-colon separated commands
|
|
||||||
parsed = sqlparse.parse(sql)
|
|
||||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
|
||||||
|
|
||||||
|
|
||||||
def _parsed_is_open_quote(parsed):
|
|
||||||
tokens = list(parsed.flatten())
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
tok = tokens[i]
|
|
||||||
if tok.match(Token.Error, "'"):
|
|
||||||
# An unmatched single quote
|
|
||||||
return True
|
|
||||||
elif (tok.ttype in Token.Name.Builtin and
|
|
||||||
dollar_quote_regex.match(tok.value)):
|
|
||||||
# Find the matching closing dollar quote sign
|
|
||||||
for (j, tok2) in enumerate(tokens[i + 1:], i + 1):
|
|
||||||
if tok2.match(Token.Name.Builtin, tok.value):
|
|
||||||
# Found the matching closing quote - continue our scan
|
|
||||||
# for open quotes thereafter
|
|
||||||
i = j
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# No matching dollar sign quote
|
|
||||||
return True
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_partial_identifier(word):
|
|
||||||
"""Attempt to parse a (partially typed) word as an identifier
|
|
||||||
|
|
||||||
word may include a schema qualification, like `schema_name.partial_name`
|
|
||||||
or `schema_name.` There may also be unclosed quotation marks, like
|
|
||||||
`"schema`, or `schema."partial_name`
|
|
||||||
|
|
||||||
:param word: string representing a (partially complete) identifier
|
|
||||||
:return: sqlparse.sql.Identifier, or None
|
|
||||||
"""
|
|
||||||
|
|
||||||
p = sqlparse.parse(word)[0]
|
|
||||||
n_tok = len(p.tokens)
|
|
||||||
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
|
|
||||||
return p.tokens[0]
|
|
||||||
elif hasattr(p, 'token_next_match') and p.token_next_match(0, Error, '"'):
|
|
||||||
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
|
|
||||||
# Close the double quote, then reparse
|
|
||||||
return parse_partial_identifier(word + '"')
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sql = 'select * from (select t. from tabl t'
|
|
||||||
print(extract_tables(sql))
|
|
@ -4,7 +4,7 @@ import sqlparse
|
|||||||
def query_starts_with(query, prefixes):
|
def query_starts_with(query, prefixes):
|
||||||
"""Check if the query starts with any item from *prefixes*."""
|
"""Check if the query starts with any item from *prefixes*."""
|
||||||
prefixes = [prefix.lower() for prefix in prefixes]
|
prefixes = [prefix.lower() for prefix in prefixes]
|
||||||
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
|
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
|
||||||
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
||||||
|
|
||||||
|
|
||||||
@ -18,5 +18,5 @@ def queries_start_with(queries, prefixes):
|
|||||||
|
|
||||||
def is_destructive(queries):
|
def is_destructive(queries):
|
||||||
"""Returns if any of the queries in *queries* is destructive."""
|
"""Returns if any of the queries in *queries* is destructive."""
|
||||||
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
|
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
|
||||||
return queries_start_with(queries, keywords)
|
return queries_start_with(queries, keywords)
|
||||||
|
@ -10,12 +10,11 @@ from .meta import TableMetadata, ColumnMetadata
|
|||||||
# columns: list of column names
|
# columns: list of column names
|
||||||
# start: index into the original string of the left parens starting the CTE
|
# start: index into the original string of the left parens starting the CTE
|
||||||
# stop: index into the original string of the right parens ending the CTE
|
# stop: index into the original string of the right parens ending the CTE
|
||||||
TableExpression = namedtuple('TableExpression', 'name columns start stop')
|
TableExpression = namedtuple("TableExpression", "name columns start stop")
|
||||||
|
|
||||||
|
|
||||||
def isolate_query_ctes(full_text, text_before_cursor):
|
def isolate_query_ctes(full_text, text_before_cursor):
|
||||||
"""Simplify a query by converting CTEs into table metadata objects
|
"""Simplify a query by converting CTEs into table metadata objects"""
|
||||||
"""
|
|
||||||
|
|
||||||
if not full_text:
|
if not full_text:
|
||||||
return full_text, text_before_cursor, tuple()
|
return full_text, text_before_cursor, tuple()
|
||||||
@ -30,8 +29,8 @@ def isolate_query_ctes(full_text, text_before_cursor):
|
|||||||
for cte in ctes:
|
for cte in ctes:
|
||||||
if cte.start < current_position < cte.stop:
|
if cte.start < current_position < cte.stop:
|
||||||
# Currently editing a cte - treat its body as the current full_text
|
# Currently editing a cte - treat its body as the current full_text
|
||||||
text_before_cursor = full_text[cte.start:current_position]
|
text_before_cursor = full_text[cte.start: current_position]
|
||||||
full_text = full_text[cte.start:cte.stop]
|
full_text = full_text[cte.start: cte.stop]
|
||||||
return full_text, text_before_cursor, meta
|
return full_text, text_before_cursor, meta
|
||||||
|
|
||||||
# Append this cte to the list of available table metadata
|
# Append this cte to the list of available table metadata
|
||||||
@ -40,19 +39,19 @@ def isolate_query_ctes(full_text, text_before_cursor):
|
|||||||
|
|
||||||
# Editing past the last cte (ie the main body of the query)
|
# Editing past the last cte (ie the main body of the query)
|
||||||
full_text = full_text[ctes[-1].stop:]
|
full_text = full_text[ctes[-1].stop:]
|
||||||
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
|
text_before_cursor = text_before_cursor[ctes[-1].stop: current_position]
|
||||||
|
|
||||||
return full_text, text_before_cursor, tuple(meta)
|
return full_text, text_before_cursor, tuple(meta)
|
||||||
|
|
||||||
|
|
||||||
def extract_ctes(sql):
|
def extract_ctes(sql):
|
||||||
""" Extract constant table expresseions from a query
|
"""Extract constant table expresseions from a query
|
||||||
|
|
||||||
Returns tuple (ctes, remainder_sql)
|
Returns tuple (ctes, remainder_sql)
|
||||||
|
|
||||||
ctes is a list of TableExpression namedtuples
|
ctes is a list of TableExpression namedtuples
|
||||||
remainder_sql is the text from the original query after the CTEs have
|
remainder_sql is the text from the original query after the CTEs have
|
||||||
been stripped.
|
been stripped.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
p = parse(sql)[0]
|
p = parse(sql)[0]
|
||||||
@ -66,7 +65,7 @@ def extract_ctes(sql):
|
|||||||
# Get the next (meaningful) token, which should be the first CTE
|
# Get the next (meaningful) token, which should be the first CTE
|
||||||
idx, tok = p.token_next(idx)
|
idx, tok = p.token_next(idx)
|
||||||
if not tok:
|
if not tok:
|
||||||
return ([], '')
|
return ([], "")
|
||||||
start_pos = token_start_pos(p.tokens, idx)
|
start_pos = token_start_pos(p.tokens, idx)
|
||||||
ctes = []
|
ctes = []
|
||||||
|
|
||||||
@ -87,7 +86,7 @@ def extract_ctes(sql):
|
|||||||
idx = p.token_index(tok) + 1
|
idx = p.token_index(tok) + 1
|
||||||
|
|
||||||
# Collapse everything after the ctes into a remainder query
|
# Collapse everything after the ctes into a remainder query
|
||||||
remainder = ''.join(str(tok) for tok in p.tokens[idx:])
|
remainder = "".join(str(tok) for tok in p.tokens[idx:])
|
||||||
|
|
||||||
return ctes, remainder
|
return ctes, remainder
|
||||||
|
|
||||||
@ -112,15 +111,15 @@ def get_cte_from_token(tok, pos0):
|
|||||||
|
|
||||||
|
|
||||||
def extract_column_names(parsed):
|
def extract_column_names(parsed):
|
||||||
# Find the first DML token to check if it's a SELECT or
|
# Find the first DML token to check if it's a
|
||||||
# INSERT/UPDATE/DELETE
|
# SELECT or INSERT/UPDATE/DELETE
|
||||||
idx, tok = parsed.token_next_by(t=DML)
|
idx, tok = parsed.token_next_by(t=DML)
|
||||||
tok_val = tok and tok.value.lower()
|
tok_val = tok and tok.value.lower()
|
||||||
|
|
||||||
if tok_val in ('insert', 'update', 'delete'):
|
if tok_val in ("insert", "update", "delete"):
|
||||||
# Jump ahead to the RETURNING clause where the list of column names is
|
# Jump ahead to the RETURNING clause where the list of column names is
|
||||||
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
|
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
|
||||||
elif not tok_val == 'select':
|
elif not tok_val == "select":
|
||||||
# Must be invalid CTE
|
# Must be invalid CTE
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
|
@ -1,23 +1,29 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
_ColumnMetadata = namedtuple(
|
_ColumnMetadata = namedtuple(
|
||||||
'ColumnMetadata',
|
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default",
|
||||||
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
|
"has_default"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def ColumnMetadata(
|
def ColumnMetadata(name, datatype, foreignkeys=None, default=None,
|
||||||
name, datatype, foreignkeys=None, default=None, has_default=False
|
has_default=False):
|
||||||
):
|
return _ColumnMetadata(name, datatype, foreignkeys or [], default,
|
||||||
return _ColumnMetadata(
|
has_default)
|
||||||
name, datatype, foreignkeys or [], default, has_default
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
|
ForeignKey = namedtuple(
|
||||||
'parentcolumn', 'childschema',
|
"ForeignKey",
|
||||||
'childtable', 'childcolumn'])
|
[
|
||||||
TableMetadata = namedtuple('TableMetadata', 'name columns')
|
"parentschema",
|
||||||
|
"parenttable",
|
||||||
|
"parentcolumn",
|
||||||
|
"childschema",
|
||||||
|
"childtable",
|
||||||
|
"childcolumn",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
TableMetadata = namedtuple("TableMetadata", "name columns")
|
||||||
|
|
||||||
|
|
||||||
def parse_defaults(defaults_string):
|
def parse_defaults(defaults_string):
|
||||||
@ -25,34 +31,42 @@ def parse_defaults(defaults_string):
|
|||||||
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
|
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
|
||||||
if not defaults_string:
|
if not defaults_string:
|
||||||
return
|
return
|
||||||
current = ''
|
current = ""
|
||||||
in_quote = None
|
in_quote = None
|
||||||
for char in defaults_string:
|
for char in defaults_string:
|
||||||
if current == '' and char == ' ':
|
if current == "" and char == " ":
|
||||||
# Skip space after comma separating default expressions
|
# Skip space after comma separating default expressions
|
||||||
continue
|
continue
|
||||||
if char == '"' or char == '\'':
|
if char == '"' or char == "'":
|
||||||
if in_quote and char == in_quote:
|
if in_quote and char == in_quote:
|
||||||
# End quote
|
# End quote
|
||||||
in_quote = None
|
in_quote = None
|
||||||
elif not in_quote:
|
elif not in_quote:
|
||||||
# Begin quote
|
# Begin quote
|
||||||
in_quote = char
|
in_quote = char
|
||||||
elif char == ',' and not in_quote:
|
elif char == "," and not in_quote:
|
||||||
# End of expression
|
# End of expression
|
||||||
yield current
|
yield current
|
||||||
current = ''
|
current = ""
|
||||||
continue
|
continue
|
||||||
current += char
|
current += char
|
||||||
yield current
|
yield current
|
||||||
|
|
||||||
|
|
||||||
class FunctionMetadata(object):
|
class FunctionMetadata(object):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, schema_name, func_name, arg_names, arg_types, arg_modes,
|
self,
|
||||||
return_type, is_aggregate, is_window, is_set_returning,
|
schema_name,
|
||||||
arg_defaults
|
func_name,
|
||||||
|
arg_names,
|
||||||
|
arg_types,
|
||||||
|
arg_modes,
|
||||||
|
return_type,
|
||||||
|
is_aggregate,
|
||||||
|
is_window,
|
||||||
|
is_set_returning,
|
||||||
|
is_extension,
|
||||||
|
arg_defaults,
|
||||||
):
|
):
|
||||||
"""Class for describing a postgresql function"""
|
"""Class for describing a postgresql function"""
|
||||||
|
|
||||||
@ -80,19 +94,29 @@ class FunctionMetadata(object):
|
|||||||
self.is_aggregate = is_aggregate
|
self.is_aggregate = is_aggregate
|
||||||
self.is_window = is_window
|
self.is_window = is_window
|
||||||
self.is_set_returning = is_set_returning
|
self.is_set_returning = is_set_returning
|
||||||
|
self.is_extension = bool(is_extension)
|
||||||
|
self.is_public = self.schema_name and self.schema_name == "public"
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (isinstance(other, self.__class__) and
|
return isinstance(other, self.__class__) and \
|
||||||
self.__dict__ == other.__dict__)
|
self.__dict__ == other.__dict__
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
def _signature(self):
|
def _signature(self):
|
||||||
return (
|
return (
|
||||||
self.schema_name, self.func_name, self.arg_names, self.arg_types,
|
self.schema_name,
|
||||||
self.arg_modes, self.return_type, self.is_aggregate,
|
self.func_name,
|
||||||
self.is_window, self.is_set_returning, self.arg_defaults
|
self.arg_names,
|
||||||
|
self.arg_types,
|
||||||
|
self.arg_modes,
|
||||||
|
self.return_type,
|
||||||
|
self.is_aggregate,
|
||||||
|
self.is_window,
|
||||||
|
self.is_set_returning,
|
||||||
|
self.is_extension,
|
||||||
|
self.arg_defaults,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
@ -100,26 +124,25 @@ class FunctionMetadata(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
(
|
"%s(schema_name=%r, func_name=%r, arg_names=%r, "
|
||||||
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
|
"arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
|
||||||
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
|
"is_window=%r, is_set_returning=%r, is_extension=%r, "
|
||||||
'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
|
"arg_defaults=%r)"
|
||||||
) % (self.__class__.__name__,) + self._signature()
|
) % ((self.__class__.__name__,) + self._signature())
|
||||||
)
|
|
||||||
|
|
||||||
def has_variadic(self):
|
def has_variadic(self):
|
||||||
return self.arg_modes and any(
|
return self.arg_modes and \
|
||||||
arg_mode == 'v' for arg_mode in self.arg_modes)
|
any(arg_mode == "v" for arg_mode in self.arg_modes)
|
||||||
|
|
||||||
def args(self):
|
def args(self):
|
||||||
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
|
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
|
||||||
if not self.arg_names:
|
if not self.arg_names:
|
||||||
return []
|
return []
|
||||||
modes = self.arg_modes or ['i'] * len(self.arg_names)
|
modes = self.arg_modes or ["i"] * len(self.arg_names)
|
||||||
args = [
|
args = [
|
||||||
(name, typ)
|
(name, typ)
|
||||||
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
|
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
|
||||||
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
|
if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
|
||||||
]
|
]
|
||||||
|
|
||||||
def arg(name, typ, num):
|
def arg(name, typ, num):
|
||||||
@ -127,7 +150,8 @@ class FunctionMetadata(object):
|
|||||||
num_defaults = len(self.arg_defaults)
|
num_defaults = len(self.arg_defaults)
|
||||||
has_default = num + num_defaults >= num_args
|
has_default = num + num_defaults >= num_args
|
||||||
default = (
|
default = (
|
||||||
self.arg_defaults[num - num_args + num_defaults] if has_default
|
self.arg_defaults[num - num_args + num_defaults]
|
||||||
|
if has_default
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return ColumnMetadata(name, typ, [], default, has_default)
|
return ColumnMetadata(name, typ, [], default, has_default)
|
||||||
@ -137,7 +161,7 @@ class FunctionMetadata(object):
|
|||||||
def fields(self):
|
def fields(self):
|
||||||
"""Returns a list of output-field ColumnMetadata namedtuples"""
|
"""Returns a list of output-field ColumnMetadata namedtuples"""
|
||||||
|
|
||||||
if self.return_type.lower() == 'void':
|
if self.return_type.lower() == "void":
|
||||||
return []
|
return []
|
||||||
elif not self.arg_modes:
|
elif not self.arg_modes:
|
||||||
# For functions without output parameters, the function name
|
# For functions without output parameters, the function name
|
||||||
@ -145,7 +169,9 @@ class FunctionMetadata(object):
|
|||||||
# E.g. 'SELECT unnest FROM unnest(...);'
|
# E.g. 'SELECT unnest FROM unnest(...);'
|
||||||
return [ColumnMetadata(self.func_name, self.return_type, [])]
|
return [ColumnMetadata(self.func_name, self.return_type, [])]
|
||||||
|
|
||||||
return [ColumnMetadata(name, typ, [])
|
return [
|
||||||
for name, typ, mode in zip(
|
ColumnMetadata(name, typ, [])
|
||||||
self.arg_names, self.arg_types, self.arg_modes)
|
for name, typ, mode in zip(self.arg_names, self.arg_types,
|
||||||
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
|
self.arg_modes)
|
||||||
|
if mode in ("o", "b", "t")
|
||||||
|
] # OUT, INOUT, TABLE
|
||||||
|
@ -3,12 +3,15 @@ from collections import namedtuple
|
|||||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||||
from sqlparse.tokens import Keyword, DML, Punctuation
|
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||||
|
|
||||||
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
|
TableReference = namedtuple(
|
||||||
'is_function'])
|
"TableReference", ["schema", "name", "alias", "is_function"]
|
||||||
|
)
|
||||||
TableReference.ref = property(
|
TableReference.ref = property(
|
||||||
lambda self: self.alias or (
|
lambda self: self.alias or (
|
||||||
self.name if self.name.islower() or self.name[0] == '"'
|
self.name
|
||||||
else '"' + self.name + '"')
|
if self.name.islower() or self.name[0] == '"'
|
||||||
|
else '"' + self.name + '"'
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -18,9 +21,13 @@ def is_subselect(parsed):
|
|||||||
if not parsed.is_group:
|
if not parsed.is_group:
|
||||||
return False
|
return False
|
||||||
for item in parsed.tokens:
|
for item in parsed.tokens:
|
||||||
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
if item.ttype is DML and item.value.upper() in (
|
||||||
'UPDATE', 'CREATE',
|
"SELECT",
|
||||||
'DELETE'):
|
"INSERT",
|
||||||
|
"UPDATE",
|
||||||
|
"CREATE",
|
||||||
|
"DELETE",
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -37,32 +44,42 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
|||||||
for x in extract_from_part(item, stop_at_punctuation):
|
for x in extract_from_part(item, stop_at_punctuation):
|
||||||
yield x
|
yield x
|
||||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||||
raise StopIteration
|
return
|
||||||
# An incomplete nested select won't be recognized correctly as a
|
# An incomplete nested select won't be recognized correctly as a
|
||||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
||||||
# the second FROM to trigger this elif condition resulting in a
|
# the second FROM to trigger this elif condition resulting in a
|
||||||
# StopIteration. So we need to ignore the keyword if the keyword
|
# `return`. So we need to ignore the keyword if the keyword
|
||||||
# FROM.
|
# FROM.
|
||||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||||
# condition. So we need to ignore the keyword JOIN and its variants
|
# condition. So we need to ignore the keyword JOIN and its variants
|
||||||
# INNER JOIN, FULL OUTER JOIN, etc.
|
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||||
elif item.ttype is Keyword and (
|
elif (
|
||||||
not item.value.upper() == 'FROM') and (
|
item.ttype is Keyword and
|
||||||
not item.value.upper().endswith('JOIN')):
|
(not item.value.upper() == "FROM") and
|
||||||
|
(not item.value.upper().endswith("JOIN"))
|
||||||
|
):
|
||||||
tbl_prefix_seen = False
|
tbl_prefix_seen = False
|
||||||
else:
|
else:
|
||||||
yield item
|
yield item
|
||||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||||
item_val = item.value.upper()
|
item_val = item.value.upper()
|
||||||
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
if (
|
||||||
item_val.endswith('JOIN')):
|
item_val
|
||||||
|
in (
|
||||||
|
"COPY",
|
||||||
|
"FROM",
|
||||||
|
"INTO",
|
||||||
|
"UPDATE",
|
||||||
|
"TABLE",
|
||||||
|
) or item_val.endswith("JOIN")
|
||||||
|
):
|
||||||
tbl_prefix_seen = True
|
tbl_prefix_seen = True
|
||||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||||
# So this check here is necessary.
|
# So this check here is necessary.
|
||||||
elif isinstance(item, IdentifierList):
|
elif isinstance(item, IdentifierList):
|
||||||
for identifier in item.get_identifiers():
|
for identifier in item.get_identifiers():
|
||||||
if (identifier.ttype is Keyword and
|
if identifier.ttype is Keyword and \
|
||||||
identifier.value.upper() == 'FROM'):
|
identifier.value.upper() == "FROM":
|
||||||
tbl_prefix_seen = True
|
tbl_prefix_seen = True
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -94,29 +111,35 @@ def extract_table_identifiers(token_stream, allow_functions=True):
|
|||||||
name = name.lower()
|
name = name.lower()
|
||||||
return schema_name, name, alias
|
return schema_name, name, alias
|
||||||
|
|
||||||
for item in token_stream:
|
try:
|
||||||
if isinstance(item, IdentifierList):
|
for item in token_stream:
|
||||||
for identifier in item.get_identifiers():
|
if isinstance(item, IdentifierList):
|
||||||
# Sometimes Keywords (such as FROM ) are classified as
|
for identifier in item.get_identifiers():
|
||||||
# identifiers which don't have the get_real_name() method.
|
# Sometimes Keywords (such as FROM ) are classified as
|
||||||
try:
|
# identifiers which don't have the get_real_name() method.
|
||||||
schema_name = identifier.get_parent_name()
|
try:
|
||||||
real_name = identifier.get_real_name()
|
schema_name = identifier.get_parent_name()
|
||||||
is_function = (allow_functions and
|
real_name = identifier.get_real_name()
|
||||||
_identifier_is_function(identifier))
|
is_function = allow_functions and \
|
||||||
except AttributeError:
|
_identifier_is_function(identifier)
|
||||||
continue
|
except AttributeError:
|
||||||
if real_name:
|
continue
|
||||||
yield TableReference(schema_name, real_name,
|
if real_name:
|
||||||
identifier.get_alias(), is_function)
|
yield TableReference(
|
||||||
elif isinstance(item, Identifier):
|
schema_name, real_name, identifier.get_alias(),
|
||||||
schema_name, real_name, alias = parse_identifier(item)
|
is_function
|
||||||
is_function = allow_functions and _identifier_is_function(item)
|
)
|
||||||
|
elif isinstance(item, Identifier):
|
||||||
|
schema_name, real_name, alias = parse_identifier(item)
|
||||||
|
is_function = allow_functions and _identifier_is_function(item)
|
||||||
|
|
||||||
yield TableReference(schema_name, real_name, alias, is_function)
|
yield TableReference(schema_name, real_name, alias,
|
||||||
elif isinstance(item, Function):
|
is_function)
|
||||||
schema_name, real_name, alias = parse_identifier(item)
|
elif isinstance(item, Function):
|
||||||
yield TableReference(None, real_name, alias, allow_functions)
|
schema_name, real_name, alias = parse_identifier(item)
|
||||||
|
yield TableReference(None, real_name, alias, allow_functions)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
# extract_tables is inspired from examples in the sqlparse lib.
|
# extract_tables is inspired from examples in the sqlparse lib.
|
||||||
@ -134,7 +157,7 @@ def extract_tables(sql):
|
|||||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||||
# abc is the table name, but if we don't stop at the first lparen, then
|
# abc is the table name, but if we don't stop at the first lparen, then
|
||||||
# we'll identify abc, col1 and col2 as table names.
|
# we'll identify abc, col1 and col2 as table names.
|
||||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
insert_stmt = parsed[0].token_first().value.lower() == "insert"
|
||||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||||
|
|
||||||
# Kludge: sqlparse mistakenly identifies insert statements as
|
# Kludge: sqlparse mistakenly identifies insert statements as
|
||||||
|
@ -5,17 +5,17 @@ from sqlparse.tokens import Token, Error
|
|||||||
|
|
||||||
cleanup_regex = {
|
cleanup_regex = {
|
||||||
# This matches only alphanumerics and underscores.
|
# This matches only alphanumerics and underscores.
|
||||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
"alphanum_underscore": re.compile(r"(\w+)$"),
|
||||||
# This matches everything except spaces, parens, colon, and comma
|
# This matches everything except spaces, parens, colon, and comma
|
||||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
"many_punctuations": re.compile(r"([^():,\s]+)$"),
|
||||||
# This matches everything except spaces, parens, colon, comma, and period
|
# This matches everything except spaces, parens, colon, comma, and period
|
||||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
"most_punctuations": re.compile(r"([^\.():,\s]+)$"),
|
||||||
# This matches everything except a space.
|
# This matches everything except a space.
|
||||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
"all_punctuations": re.compile(r"([^\s]+)$"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def last_word(text, include='alphanum_underscore'):
|
def last_word(text, include="alphanum_underscore"):
|
||||||
r"""
|
r"""
|
||||||
Find the last word in a sentence.
|
Find the last word in a sentence.
|
||||||
|
|
||||||
@ -49,41 +49,42 @@ def last_word(text, include='alphanum_underscore'):
|
|||||||
'"foo*bar'
|
'"foo*bar'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not text: # Empty string
|
if not text: # Empty string
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
if text[-1].isspace():
|
if text[-1].isspace():
|
||||||
return ''
|
return ""
|
||||||
else:
|
else:
|
||||||
regex = cleanup_regex[include]
|
regex = cleanup_regex[include]
|
||||||
matches = regex.search(text)
|
matches = regex.search(text)
|
||||||
if matches:
|
if matches:
|
||||||
return matches.group(0)
|
return matches.group(0)
|
||||||
else:
|
else:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def find_prev_keyword(sql, n_skip=0):
|
def find_prev_keyword(sql, n_skip=0):
|
||||||
""" Find the last sql keyword in an SQL statement
|
"""Find the last sql keyword in an SQL statement
|
||||||
|
|
||||||
Returns the value of the last keyword, and the text of the query with
|
Returns the value of the last keyword, and the text of the query with
|
||||||
everything after the last keyword stripped
|
everything after the last keyword stripped
|
||||||
"""
|
"""
|
||||||
if not sql.strip():
|
if not sql.strip():
|
||||||
return None, ''
|
return None, ""
|
||||||
|
|
||||||
parsed = sqlparse.parse(sql)[0]
|
parsed = sqlparse.parse(sql)[0]
|
||||||
flattened = list(parsed.flatten())
|
flattened = list(parsed.flatten())
|
||||||
flattened = flattened[:len(flattened) - n_skip]
|
flattened = flattened[: len(flattened) - n_skip]
|
||||||
|
|
||||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
logical_operators = ("AND", "OR", "NOT", "BETWEEN")
|
||||||
|
|
||||||
for t in reversed(flattened):
|
for t in reversed(flattened):
|
||||||
if t.value == '(' or (t.is_keyword and (
|
if t.value == "(" or (
|
||||||
t.value.upper() not in logical_operators)):
|
t.is_keyword and (t.value.upper() not in logical_operators)
|
||||||
|
):
|
||||||
# Find the location of token t in the original parsed statement
|
# Find the location of token t in the original parsed statement
|
||||||
# We can't use parsed.token_index(t) because t may be a child token
|
# We can't use parsed.token_index(t) because t may be a child token
|
||||||
# inside a TokenList, in which case token_index thows an error
|
# inside a TokenList, in which case token_index throws an error
|
||||||
# Minimal example:
|
# Minimal example:
|
||||||
# p = sqlparse.parse('select * from foo where bar')
|
# p = sqlparse.parse('select * from foo where bar')
|
||||||
# t = list(p.flatten())[-3] # The "Where" token
|
# t = list(p.flatten())[-3] # The "Where" token
|
||||||
@ -93,14 +94,14 @@ def find_prev_keyword(sql, n_skip=0):
|
|||||||
# Combine the string values of all tokens in the original list
|
# Combine the string values of all tokens in the original list
|
||||||
# up to and including the target keyword token t, to produce a
|
# up to and including the target keyword token t, to produce a
|
||||||
# query string with everything after the keyword token removed
|
# query string with everything after the keyword token removed
|
||||||
text = ''.join(tok.value for tok in flattened[:idx + 1])
|
text = "".join(tok.value for tok in flattened[: idx + 1])
|
||||||
return t, text
|
return t, text
|
||||||
|
|
||||||
return None, ''
|
return None, ""
|
||||||
|
|
||||||
|
|
||||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
|
||||||
|
|
||||||
|
|
||||||
def is_open_quote(sql):
|
def is_open_quote(sql):
|
||||||
|
@ -4,13 +4,13 @@ from sqlparse.tokens import Name
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
white_space_regex = re.compile("\\s+", re.MULTILINE)
|
||||||
|
|
||||||
|
|
||||||
def _compile_regex(keyword):
|
def _compile_regex(keyword):
|
||||||
# Surround the keyword with word boundaries and replace interior whitespace
|
# Surround the keyword with word boundaries and replace interior whitespace
|
||||||
# with whitespace wildcards
|
# with whitespace wildcards
|
||||||
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b'
|
pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
|
||||||
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,28 +3,29 @@ import re
|
|||||||
import sqlparse
|
import sqlparse
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from sqlparse.sql import Comparison, Identifier, Where
|
from sqlparse.sql import Comparison, Identifier, Where
|
||||||
from .parseutils.utils import (
|
from .parseutils.utils import last_word, find_prev_keyword,\
|
||||||
last_word, find_prev_keyword, parse_partial_identifier)
|
parse_partial_identifier
|
||||||
from .parseutils.tables import extract_tables
|
from .parseutils.tables import extract_tables
|
||||||
from .parseutils.ctes import isolate_query_ctes
|
from .parseutils.ctes import isolate_query_ctes
|
||||||
|
|
||||||
Special = namedtuple('Special', [])
|
|
||||||
Database = namedtuple('Database', [])
|
Special = namedtuple("Special", [])
|
||||||
Schema = namedtuple('Schema', ['quoted'])
|
Database = namedtuple("Database", [])
|
||||||
|
Schema = namedtuple("Schema", ["quoted"])
|
||||||
Schema.__new__.__defaults__ = (False,)
|
Schema.__new__.__defaults__ = (False,)
|
||||||
# FromClauseItem is a table/view/function used in the FROM clause
|
# FromClauseItem is a table/view/function used in the FROM clause
|
||||||
# `table_refs` contains the list of tables/... already in the statement,
|
# `table_refs` contains the list of tables/... already in the statement,
|
||||||
# used to ensure that the alias we suggest is unique
|
# used to ensure that the alias we suggest is unique
|
||||||
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
|
FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
|
||||||
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
|
Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
|
||||||
TableFormat = namedtuple('TableFormat', [])
|
TableFormat = namedtuple("TableFormat", [])
|
||||||
View = namedtuple('View', ['schema', 'table_refs'])
|
View = namedtuple("View", ["schema", "table_refs"])
|
||||||
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
|
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
|
||||||
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
|
JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
|
||||||
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
|
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
|
||||||
Join = namedtuple('Join', ['table_refs', 'schema'])
|
Join = namedtuple("Join", ["table_refs", "schema"])
|
||||||
|
|
||||||
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
|
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
|
||||||
# For convenience, don't require the `usage` argument in Function constructor
|
# For convenience, don't require the `usage` argument in Function constructor
|
||||||
Function.__new__.__defaults__ = (None, tuple(), None)
|
Function.__new__.__defaults__ = (None, tuple(), None)
|
||||||
Table.__new__.__defaults__ = (None, tuple(), tuple())
|
Table.__new__.__defaults__ = (None, tuple(), tuple())
|
||||||
@ -32,31 +33,33 @@ View.__new__.__defaults__ = (None, tuple())
|
|||||||
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
|
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
|
||||||
|
|
||||||
Column = namedtuple(
|
Column = namedtuple(
|
||||||
'Column',
|
"Column",
|
||||||
['table_refs', 'require_last_table', 'local_tables',
|
["table_refs", "require_last_table", "local_tables", "qualifiable",
|
||||||
'qualifiable', 'context']
|
"context"],
|
||||||
)
|
)
|
||||||
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
|
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
|
||||||
|
|
||||||
Keyword = namedtuple('Keyword', ['last_token'])
|
Keyword = namedtuple("Keyword", ["last_token"])
|
||||||
Keyword.__new__.__defaults__ = (None,)
|
Keyword.__new__.__defaults__ = (None,)
|
||||||
NamedQuery = namedtuple('NamedQuery', [])
|
NamedQuery = namedtuple("NamedQuery", [])
|
||||||
Datatype = namedtuple('Datatype', ['schema'])
|
Datatype = namedtuple("Datatype", ["schema"])
|
||||||
Alias = namedtuple('Alias', ['aliases'])
|
Alias = namedtuple("Alias", ["aliases"])
|
||||||
|
|
||||||
Path = namedtuple('Path', [])
|
Path = namedtuple("Path", [])
|
||||||
|
|
||||||
|
|
||||||
class SqlStatement(object):
|
class SqlStatement(object):
|
||||||
def __init__(self, full_text, text_before_cursor):
|
def __init__(self, full_text, text_before_cursor):
|
||||||
self.identifier = None
|
self.identifier = None
|
||||||
self.word_before_cursor = word_before_cursor = last_word(
|
self.word_before_cursor = word_before_cursor = last_word(
|
||||||
text_before_cursor, include='many_punctuations')
|
text_before_cursor, include="many_punctuations"
|
||||||
|
)
|
||||||
full_text = _strip_named_query(full_text)
|
full_text = _strip_named_query(full_text)
|
||||||
text_before_cursor = _strip_named_query(text_before_cursor)
|
text_before_cursor = _strip_named_query(text_before_cursor)
|
||||||
|
|
||||||
full_text, text_before_cursor, self.local_tables = \
|
full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
|
||||||
isolate_query_ctes(full_text, text_before_cursor)
|
full_text, text_before_cursor
|
||||||
|
)
|
||||||
|
|
||||||
self.text_before_cursor_including_last_word = text_before_cursor
|
self.text_before_cursor_including_last_word = text_before_cursor
|
||||||
|
|
||||||
@ -67,39 +70,41 @@ class SqlStatement(object):
|
|||||||
# completion useless because it will always return the list of
|
# completion useless because it will always return the list of
|
||||||
# keywords as completion.
|
# keywords as completion.
|
||||||
if self.word_before_cursor:
|
if self.word_before_cursor:
|
||||||
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
|
if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
|
||||||
parsed = sqlparse.parse(text_before_cursor)
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
else:
|
else:
|
||||||
text_before_cursor = \
|
text_before_cursor = \
|
||||||
text_before_cursor[:-len(word_before_cursor)]
|
text_before_cursor[: -len(word_before_cursor)]
|
||||||
parsed = sqlparse.parse(text_before_cursor)
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
self.identifier = parse_partial_identifier(word_before_cursor)
|
self.identifier = parse_partial_identifier(word_before_cursor)
|
||||||
else:
|
else:
|
||||||
parsed = sqlparse.parse(text_before_cursor)
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
|
|
||||||
full_text, text_before_cursor, parsed = \
|
full_text, text_before_cursor, parsed = _split_multiple_statements(
|
||||||
_split_multiple_statements(full_text, text_before_cursor, parsed)
|
full_text, text_before_cursor, parsed
|
||||||
|
)
|
||||||
|
|
||||||
self.full_text = full_text
|
self.full_text = full_text
|
||||||
self.text_before_cursor = text_before_cursor
|
self.text_before_cursor = text_before_cursor
|
||||||
self.parsed = parsed
|
self.parsed = parsed
|
||||||
|
|
||||||
self.last_token = \
|
self.last_token = parsed and \
|
||||||
parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
|
parsed.token_prev(len(parsed.tokens))[1] or ""
|
||||||
|
|
||||||
def is_insert(self):
|
def is_insert(self):
|
||||||
return self.parsed.token_first().value.lower() == 'insert'
|
return self.parsed.token_first().value.lower() == "insert"
|
||||||
|
|
||||||
def get_tables(self, scope='full'):
|
def get_tables(self, scope="full"):
|
||||||
""" Gets the tables available in the statement.
|
"""Gets the tables available in the statement.
|
||||||
param `scope:` possible values: 'full', 'insert', 'before'
|
param `scope:` possible values: 'full', 'insert', 'before'
|
||||||
If 'insert', only the first table is returned.
|
If 'insert', only the first table is returned.
|
||||||
If 'before', only tables before the cursor are returned.
|
If 'before', only tables before the cursor are returned.
|
||||||
If not 'insert' and the stmt is an insert, the first table is skipped.
|
If not 'insert' and the stmt is an insert, the first table is skipped.
|
||||||
"""
|
"""
|
||||||
tables = extract_tables(
|
tables = extract_tables(
|
||||||
self.full_text if scope == 'full' else self.text_before_cursor)
|
self.full_text if scope == "full" else self.text_before_cursor
|
||||||
if scope == 'insert':
|
)
|
||||||
|
if scope == "insert":
|
||||||
tables = tables[:1]
|
tables = tables[:1]
|
||||||
elif self.is_insert():
|
elif self.is_insert():
|
||||||
tables = tables[1:]
|
tables = tables[1:]
|
||||||
@ -118,8 +123,9 @@ class SqlStatement(object):
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
def reduce_to_prev_keyword(self, n_skip=0):
|
def reduce_to_prev_keyword(self, n_skip=0):
|
||||||
prev_keyword, self.text_before_cursor = \
|
prev_keyword, self.text_before_cursor = find_prev_keyword(
|
||||||
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
|
self.text_before_cursor, n_skip=n_skip
|
||||||
|
)
|
||||||
return prev_keyword
|
return prev_keyword
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +137,7 @@ def suggest_type(full_text, text_before_cursor):
|
|||||||
A scope for a column category will be a list of tables.
|
A scope for a column category will be a list of tables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if full_text.startswith('\\i '):
|
if full_text.startswith("\\i "):
|
||||||
return (Path(),)
|
return (Path(),)
|
||||||
|
|
||||||
# This is a temporary hack; the exception handling
|
# This is a temporary hack; the exception handling
|
||||||
@ -144,7 +150,7 @@ def suggest_type(full_text, text_before_cursor):
|
|||||||
return suggest_based_on_last_token(stmt.last_token, stmt)
|
return suggest_based_on_last_token(stmt.last_token, stmt)
|
||||||
|
|
||||||
|
|
||||||
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
|
named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
|
||||||
|
|
||||||
|
|
||||||
def _strip_named_query(txt):
|
def _strip_named_query(txt):
|
||||||
@ -155,11 +161,11 @@ def _strip_named_query(txt):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if named_query_regex.match(txt):
|
if named_query_regex.match(txt):
|
||||||
txt = named_query_regex.sub('', txt)
|
txt = named_query_regex.sub("", txt)
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
|
||||||
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
|
function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
|
||||||
|
|
||||||
|
|
||||||
def _find_function_body(text):
|
def _find_function_body(text):
|
||||||
@ -205,12 +211,12 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
|
|||||||
return full_text, text_before_cursor, None
|
return full_text, text_before_cursor, None
|
||||||
|
|
||||||
token2 = None
|
token2 = None
|
||||||
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
|
if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
|
||||||
token1 = statement.token_first()
|
token1 = statement.token_first()
|
||||||
if token1:
|
if token1:
|
||||||
token1_idx = statement.token_index(token1)
|
token1_idx = statement.token_index(token1)
|
||||||
token2 = statement.token_next(token1_idx)[1]
|
token2 = statement.token_next(token1_idx)[1]
|
||||||
if token2 and token2.value.upper() == 'FUNCTION':
|
if token2 and token2.value.upper() == "FUNCTION":
|
||||||
full_text, text_before_cursor, statement = _statement_from_function(
|
full_text, text_before_cursor, statement = _statement_from_function(
|
||||||
full_text, text_before_cursor, statement
|
full_text, text_before_cursor, statement
|
||||||
)
|
)
|
||||||
@ -246,9 +252,9 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
# SELECT Identifier <CURSOR>
|
# SELECT Identifier <CURSOR>
|
||||||
# SELECT foo FROM Identifier <CURSOR>
|
# SELECT foo FROM Identifier <CURSOR>
|
||||||
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
|
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
|
||||||
if prev_keyword and prev_keyword.value == '(':
|
if prev_keyword and prev_keyword.value == "(":
|
||||||
# Suggest datatypes
|
# Suggest datatypes
|
||||||
return suggest_based_on_last_token('type', stmt)
|
return suggest_based_on_last_token("type", stmt)
|
||||||
else:
|
else:
|
||||||
return (Keyword(),)
|
return (Keyword(),)
|
||||||
else:
|
else:
|
||||||
@ -256,7 +262,7 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
return (Keyword(),)
|
return (Keyword(),)
|
||||||
elif token_v.endswith('('):
|
elif token_v.endswith("("):
|
||||||
p = sqlparse.parse(stmt.text_before_cursor)[0]
|
p = sqlparse.parse(stmt.text_before_cursor)[0]
|
||||||
|
|
||||||
if p.tokens and isinstance(p.tokens[-1], Where):
|
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||||
@ -268,10 +274,10 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
# 3 - Subquery expression like "WHERE EXISTS ("
|
# 3 - Subquery expression like "WHERE EXISTS ("
|
||||||
# Suggest keywords, in order to do a subquery
|
# Suggest keywords, in order to do a subquery
|
||||||
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
||||||
# Suggest columns/functions AND keywords. (If we wanted to be
|
# Suggest columns/functions AND keywords. (If we wanted to
|
||||||
# really fancy, we could suggest only array-typed columns)
|
# be really fancy, we could suggest only array-typed columns)
|
||||||
|
|
||||||
column_suggestions = suggest_based_on_last_token('where', stmt)
|
column_suggestions = suggest_based_on_last_token("where", stmt)
|
||||||
|
|
||||||
# Check for a subquery expression (cases 3 & 4)
|
# Check for a subquery expression (cases 3 & 4)
|
||||||
where = p.tokens[-1]
|
where = p.tokens[-1]
|
||||||
@ -282,7 +288,7 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
prev_tok = prev_tok.tokens[-1]
|
prev_tok = prev_tok.tokens[-1]
|
||||||
|
|
||||||
prev_tok = prev_tok.value.lower()
|
prev_tok = prev_tok.value.lower()
|
||||||
if prev_tok == 'exists':
|
if prev_tok == "exists":
|
||||||
return (Keyword(),)
|
return (Keyword(),)
|
||||||
else:
|
else:
|
||||||
return column_suggestions
|
return column_suggestions
|
||||||
@ -292,59 +298,47 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
prev_tok and prev_tok.value and
|
prev_tok and prev_tok.value and
|
||||||
prev_tok.value.lower().split(' ')[-1] == 'using'
|
prev_tok.value.lower().split(" ")[-1] == "using"
|
||||||
):
|
):
|
||||||
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||||
tables = stmt.get_tables('before')
|
tables = stmt.get_tables("before")
|
||||||
|
|
||||||
# suggest columns that are present in more than one table
|
# suggest columns that are present in more than one table
|
||||||
return (Column(table_refs=tables,
|
|
||||||
require_last_table=True,
|
|
||||||
local_tables=stmt.local_tables),)
|
|
||||||
|
|
||||||
# If the lparen is preceeded by a space chances are we're about to
|
|
||||||
# do a sub-select.
|
|
||||||
elif p.token_first().value.lower() == 'select' and \
|
|
||||||
last_word(stmt.text_before_cursor,
|
|
||||||
'all_punctuations').startswith('('):
|
|
||||||
return (Keyword(),)
|
|
||||||
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
|
|
||||||
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
|
|
||||||
return (
|
return (
|
||||||
Column(table_refs=stmt.get_tables('insert'), context='insert'),
|
Column(
|
||||||
|
table_refs=tables,
|
||||||
|
require_last_table=True,
|
||||||
|
local_tables=stmt.local_tables,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif p.token_first().value.lower() == "select":
|
||||||
|
# If the lparen is preceeded by a space chances are we're about to
|
||||||
|
# do a sub-select.
|
||||||
|
if last_word(stmt.text_before_cursor,
|
||||||
|
"all_punctuations").startswith("("):
|
||||||
|
return (Keyword(),)
|
||||||
|
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
|
||||||
|
if prev_prev_tok and prev_prev_tok.normalized == "INTO":
|
||||||
|
return (Column(table_refs=stmt.get_tables("insert"),
|
||||||
|
context="insert"),)
|
||||||
# We're probably in a function argument list
|
# We're probably in a function argument list
|
||||||
return (Column(table_refs=extract_tables(stmt.full_text),
|
return _suggest_expression(token_v, stmt)
|
||||||
local_tables=stmt.local_tables, qualifiable=True),)
|
elif token_v == "set":
|
||||||
elif token_v == 'set':
|
|
||||||
return (Column(table_refs=stmt.get_tables(),
|
return (Column(table_refs=stmt.get_tables(),
|
||||||
local_tables=stmt.local_tables),)
|
local_tables=stmt.local_tables),)
|
||||||
elif token_v in ('select', 'where', 'having', 'by', 'distinct'):
|
elif token_v in ("select", "where", "having", "order by", "distinct"):
|
||||||
# Check for a table alias or schema qualification
|
return _suggest_expression(token_v, stmt)
|
||||||
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
|
elif token_v == "as":
|
||||||
tables = stmt.get_tables()
|
|
||||||
if parent:
|
|
||||||
tables = tuple(t for t in tables if identifies(parent, t))
|
|
||||||
return (Column(table_refs=tables, local_tables=stmt.local_tables),
|
|
||||||
Table(schema=parent),
|
|
||||||
View(schema=parent),
|
|
||||||
Function(schema=parent),)
|
|
||||||
else:
|
|
||||||
return (Column(table_refs=tables, local_tables=stmt.local_tables,
|
|
||||||
qualifiable=True),
|
|
||||||
Function(schema=None),
|
|
||||||
Keyword(token_v.upper()),)
|
|
||||||
elif token_v == 'as':
|
|
||||||
# Don't suggest anything for aliases
|
# Don't suggest anything for aliases
|
||||||
return ()
|
return ()
|
||||||
elif (
|
elif (token_v.endswith("join") and token.is_keyword) or (
|
||||||
(token_v.endswith('join') and token.is_keyword) or
|
token_v in ("copy", "from", "update", "into", "describe", "truncate")
|
||||||
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
|
|
||||||
):
|
):
|
||||||
|
|
||||||
schema = stmt.get_identifier_schema()
|
schema = stmt.get_identifier_schema()
|
||||||
tables = extract_tables(stmt.text_before_cursor)
|
tables = extract_tables(stmt.text_before_cursor)
|
||||||
is_join = token_v.endswith('join') and token.is_keyword
|
is_join = token_v.endswith("join") and token.is_keyword
|
||||||
|
|
||||||
# Suggest tables from either the currently-selected schema or the
|
# Suggest tables from either the currently-selected schema or the
|
||||||
# public schema if no schema has been specified
|
# public schema if no schema has been specified
|
||||||
@ -354,60 +348,77 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
# Suggest schemas
|
# Suggest schemas
|
||||||
suggest.insert(0, Schema())
|
suggest.insert(0, Schema())
|
||||||
|
|
||||||
if token_v == 'from' or is_join:
|
if token_v == "from" or is_join:
|
||||||
suggest.append(FromClauseItem(schema=schema,
|
suggest.append(
|
||||||
table_refs=tables,
|
FromClauseItem(
|
||||||
local_tables=stmt.local_tables))
|
schema=schema, table_refs=tables,
|
||||||
elif token_v == 'truncate':
|
local_tables=stmt.local_tables
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif token_v == "truncate":
|
||||||
suggest.append(Table(schema))
|
suggest.append(Table(schema))
|
||||||
else:
|
else:
|
||||||
suggest.extend((Table(schema), View(schema)))
|
suggest.extend((Table(schema), View(schema)))
|
||||||
|
|
||||||
if is_join and _allow_join(stmt.parsed):
|
if is_join and _allow_join(stmt.parsed):
|
||||||
tables = stmt.get_tables('before')
|
tables = stmt.get_tables("before")
|
||||||
suggest.append(Join(table_refs=tables, schema=schema))
|
suggest.append(Join(table_refs=tables, schema=schema))
|
||||||
|
|
||||||
return tuple(suggest)
|
return tuple(suggest)
|
||||||
|
|
||||||
elif token_v == 'function':
|
elif token_v == "function":
|
||||||
schema = stmt.get_identifier_schema()
|
schema = stmt.get_identifier_schema()
|
||||||
|
|
||||||
# stmt.get_previous_token will fail for e.g.
|
# stmt.get_previous_token will fail for e.g.
|
||||||
# `SELECT 1 FROM functions WHERE function:`
|
# `SELECT 1 FROM functions WHERE function:`
|
||||||
try:
|
try:
|
||||||
prev = stmt.get_previous_token(token).value.lower()
|
prev = stmt.get_previous_token(token).value.lower()
|
||||||
if prev in ('drop', 'alter', 'create', 'create or replace'):
|
if prev in ("drop", "alter", "create", "create or replace"):
|
||||||
return (Function(schema=schema, usage='signature'),)
|
|
||||||
|
# Suggest functions from either the currently-selected schema
|
||||||
|
# or the public schema if no schema has been specified
|
||||||
|
suggest = []
|
||||||
|
|
||||||
|
if not schema:
|
||||||
|
# Suggest schemas
|
||||||
|
suggest.insert(0, Schema())
|
||||||
|
|
||||||
|
suggest.append(Function(schema=schema, usage="signature"))
|
||||||
|
return tuple(suggest)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
return tuple()
|
return tuple()
|
||||||
|
|
||||||
elif token_v in ('table', 'view'):
|
elif token_v in ("table", "view"):
|
||||||
# E.g. 'ALTER TABLE <tablname>'
|
# E.g. 'ALTER TABLE <tablname>'
|
||||||
rel_type = \
|
rel_type = \
|
||||||
{'table': Table, 'view': View, 'function': Function}[token_v]
|
{"table": Table, "view": View, "function": Function}[token_v]
|
||||||
schema = stmt.get_identifier_schema()
|
schema = stmt.get_identifier_schema()
|
||||||
if schema:
|
if schema:
|
||||||
return (rel_type(schema=schema),)
|
return (rel_type(schema=schema),)
|
||||||
else:
|
else:
|
||||||
return (Schema(), rel_type(schema=schema))
|
return (Schema(), rel_type(schema=schema))
|
||||||
|
|
||||||
elif token_v == 'column':
|
elif token_v == "column":
|
||||||
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
|
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
|
||||||
return (Column(table_refs=stmt.get_tables()),)
|
return (Column(table_refs=stmt.get_tables()),)
|
||||||
|
|
||||||
elif token_v == 'on':
|
elif token_v == "on":
|
||||||
tables = stmt.get_tables('before')
|
tables = stmt.get_tables("before")
|
||||||
parent = \
|
parent = \
|
||||||
(stmt.identifier and stmt.identifier.get_parent_name()) or None
|
(stmt.identifier and stmt.identifier.get_parent_name()) or None
|
||||||
if parent:
|
if parent:
|
||||||
# "ON parent.<suggestion>"
|
# "ON parent.<suggestion>"
|
||||||
# parent can be either a schema name or table alias
|
# parent can be either a schema name or table alias
|
||||||
filteredtables = tuple(t for t in tables if identifies(parent, t))
|
filteredtables = tuple(t for t in tables if identifies(parent, t))
|
||||||
sugs = [Column(table_refs=filteredtables,
|
sugs = [
|
||||||
local_tables=stmt.local_tables),
|
Column(table_refs=filteredtables,
|
||||||
Table(schema=parent),
|
local_tables=stmt.local_tables),
|
||||||
View(schema=parent),
|
Table(schema=parent),
|
||||||
Function(schema=parent)]
|
View(schema=parent),
|
||||||
|
Function(schema=parent),
|
||||||
|
]
|
||||||
if filteredtables and _allow_join_condition(stmt.parsed):
|
if filteredtables and _allow_join_condition(stmt.parsed):
|
||||||
sugs.append(JoinCondition(table_refs=tables,
|
sugs.append(JoinCondition(table_refs=tables,
|
||||||
parent=filteredtables[-1]))
|
parent=filteredtables[-1]))
|
||||||
@ -417,38 +428,39 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
# Use table alias if there is one, otherwise the table name
|
# Use table alias if there is one, otherwise the table name
|
||||||
aliases = tuple(t.ref for t in tables)
|
aliases = tuple(t.ref for t in tables)
|
||||||
if _allow_join_condition(stmt.parsed):
|
if _allow_join_condition(stmt.parsed):
|
||||||
return (Alias(aliases=aliases), JoinCondition(
|
return (
|
||||||
table_refs=tables, parent=None))
|
Alias(aliases=aliases),
|
||||||
|
JoinCondition(table_refs=tables, parent=None),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return (Alias(aliases=aliases),)
|
return (Alias(aliases=aliases),)
|
||||||
|
|
||||||
elif token_v in ('c', 'use', 'database', 'template'):
|
elif token_v in ("c", "use", "database", "template"):
|
||||||
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||||
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||||
return (Database(),)
|
return (Database(),)
|
||||||
elif token_v == 'schema':
|
elif token_v == "schema":
|
||||||
# DROP SCHEMA schema_name, SET SCHEMA schema name
|
# DROP SCHEMA schema_name, SET SCHEMA schema name
|
||||||
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
|
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
|
||||||
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
|
quoted = prev_keyword and prev_keyword.value.lower() == "set"
|
||||||
return (Schema(quoted),)
|
return (Schema(quoted),)
|
||||||
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
|
elif token_v.endswith(",") or token_v in ("=", "and", "or"):
|
||||||
prev_keyword = stmt.reduce_to_prev_keyword()
|
prev_keyword = stmt.reduce_to_prev_keyword()
|
||||||
if prev_keyword:
|
if prev_keyword:
|
||||||
return suggest_based_on_last_token(prev_keyword, stmt)
|
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||||
else:
|
else:
|
||||||
return ()
|
return ()
|
||||||
elif token_v in ('type', '::'):
|
elif token_v in ("type", "::"):
|
||||||
# ALTER TABLE foo SET DATA TYPE bar
|
# ALTER TABLE foo SET DATA TYPE bar
|
||||||
# SELECT foo::bar
|
# SELECT foo::bar
|
||||||
# Note that tables are a form of composite type in postgresql, so
|
# Note that tables are a form of composite type in postgresql, so
|
||||||
# they're suggested here as well
|
# they're suggested here as well
|
||||||
schema = stmt.get_identifier_schema()
|
schema = stmt.get_identifier_schema()
|
||||||
suggestions = [Datatype(schema=schema),
|
suggestions = [Datatype(schema=schema), Table(schema=schema)]
|
||||||
Table(schema=schema)]
|
|
||||||
if not schema:
|
if not schema:
|
||||||
suggestions.append(Schema())
|
suggestions.append(Schema())
|
||||||
return tuple(suggestions)
|
return tuple(suggestions)
|
||||||
elif token_v in ['alter', 'create', 'drop']:
|
elif token_v in {"alter", "create", "drop"}:
|
||||||
return (Keyword(token_v.upper()),)
|
return (Keyword(token_v.upper()),)
|
||||||
elif token.is_keyword:
|
elif token.is_keyword:
|
||||||
# token is a keyword we haven't implemented any special handling for
|
# token is a keyword we haven't implemented any special handling for
|
||||||
@ -462,11 +474,38 @@ def suggest_based_on_last_token(token, stmt):
|
|||||||
return (Keyword(),)
|
return (Keyword(),)
|
||||||
|
|
||||||
|
|
||||||
|
def _suggest_expression(token_v, stmt):
|
||||||
|
"""
|
||||||
|
Return suggestions for an expression, taking account of any partially-typed
|
||||||
|
identifier's parent, which may be a table alias or schema name.
|
||||||
|
"""
|
||||||
|
parent = stmt.identifier.get_parent_name() if stmt.identifier else []
|
||||||
|
tables = stmt.get_tables()
|
||||||
|
|
||||||
|
if parent:
|
||||||
|
tables = tuple(t for t in tables if identifies(parent, t))
|
||||||
|
return (
|
||||||
|
Column(table_refs=tables, local_tables=stmt.local_tables),
|
||||||
|
Table(schema=parent),
|
||||||
|
View(schema=parent),
|
||||||
|
Function(schema=parent),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
Column(table_refs=tables, local_tables=stmt.local_tables,
|
||||||
|
qualifiable=True),
|
||||||
|
Function(schema=None),
|
||||||
|
Keyword(token_v.upper()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def identifies(id, ref):
|
def identifies(id, ref):
|
||||||
"""Returns true if string `id` matches TableReference `ref`"""
|
"""Returns true if string `id` matches TableReference `ref`"""
|
||||||
|
|
||||||
return id == ref.alias or id == ref.name or (
|
return (
|
||||||
ref.schema and (id == ref.schema + '.' + ref.name))
|
id == ref.alias or id == ref.name or
|
||||||
|
(ref.schema and (id == ref.schema + "." + ref.name))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _allow_join_condition(statement):
|
def _allow_join_condition(statement):
|
||||||
@ -486,7 +525,7 @@ def _allow_join_condition(statement):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
last_tok = statement.token_prev(len(statement.tokens))[1]
|
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||||
return last_tok.value.lower() in ('on', 'and', 'or')
|
return last_tok.value.lower() in ("on", "and", "or")
|
||||||
|
|
||||||
|
|
||||||
def _allow_join(statement):
|
def _allow_join(statement):
|
||||||
@ -505,7 +544,5 @@ def _allow_join(statement):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
last_tok = statement.token_prev(len(statement.tokens))[1]
|
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||||
return (
|
return last_tok.value.lower().endswith("join") and \
|
||||||
last_tok.value.lower().endswith('join') and
|
last_tok.value.lower() not in ("cross join", "natural join",)
|
||||||
last_tok.value.lower() not in ('cross join', 'natural join')
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user