mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-01-27 00:36:52 -06:00
1210 lines
47 KiB
Python
1210 lines
47 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2020, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
"""A blueprint module implementing the sql auto complete feature."""
|
|
|
|
import re
|
|
import operator
|
|
import sys
|
|
from itertools import count, repeat, chain
|
|
from .completion import Completion
|
|
from collections import namedtuple, defaultdict, OrderedDict
|
|
|
|
from .sqlcompletion import (
|
|
FromClauseItem, suggest_type, Database, Schema, Table,
|
|
Function, Column, View, Keyword, Datatype, Alias, JoinCondition, Join)
|
|
from .parseutils.meta import FunctionMetadata, ColumnMetadata, ForeignKey
|
|
from .parseutils.utils import last_word
|
|
from .parseutils.tables import TableReference
|
|
from .prioritization import PrevalenceCounter
|
|
from flask import render_template
|
|
from pgadmin.utils.driver import get_driver
|
|
from config import PG_DEFAULT_DRIVER
|
|
from pgadmin.utils.preferences import Preferences
|
|
|
|
Match = namedtuple('Match', ['completion', 'priority'])
|
|
|
|
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
|
|
|
|
|
|
def SchemaObject(name, schema=None, meta=None):
|
|
return _SchemaObject(name, schema, meta)
|
|
|
|
|
|
# Regex for finding "words" in documents.
|
|
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
|
|
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
|
|
|
|
_Candidate = namedtuple(
|
|
'Candidate', 'completion prio meta synonyms prio2 display'
|
|
)
|
|
|
|
|
|
def Candidate(
|
|
completion, prio=None, meta=None, synonyms=None, prio2=None,
|
|
display=None
|
|
):
|
|
return _Candidate(
|
|
completion, prio, meta, synonyms or [completion], prio2,
|
|
display or completion
|
|
)
|
|
|
|
|
|
# Used to strip trailing '::some_type' from default-value expressions
|
|
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
|
|
|
|
|
|
def normalize_ref(ref):
|
|
return ref if ref[0] == '"' else '"' + ref.lower() + '"'
|
|
|
|
|
|
def generate_alias(tbl):
|
|
""" Generate a table alias, consisting of all upper-case letters in
|
|
the table name, or, if there are no upper-case letters, the first letter +
|
|
all letters preceded by _
|
|
param tbl - unescaped name of the table to alias
|
|
"""
|
|
return ''.join(
|
|
[letter for letter in tbl if letter.isupper()] or
|
|
[letter for letter, prev in zip(tbl, '_' + tbl)
|
|
if prev == '_' and letter != '_']
|
|
)
|
|
|
|
|
|
class SQLAutoComplete(object):
|
|
"""
|
|
class SQLAutoComplete
|
|
|
|
This class is used to provide the postgresql's autocomplete feature.
|
|
This class used sqlparse to parse the given sql and psycopg2 to make
|
|
the connection and get the tables, schemas, functions etc. based on
|
|
the query.
|
|
"""
|
|
def __init__(self, **kwargs):
|
|
"""
|
|
This method is used to initialize the class.
|
|
|
|
Args:
|
|
**kwargs : N number of parameters
|
|
"""
|
|
|
|
self.sid = kwargs['sid'] if 'sid' in kwargs else None
|
|
self.conn = kwargs['conn'] if 'conn' in kwargs else None
|
|
self.keywords = []
|
|
self.databases = []
|
|
self.functions = []
|
|
self.datatypes = []
|
|
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
|
|
'datatypes': {}}
|
|
self.text_before_cursor = None
|
|
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
|
|
|
|
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
|
|
|
|
# we will set template path for sql scripts
|
|
self.sql_path = 'sqlautocomplete/sql/#{0}#'.format(manager.version)
|
|
|
|
self.search_path = []
|
|
schema_names = []
|
|
if self.conn.connected():
|
|
# Fetch the search path
|
|
query = render_template(
|
|
"/".join([self.sql_path, 'schema.sql']), search_path=True)
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for record in res['rows']:
|
|
self.search_path.append(record['schema'])
|
|
|
|
# Fetch the schema names
|
|
query = render_template("/".join([self.sql_path, 'schema.sql']))
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for record in res['rows']:
|
|
schema_names.append(record['schema'])
|
|
|
|
pref = Preferences.module('sqleditor')
|
|
keywords_in_uppercase = \
|
|
pref.preference('keywords_in_uppercase').get()
|
|
|
|
# Fetch the keywords
|
|
query = render_template("/".join([self.sql_path, 'keywords.sql']))
|
|
# If setting 'Keywords in uppercase' is set to True in
|
|
# Preferences then fetch the keywords in upper case.
|
|
if keywords_in_uppercase:
|
|
query = render_template(
|
|
"/".join([self.sql_path, 'keywords.sql']), upper_case=True)
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for record in res['rows']:
|
|
# 'public' is a keyword in EPAS database server. Don't add
|
|
# this into the list of keywords.
|
|
# This is a hack to fix the issue in autocomplete.
|
|
if record['word'].lower() == 'public':
|
|
continue
|
|
self.keywords.append(record['word'])
|
|
|
|
self.prioritizer = PrevalenceCounter(self.keywords)
|
|
|
|
self.reserved_words = set()
|
|
for x in self.keywords:
|
|
self.reserved_words.update(x.split())
|
|
|
|
self.all_completions = set(self.keywords)
|
|
self.extend_schemata(schema_names)
|
|
|
|
# Below are the configurable options in pgcli which we don't have
|
|
# in pgAdmin4 at the moment. Setting the default value from the pgcli's
|
|
# config file.
|
|
self.signature_arg_style = '{arg_name} {arg_type}'
|
|
self.call_arg_style = '{arg_name: <{max_arg_len}} := {arg_default}'
|
|
self.call_arg_display_style = '{arg_name}'
|
|
self.call_arg_oneliner_max = 2
|
|
self.search_path_filter = True
|
|
self.generate_aliases = False
|
|
self.insert_col_skip_patterns = [
|
|
re.compile(r'^now\(\)$'),
|
|
re.compile(r'^nextval\(')]
|
|
self.qualify_columns = 'if_more_than_one_table'
|
|
self.asterisk_column_order = 'table_order'
|
|
|
|
def escape_name(self, name):
|
|
if name and (
|
|
(not self.name_pattern.match(name)) or
|
|
(name.upper() in self.reserved_words)
|
|
):
|
|
name = '"%s"' % name
|
|
|
|
return name
|
|
|
|
def escape_schema(self, name):
|
|
return "'{}'".format(self.unescape_name(name))
|
|
|
|
def unescape_name(self, name):
|
|
""" Unquote a string."""
|
|
if name and name[0] == '"' and name[-1] == '"':
|
|
name = name[1:-1]
|
|
|
|
return name
|
|
|
|
def escaped_names(self, names):
|
|
return [self.escape_name(name) for name in names]
|
|
|
|
def extend_database_names(self, databases):
|
|
self.databases.extend(databases)
|
|
|
|
def extend_keywords(self, additional_keywords):
|
|
self.keywords.extend(additional_keywords)
|
|
self.all_completions.update(additional_keywords)
|
|
|
|
def extend_schemata(self, schemata):
|
|
|
|
# schemata is a list of schema names
|
|
schemata = self.escaped_names(schemata)
|
|
metadata = self.dbmetadata['tables']
|
|
for schema in schemata:
|
|
metadata[schema] = {}
|
|
|
|
# dbmetadata.values() are the 'tables' and 'functions' dicts
|
|
for metadata in self.dbmetadata.values():
|
|
for schema in schemata:
|
|
metadata[schema] = {}
|
|
|
|
self.all_completions.update(schemata)
|
|
|
|
def extend_casing(self, words):
|
|
""" extend casing data
|
|
|
|
:return:
|
|
"""
|
|
# casing should be a dict {lowercasename:PreferredCasingName}
|
|
self.casing = dict((word.lower(), word) for word in words)
|
|
|
|
def extend_relations(self, data, kind):
|
|
"""extend metadata for tables or views.
|
|
|
|
:param data: list of (schema_name, rel_name) tuples
|
|
:param kind: either 'tables' or 'views'
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
data = [self.escaped_names(d) for d in data]
|
|
|
|
# dbmetadata['tables']['schema_name']['table_name'] should be an
|
|
# OrderedDict {column_name:ColumnMetaData}.
|
|
metadata = self.dbmetadata[kind]
|
|
for schema, relname in data:
|
|
try:
|
|
metadata[schema][relname] = OrderedDict()
|
|
except KeyError:
|
|
print('%r %r listed in unrecognized schema %r',
|
|
kind, relname, schema)
|
|
|
|
self.all_completions.add(relname)
|
|
|
|
def extend_columns(self, column_data, kind):
|
|
"""extend column metadata.
|
|
|
|
:param column_data: list of (schema_name, rel_name, column_name,
|
|
column_type, has_default, default) tuples
|
|
:param kind: either 'tables' or 'views'
|
|
|
|
:return:
|
|
|
|
"""
|
|
metadata = self.dbmetadata[kind]
|
|
for schema, relname, colname, datatype, \
|
|
has_default, default in column_data:
|
|
(schema, relname, colname) = self.escaped_names(
|
|
[schema, relname, colname])
|
|
column = ColumnMetadata(
|
|
name=colname,
|
|
datatype=datatype,
|
|
has_default=has_default,
|
|
default=default
|
|
)
|
|
metadata[schema][relname][colname] = column
|
|
self.all_completions.add(colname)
|
|
|
|
def extend_functions(self, func_data):
|
|
|
|
# func_data is a list of function metadata namedtuples
|
|
|
|
# dbmetadata['schema_name']['functions']['function_name'] should return
|
|
# the function metadata namedtuple for the corresponding function
|
|
metadata = self.dbmetadata['functions']
|
|
|
|
for f in func_data:
|
|
schema, func = self.escaped_names([f.schema_name, f.func_name])
|
|
|
|
if func in metadata[schema]:
|
|
metadata[schema][func].append(f)
|
|
else:
|
|
metadata[schema][func] = [f]
|
|
|
|
self.all_completions.add(func)
|
|
|
|
self._refresh_arg_list_cache()
|
|
|
|
def _refresh_arg_list_cache(self):
|
|
# We keep a cache of
|
|
# {function_usage:{function_metadata: function_arg_list_string}}
|
|
# This is used when suggesting functions, to avoid the latency that
|
|
# would result if we'd recalculate the arg lists each time we suggest
|
|
# functions (in large DBs)
|
|
|
|
self._arg_list_cache = \
|
|
dict((usage,
|
|
dict((meta, self._arg_list(meta, usage))
|
|
for sch, funcs in self.dbmetadata['functions'].items()
|
|
for func, metas in funcs.items()
|
|
for meta in metas))
|
|
for usage in ('call', 'call_display', 'signature'))
|
|
|
|
def extend_foreignkeys(self, fk_data):
|
|
|
|
# fk_data is a list of ForeignKey namedtuples, with fields
|
|
# parentschema, childschema, parenttable, childtable,
|
|
# parentcolumns, childcolumns
|
|
|
|
# These are added as a list of ForeignKey namedtuples to the
|
|
# ColumnMetadata namedtuple for both the child and parent
|
|
meta = self.dbmetadata['tables']
|
|
|
|
for fk in fk_data:
|
|
e = self.escaped_names
|
|
parentschema, childschema = e([fk.parentschema, fk.childschema])
|
|
parenttable, childtable = e([fk.parenttable, fk.childtable])
|
|
childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
|
|
|
|
if childtable not in meta[childschema] or \
|
|
parenttable not in meta[parentschema] or \
|
|
childcol not in meta[childschema][childtable] or \
|
|
parcol not in meta[parentschema][parenttable]:
|
|
continue
|
|
|
|
childcolmeta = meta[childschema][childtable][childcol]
|
|
parcolmeta = meta[parentschema][parenttable][parcol]
|
|
fk = ForeignKey(
|
|
parentschema, parenttable, parcol,
|
|
childschema, childtable, childcol
|
|
)
|
|
childcolmeta.foreignkeys.append((fk))
|
|
parcolmeta.foreignkeys.append((fk))
|
|
|
|
def extend_datatypes(self, type_data):
|
|
|
|
# dbmetadata['datatypes'][schema_name][type_name] should store type
|
|
# metadata, such as composite type field names. Currently, we're not
|
|
# storing any metadata beyond typename, so just store None
|
|
meta = self.dbmetadata['datatypes']
|
|
|
|
for t in type_data:
|
|
schema, type_name = self.escaped_names(t)
|
|
meta[schema][type_name] = None
|
|
self.all_completions.add(type_name)
|
|
|
|
def set_search_path(self, search_path):
|
|
self.search_path = self.escaped_names(search_path)
|
|
|
|
def reset_completions(self):
|
|
self.databases = []
|
|
self.special_commands = []
|
|
self.search_path = []
|
|
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
|
|
'datatypes': {}}
|
|
self.all_completions = set(self.keywords + self.functions)
|
|
|
|
def find_matches(self, text, collection, mode='fuzzy', meta=None):
|
|
"""Find completion matches for the given text.
|
|
|
|
Given the user's input text and a collection of available
|
|
completions, find completions matching the last word of the
|
|
text.
|
|
|
|
`collection` can be either a list of strings or a list of Candidate
|
|
namedtuples.
|
|
`mode` can be either 'fuzzy', or 'strict'
|
|
'fuzzy': fuzzy matching, ties broken by name prevalance
|
|
`keyword`: start only matching, ties broken by keyword prevalance
|
|
|
|
yields prompt_toolkit Completion instances for any matches found
|
|
in the collection of available completions.
|
|
|
|
Args:
|
|
text:
|
|
collection:
|
|
mode:
|
|
meta:
|
|
meta_collection:
|
|
"""
|
|
if not collection:
|
|
return []
|
|
prio_order = [
|
|
'keyword', 'function', 'view', 'table', 'datatype', 'database',
|
|
'schema', 'column', 'table alias', 'join', 'name join', 'fk join',
|
|
'table format'
|
|
]
|
|
type_priority = prio_order.index(meta) if meta in prio_order else -1
|
|
text = last_word(text, include='most_punctuations').lower()
|
|
text_len = len(text)
|
|
|
|
if text and text[0] == '"':
|
|
# text starts with double quote; user is manually escaping a name
|
|
# Match on everything that follows the double-quote. Note that
|
|
# text_len is calculated before removing the quote, so the
|
|
# Completion.position value is correct
|
|
text = text[1:]
|
|
|
|
if mode == 'fuzzy':
|
|
fuzzy = True
|
|
priority_func = self.prioritizer.name_count
|
|
else:
|
|
fuzzy = False
|
|
priority_func = self.prioritizer.keyword_count
|
|
|
|
# Construct a `_match` function for either fuzzy or non-fuzzy matching
|
|
# The match function returns a 2-tuple used for sorting the matches,
|
|
# or None if the item doesn't match
|
|
# Note: higher priority values mean more important, so use negative
|
|
# signs to flip the direction of the tuple
|
|
if fuzzy:
|
|
regex = '.*?'.join(map(re.escape, text))
|
|
pat = re.compile('(%s)' % regex)
|
|
|
|
def _match(item):
|
|
if item.lower()[:len(text) + 1] in (text, text + ' '):
|
|
# Exact match of first word in suggestion
|
|
# This is to get exact alias matches to the top
|
|
# E.g. for input `e`, 'Entries E' should be on top
|
|
# (before e.g. `EndUsers EU`)
|
|
return float('Infinity'), -1
|
|
r = pat.search(self.unescape_name(item.lower()))
|
|
if r:
|
|
return -len(r.group()), -r.start()
|
|
else:
|
|
match_end_limit = len(text)
|
|
|
|
def _match(item):
|
|
# Text starts with double quote; Remove quoting and
|
|
# match on everything that follows the double-quote.
|
|
item = self.unescape_name(item.lower())
|
|
match_point = item.find(text, 0, match_end_limit)
|
|
if match_point >= 0:
|
|
# Use negative infinity to force keywords to sort after all
|
|
# fuzzy matches
|
|
return -float('Infinity'), -match_point
|
|
|
|
matches = []
|
|
for cand in collection:
|
|
if isinstance(cand, _Candidate):
|
|
item, prio, display_meta, synonyms, prio2, display = cand
|
|
if display_meta is None:
|
|
display_meta = meta
|
|
syn_matches = (_match(x) for x in synonyms)
|
|
# Nones need to be removed to avoid max() crashing in Python 3
|
|
syn_matches = [m for m in syn_matches if m]
|
|
sort_key = max(syn_matches) if syn_matches else None
|
|
else:
|
|
item, display_meta, prio, prio2, display = \
|
|
cand, meta, 0, 0, cand
|
|
sort_key = _match(cand)
|
|
|
|
if sort_key:
|
|
if display_meta and len(display_meta) > 50:
|
|
# Truncate meta-text to 50 characters, if necessary
|
|
display_meta = display_meta[:47] + u'...'
|
|
|
|
# Lexical order of items in the collection, used for
|
|
# tiebreaking items with the same match group length and start
|
|
# position. Since we use *higher* priority to mean "more
|
|
# important," we use -ord(c) to prioritize "aa" > "ab" and end
|
|
# with 1 to prioritize shorter strings (ie "user" > "users").
|
|
# We first do a case-insensitive sort and then a
|
|
# case-sensitive one as a tie breaker.
|
|
# We also use the unescape_name to make sure quoted names have
|
|
# the same priority as unquoted names.
|
|
lexical_priority = (
|
|
tuple(0 if c in(' _') else -ord(c)
|
|
for c in self.unescape_name(item.lower())) + (1,) +
|
|
tuple(c for c in item)
|
|
)
|
|
|
|
priority = (
|
|
sort_key, type_priority, prio, priority_func(item),
|
|
prio2, lexical_priority
|
|
)
|
|
matches.append(
|
|
Match(
|
|
completion=Completion(
|
|
text=item,
|
|
start_position=-text_len,
|
|
display_meta=display_meta,
|
|
display=display
|
|
),
|
|
priority=priority
|
|
)
|
|
)
|
|
return matches
|
|
|
|
def get_completions(self, text, text_before_cursor):
|
|
self.text_before_cursor = text_before_cursor
|
|
|
|
word_before_cursor = self.get_word_before_cursor(word=True)
|
|
matches = []
|
|
suggestions = suggest_type(text, text_before_cursor)
|
|
|
|
for suggestion in suggestions:
|
|
suggestion_type = type(suggestion)
|
|
|
|
# Map suggestion type to method
|
|
# e.g. 'table' -> self.get_table_matches
|
|
matcher = self.suggestion_matchers[suggestion_type]
|
|
matches.extend(matcher(self, suggestion, word_before_cursor))
|
|
|
|
# Sort matches so highest priorities are first
|
|
matches = sorted(matches, key=operator.attrgetter('priority'),
|
|
reverse=True)
|
|
|
|
result = dict()
|
|
for m in matches:
|
|
name = m.completion.display
|
|
result[name] = {'object_type': m.completion.display_meta}
|
|
|
|
return result
|
|
|
|
def get_column_matches(self, suggestion, word_before_cursor):
|
|
schema = None
|
|
if len(suggestion.table_refs) > 0 and \
|
|
hasattr(suggestion.table_refs[0], 'schema') and \
|
|
suggestion.table_refs[0].schema != '':
|
|
schema = suggestion.table_refs[0].schema
|
|
|
|
# Tables and Views should be populated first.
|
|
self.fetch_schema_objects(schema, 'tables')
|
|
self.fetch_schema_objects(schema, 'views')
|
|
|
|
tables = suggestion.table_refs
|
|
do_qualify = suggestion.qualifiable and {
|
|
'always': True, 'never': False,
|
|
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns]
|
|
|
|
def qualify(col, tbl):
|
|
return (tbl + '.' + col) if do_qualify else col
|
|
|
|
scoped_cols = self.populate_scoped_cols(
|
|
tables, suggestion.local_tables
|
|
)
|
|
|
|
def make_cand(name, ref):
|
|
synonyms = (name, generate_alias(name))
|
|
return Candidate(qualify(name, ref), 0, 'column', synonyms)
|
|
|
|
def flat_cols():
|
|
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items()
|
|
for c in cols]
|
|
if suggestion.require_last_table:
|
|
# require_last_table is used for 'tb11 JOIN tbl2 USING
|
|
# (...' which should
|
|
# suggest only columns that appear in the last table and one more
|
|
ltbl = tables[-1].ref
|
|
other_tbl_cols = set(
|
|
c.name for t, cs in scoped_cols.items() if t.ref != ltbl
|
|
for c in cs
|
|
)
|
|
scoped_cols = \
|
|
dict((t, [col for col in cols if col.name in other_tbl_cols])
|
|
for t, cols in scoped_cols.items() if t.ref == ltbl)
|
|
|
|
lastword = last_word(word_before_cursor, include='most_punctuations')
|
|
if lastword == '*':
|
|
if suggestion.context == 'insert':
|
|
def filter(col):
|
|
if not col.has_default:
|
|
return True
|
|
return not any(
|
|
p.match(col.default)
|
|
for p in self.insert_col_skip_patterns
|
|
)
|
|
scoped_cols = \
|
|
dict((t, [col for col in cols if filter(col)])
|
|
for t, cols in scoped_cols.items())
|
|
if self.asterisk_column_order == 'alphabetic':
|
|
for cols in scoped_cols.values():
|
|
cols.sort(key=operator.attrgetter('name'))
|
|
if (
|
|
lastword != word_before_cursor and
|
|
len(tables) == 1 and
|
|
word_before_cursor[-len(lastword) - 1] == '.'
|
|
):
|
|
# User typed x.*; replicate "x." for all columns except the
|
|
# first, which gets the original (as we only replace the "*"")
|
|
sep = ', ' + word_before_cursor[:-1]
|
|
collist = sep.join(c.completion for c in flat_cols())
|
|
else:
|
|
collist = ', '.join(qualify(c.name, t.ref)
|
|
for t, cs in scoped_cols.items()
|
|
for c in cs)
|
|
|
|
return [Match(
|
|
completion=Completion(
|
|
collist,
|
|
-1,
|
|
display_meta='columns',
|
|
display='*'
|
|
),
|
|
priority=(1, 1, 1)
|
|
)]
|
|
|
|
return self.find_matches(word_before_cursor, flat_cols(),
|
|
mode='strict', meta='column')
|
|
|
|
def alias(self, tbl, tbls):
|
|
""" Generate a unique table alias
|
|
tbl - name of the table to alias, quoted if it needs to be
|
|
tbls - TableReference iterable of tables already in query
|
|
"""
|
|
tbls = set(normalize_ref(t.ref) for t in tbls)
|
|
if self.generate_aliases:
|
|
tbl = generate_alias(self.unescape_name(tbl))
|
|
if normalize_ref(tbl) not in tbls:
|
|
return tbl
|
|
elif tbl[0] == '"':
|
|
aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
|
|
else:
|
|
aliases = (tbl + str(i) for i in count(2))
|
|
return next(a for a in aliases if normalize_ref(a) not in tbls)
|
|
|
|
def get_join_matches(self, suggestion, word_before_cursor):
|
|
tbls = suggestion.table_refs
|
|
cols = self.populate_scoped_cols(tbls)
|
|
# Set up some data structures for efficient access
|
|
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
|
|
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
|
|
refs = set(normalize_ref(t.ref) for t in tbls)
|
|
other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
|
|
joins = []
|
|
# Iterate over FKs in existing tables to find potential joins
|
|
fks = (
|
|
(fk, rtbl, rcol) for rtbl, rcols in cols.items()
|
|
for rcol in rcols for fk in rcol.foreignkeys
|
|
)
|
|
col = namedtuple('col', 'schema tbl col')
|
|
for fk, rtbl, rcol in fks:
|
|
right = col(rtbl.schema, rtbl.name, rcol.name)
|
|
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
|
parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
|
left = child if parent == right else parent
|
|
if suggestion.schema and left.schema != suggestion.schema:
|
|
continue
|
|
if self.generate_aliases or normalize_ref(left.tbl) in refs:
|
|
lref = self.alias(left.tbl, suggestion.table_refs)
|
|
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
|
|
left.tbl, left.col, rtbl.ref, right.col, lref)
|
|
else:
|
|
join = '{0} ON {0}.{1} = {2}.{3}'.format(
|
|
left.tbl, left.col, rtbl.ref, right.col)
|
|
alias = generate_alias(left.tbl)
|
|
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format(
|
|
alias, left.col, rtbl.ref, right.col)]
|
|
# Schema-qualify if (1) new table in same schema as old, and old
|
|
# is schema-qualified, or (2) new in other schema, except public
|
|
if not suggestion.schema and \
|
|
(qualified[normalize_ref(rtbl.ref)] and
|
|
left.schema == right.schema or
|
|
left.schema not in(right.schema, 'public')):
|
|
join = left.schema + '.' + join
|
|
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
|
|
0 if (left.schema, left.tbl) in other_tbls else 1)
|
|
joins.append(Candidate(join, prio, 'join', synonyms=synonyms))
|
|
|
|
return self.find_matches(word_before_cursor, joins,
|
|
mode='strict', meta='join')
|
|
|
|
def get_join_condition_matches(self, suggestion, word_before_cursor):
|
|
col = namedtuple('col', 'schema tbl col')
|
|
tbls = self.populate_scoped_cols(suggestion.table_refs).items
|
|
cols = [(t, c) for t, cs in tbls() for c in cs]
|
|
try:
|
|
lref = (suggestion.parent or suggestion.table_refs[-1]).ref
|
|
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
|
|
except IndexError: # The user typed an incorrect table qualifier
|
|
return []
|
|
conds, found_conds = [], set()
|
|
|
|
def add_cond(lcol, rcol, rref, prio, meta):
|
|
prefix = '' if suggestion.parent else ltbl.ref + '.'
|
|
cond = prefix + lcol + ' = ' + rref + '.' + rcol
|
|
if cond not in found_conds:
|
|
found_conds.add(cond)
|
|
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
|
|
|
|
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
|
|
d = defaultdict(list)
|
|
for pair in pairs:
|
|
d[pair[0]].append(pair[1])
|
|
return d
|
|
|
|
# Tables that are closer to the cursor get higher prio
|
|
ref_prio = dict((tbl.ref, num)
|
|
for num, tbl in enumerate(suggestion.table_refs))
|
|
# Map (schema, table, col) to tables
|
|
coldict = list_dict(
|
|
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
|
|
)
|
|
# For each fk from the left table, generate a join condition if
|
|
# the other table is also in the scope
|
|
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
|
|
for fk, lcol in fks:
|
|
left = col(ltbl.schema, ltbl.name, lcol)
|
|
child = col(fk.childschema, fk.childtable, fk.childcolumn)
|
|
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
|
|
left, right = (child, par) if left == child else (par, child)
|
|
for rtbl in coldict[right]:
|
|
add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join')
|
|
# For name matching, use a {(colname, coltype): TableReference} dict
|
|
coltyp = namedtuple('coltyp', 'name datatype')
|
|
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
|
|
# Find all name-match join conditions
|
|
for c in (coltyp(c.name, c.datatype) for c in lcols):
|
|
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
|
|
prio = 1000 if c.datatype in (
|
|
'integer', 'bigint', 'smallint') else 0
|
|
add_cond(c.name, c.name, rtbl.ref, prio, 'name join')
|
|
|
|
return self.find_matches(word_before_cursor, conds,
|
|
mode='strict', meta='join')
|
|
|
|
def get_function_matches(self, suggestion, word_before_cursor,
|
|
alias=False):
|
|
if suggestion.usage == 'from':
|
|
# Only suggest functions allowed in FROM clause
|
|
def filt(f):
|
|
return not f.is_aggregate and not f.is_window
|
|
else:
|
|
alias = False
|
|
|
|
def filt(f):
|
|
return True
|
|
|
|
arg_mode = {
|
|
'signature': 'signature',
|
|
'special': None,
|
|
}.get(suggestion.usage, 'call')
|
|
# Function overloading means we way have multiple functions of the same
|
|
# name at this point, so keep unique names only
|
|
funcs = set(
|
|
self._make_cand(f, alias, suggestion, arg_mode)
|
|
for f in self.populate_functions(suggestion.schema, filt)
|
|
)
|
|
|
|
matches = self.find_matches(word_before_cursor, funcs,
|
|
mode='strict', meta='function')
|
|
|
|
return matches
|
|
|
|
def get_schema_matches(self, suggestion, word_before_cursor):
|
|
schema_names = self.dbmetadata['tables'].keys()
|
|
|
|
# Unless we're sure the user really wants them, hide schema names
|
|
# starting with pg_, which are mostly temporary schemas
|
|
if not word_before_cursor.startswith('pg_'):
|
|
schema_names = [s
|
|
for s in schema_names
|
|
if not s.startswith('pg_')]
|
|
|
|
if suggestion.quoted:
|
|
schema_names = [self.escape_schema(s) for s in schema_names]
|
|
|
|
return self.find_matches(word_before_cursor, schema_names,
|
|
mode='strict', meta='schema')
|
|
|
|
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
|
|
alias = self.generate_aliases
|
|
s = suggestion
|
|
t_sug = Table(s.schema, s.table_refs, s.local_tables)
|
|
v_sug = View(s.schema, s.table_refs)
|
|
f_sug = Function(s.schema, s.table_refs, usage='from')
|
|
return (
|
|
self.get_table_matches(t_sug, word_before_cursor, alias) +
|
|
self.get_view_matches(v_sug, word_before_cursor, alias) +
|
|
self.get_function_matches(f_sug, word_before_cursor, alias)
|
|
)
|
|
|
|
def _arg_list(self, func, usage):
|
|
"""Returns a an arg list string, e.g. `(_foo:=23)` for a func.
|
|
|
|
:param func is a FunctionMetadata object
|
|
:param usage is 'call', 'call_display' or 'signature'
|
|
|
|
"""
|
|
template = {
|
|
'call': self.call_arg_style,
|
|
'call_display': self.call_arg_display_style,
|
|
'signature': self.signature_arg_style
|
|
}[usage]
|
|
args = func.args()
|
|
if not template:
|
|
return '()'
|
|
elif usage == 'call' and len(args) < 2:
|
|
return '()'
|
|
elif usage == 'call' and func.has_variadic():
|
|
return '()'
|
|
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max
|
|
max_arg_len = max(len(a.name) for a in args) if multiline else 0
|
|
args = (
|
|
self._format_arg(template, arg, arg_num + 1, max_arg_len)
|
|
for arg_num, arg in enumerate(args)
|
|
)
|
|
if multiline:
|
|
return '(' + ','.join('\n ' + a for a in args if a) + '\n)'
|
|
else:
|
|
return '(' + ', '.join(a for a in args if a) + ')'
|
|
|
|
def _format_arg(self, template, arg, arg_num, max_arg_len):
|
|
if not template:
|
|
return None
|
|
if arg.has_default:
|
|
arg_default = 'NULL' if arg.default is None else arg.default
|
|
# Remove trailing ::(schema.)type
|
|
arg_default = arg_default_type_strip_regex.sub('', arg_default)
|
|
else:
|
|
arg_default = ''
|
|
return template.format(
|
|
max_arg_len=max_arg_len,
|
|
arg_name=arg.name,
|
|
arg_num=arg_num,
|
|
arg_type=arg.datatype,
|
|
arg_default=arg_default
|
|
)
|
|
|
|
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
|
|
"""Returns a Candidate namedtuple.
|
|
|
|
:param tbl is a SchemaObject
|
|
:param arg_mode determines what type of arg list to suffix for
|
|
functions.
|
|
Possible values: call, signature
|
|
|
|
"""
|
|
cased_tbl = tbl.name
|
|
if do_alias:
|
|
alias = self.alias(cased_tbl, suggestion.table_refs)
|
|
synonyms = (cased_tbl, generate_alias(cased_tbl))
|
|
maybe_alias = (' ' + alias) if do_alias else ''
|
|
maybe_schema = (tbl.schema + '.') if tbl.schema else ''
|
|
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ''
|
|
if arg_mode == 'call':
|
|
display_suffix = self._arg_list_cache['call_display'][tbl.meta]
|
|
elif arg_mode == 'signature':
|
|
display_suffix = self._arg_list_cache['signature'][tbl.meta]
|
|
else:
|
|
display_suffix = ''
|
|
item = maybe_schema + cased_tbl + suffix + maybe_alias
|
|
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
|
|
prio2 = 0 if tbl.schema else 1
|
|
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
|
|
|
|
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
|
|
tables = self.populate_schema_objects(suggestion.schema, 'tables')
|
|
tables.extend(
|
|
SchemaObject(tbl.name) for tbl in suggestion.local_tables)
|
|
|
|
# Unless we're sure the user really wants them, don't suggest the
|
|
# pg_catalog tables that are implicitly on the search path
|
|
if not suggestion.schema and (
|
|
not word_before_cursor.startswith('pg_')):
|
|
tables = [t for t in tables if not t.name.startswith('pg_')]
|
|
tables = [self._make_cand(t, alias, suggestion) for t in tables]
|
|
return self.find_matches(word_before_cursor, tables,
|
|
mode='strict', meta='table')
|
|
|
|
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
|
|
views = self.populate_schema_objects(suggestion.schema, 'views')
|
|
|
|
if not suggestion.schema and (
|
|
not word_before_cursor.startswith('pg_')):
|
|
views = [v for v in views if not v.name.startswith('pg_')]
|
|
views = [self._make_cand(v, alias, suggestion) for v in views]
|
|
return self.find_matches(word_before_cursor, views,
|
|
mode='strict', meta='view')
|
|
|
|
def get_alias_matches(self, suggestion, word_before_cursor):
|
|
aliases = suggestion.aliases
|
|
return self.find_matches(word_before_cursor, aliases,
|
|
mode='strict', meta='table alias')
|
|
|
|
def get_database_matches(self, _, word_before_cursor):
|
|
return self.find_matches(word_before_cursor, self.databases,
|
|
mode='strict', meta='database')
|
|
|
|
def get_keyword_matches(self, suggestion, word_before_cursor):
|
|
return self.find_matches(word_before_cursor, self.keywords,
|
|
mode='strict', meta='keyword')
|
|
|
|
def get_datatype_matches(self, suggestion, word_before_cursor):
|
|
# suggest custom datatypes
|
|
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
|
|
types = [self._make_cand(t, False, suggestion) for t in types]
|
|
matches = self.find_matches(word_before_cursor, types,
|
|
mode='strict', meta='datatype')
|
|
return matches
|
|
|
|
def get_word_before_cursor(self, word=False):
|
|
"""
|
|
Give the word before the cursor.
|
|
If we have whitespace before the cursor this returns an empty string.
|
|
|
|
Args:
|
|
word:
|
|
"""
|
|
|
|
if self.text_before_cursor[-1:].isspace():
|
|
return ''
|
|
else:
|
|
return self.text_before_cursor[self.find_start_of_previous_word(
|
|
word=word
|
|
):]
|
|
|
|
def find_start_of_previous_word(self, count=1, word=False):
|
|
"""
|
|
Return an index relative to the cursor position pointing to the start
|
|
of the previous word. Return `None` if nothing was found.
|
|
|
|
Args:
|
|
count:
|
|
word:
|
|
"""
|
|
|
|
# Reverse the text before the cursor, in order to do an efficient
|
|
# backwards search.
|
|
text_before_cursor = self.text_before_cursor[::-1]
|
|
|
|
regex = _FIND_BIG_WORD_RE if word else _FIND_WORD_RE
|
|
iterator = regex.finditer(text_before_cursor)
|
|
|
|
try:
|
|
for i, match in enumerate(iterator):
|
|
if i + 1 == count:
|
|
return - match.end(1)
|
|
except StopIteration:
|
|
pass
|
|
|
|
suggestion_matchers = {
|
|
FromClauseItem: get_from_clause_item_matches,
|
|
JoinCondition: get_join_condition_matches,
|
|
Join: get_join_matches,
|
|
Column: get_column_matches,
|
|
Function: get_function_matches,
|
|
Schema: get_schema_matches,
|
|
Table: get_table_matches,
|
|
View: get_view_matches,
|
|
Alias: get_alias_matches,
|
|
Database: get_database_matches,
|
|
Keyword: get_keyword_matches,
|
|
Datatype: get_datatype_matches,
|
|
}
|
|
|
|
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
|
|
"""Find all columns in a set of scoped_tables.
|
|
|
|
:param scoped_tbls: list of TableReference namedtuples
|
|
:param local_tbls: tuple(TableMetadata)
|
|
:return: {TableReference:{colname:ColumnMetaData}}
|
|
|
|
"""
|
|
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
|
|
columns = OrderedDict()
|
|
meta = self.dbmetadata
|
|
|
|
def addcols(schema, rel, alias, reltype, cols):
|
|
tbl = TableReference(schema, rel, alias, reltype == 'functions')
|
|
if tbl not in columns:
|
|
columns[tbl] = []
|
|
columns[tbl].extend(cols)
|
|
|
|
for tbl in scoped_tbls:
|
|
# Local tables should shadow database tables
|
|
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
|
|
cols = ctes[normalize_ref(tbl.name)]
|
|
addcols(None, tbl.name, 'CTE', tbl.alias, cols)
|
|
continue
|
|
schemas = [tbl.schema] if tbl.schema else self.search_path
|
|
for schema in schemas:
|
|
relname = self.escape_name(tbl.name)
|
|
schema = self.escape_name(schema)
|
|
if tbl.is_function:
|
|
# Return column names from a set-returning function
|
|
# Get an array of FunctionMetadata objects
|
|
functions = meta['functions'].get(schema, {}).get(relname)
|
|
for func in (functions or []):
|
|
# func is a FunctionMetadata object
|
|
cols = func.fields()
|
|
addcols(schema, relname, tbl.alias, 'functions', cols)
|
|
else:
|
|
for reltype in ('tables', 'views'):
|
|
cols = meta[reltype].get(schema, {}).get(relname)
|
|
if cols:
|
|
cols = cols.values()
|
|
addcols(schema, relname, tbl.alias, reltype, cols)
|
|
break
|
|
|
|
return columns
|
|
|
|
def _get_schemas(self, obj_typ, schema):
|
|
"""Returns a list of schemas from which to suggest objects.
|
|
|
|
:param schema is the schema qualification input by the user (if any)
|
|
|
|
"""
|
|
metadata = self.dbmetadata[obj_typ]
|
|
if schema:
|
|
schema = self.escape_name(schema)
|
|
return [schema] if schema in metadata else []
|
|
return self.search_path if self.search_path_filter else metadata.keys()
|
|
|
|
def _maybe_schema(self, schema, parent):
|
|
return None if parent or schema in self.search_path else schema
|
|
|
|
def populate_schema_objects(self, schema, obj_type):
|
|
"""Returns a list of SchemaObjects representing tables or views.
|
|
|
|
:param schema is the schema qualification input by the user (if any)
|
|
|
|
"""
|
|
# Fetch the schema objects first
|
|
self.fetch_schema_objects(schema, obj_type)
|
|
|
|
return [
|
|
SchemaObject(
|
|
name=obj,
|
|
schema=(self._maybe_schema(schema=sch, parent=schema))
|
|
)
|
|
for sch in self._get_schemas(obj_type, schema)
|
|
for obj in self.dbmetadata[obj_type][sch].keys()
|
|
]
|
|
|
|
def populate_functions(self, schema, filter_func):
|
|
"""Returns a list of function SchemaObjects.
|
|
|
|
:param filter_func is a function that accepts a FunctionMetadata
|
|
namedtuple and returns a boolean indicating whether that
|
|
function should be kept or discarded
|
|
|
|
"""
|
|
|
|
# Fetch the functions list
|
|
self.fetch_functions(schema)
|
|
|
|
# Because of multiple dispatch, we can have multiple functions
|
|
# with the same name, which is why `for meta in metas` is necessary
|
|
# in the comprehensions below
|
|
return [
|
|
SchemaObject(
|
|
name=func,
|
|
schema=(self._maybe_schema(schema=sch, parent=schema)),
|
|
meta=meta
|
|
)
|
|
for sch in self._get_schemas('functions', schema)
|
|
for (func, metas) in self.dbmetadata['functions'][sch].items()
|
|
for meta in metas
|
|
if filter_func(meta)
|
|
]
|
|
|
|
def fetch_schema_objects(self, schema, obj_type):
|
|
"""
|
|
This function is used to fetch schema objects like tables, views, etc..
|
|
:return:
|
|
"""
|
|
in_clause = ''
|
|
query = ''
|
|
data = []
|
|
|
|
if schema:
|
|
in_clause = '\'' + schema + '\''
|
|
else:
|
|
for r in self.search_path:
|
|
in_clause += '\'' + r + '\','
|
|
# Remove extra comma
|
|
if len(in_clause) > 0:
|
|
in_clause = in_clause[:-1]
|
|
|
|
if obj_type == 'tables':
|
|
query = render_template("/".join([self.sql_path, 'tableview.sql']),
|
|
schema_names=in_clause,
|
|
object_name='tables')
|
|
elif obj_type == 'views':
|
|
query = render_template("/".join([self.sql_path, 'tableview.sql']),
|
|
schema_names=in_clause,
|
|
object_name='views')
|
|
elif obj_type == 'datatypes':
|
|
query = render_template("/".join([self.sql_path, 'datatypes.sql']),
|
|
schema_names=in_clause)
|
|
|
|
if self.conn.connected():
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for record in res['rows']:
|
|
data.append(
|
|
(record['schema_name'], record['object_name'])
|
|
)
|
|
|
|
if (obj_type == 'tables' or obj_type == 'views') and len(data) > 0:
|
|
self.extend_relations(data, obj_type)
|
|
self.extend_columns(
|
|
self.fetch_columns(in_clause, obj_type), obj_type
|
|
)
|
|
if obj_type == 'tables':
|
|
self.extend_foreignkeys(
|
|
self.fetch_foreign_keys(in_clause, obj_type)
|
|
)
|
|
elif obj_type == 'datatypes' and len(data) > 0:
|
|
self.extend_datatypes(data)
|
|
|
|
def fetch_functions(self, schema):
|
|
"""
|
|
This function is used to fecth the list of functions.
|
|
:param schema:
|
|
:return:
|
|
"""
|
|
in_clause = ''
|
|
data = []
|
|
|
|
if schema:
|
|
in_clause = '\'' + schema + '\''
|
|
else:
|
|
for r in self.search_path:
|
|
in_clause += '\'' + r + '\','
|
|
# Remove extra comma
|
|
if len(in_clause) > 0:
|
|
in_clause = in_clause[:-1]
|
|
|
|
query = render_template("/".join([self.sql_path, 'functions.sql']),
|
|
schema_names=in_clause)
|
|
|
|
if self.conn.connected():
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for row in res['rows']:
|
|
data.append(FunctionMetadata(
|
|
row['schema_name'],
|
|
row['func_name'],
|
|
row['arg_names'].strip('{}').split(',')
|
|
if row['arg_names'] is not None
|
|
else row['arg_names'],
|
|
row['arg_types'].strip('{}').split(',')
|
|
if row['arg_types'] is not None
|
|
else row['arg_types'],
|
|
row['arg_modes'].strip('{}').split(',')
|
|
if row['arg_modes'] is not None
|
|
else row['arg_modes'],
|
|
row['return_type'],
|
|
row['is_aggregate'],
|
|
row['is_window'],
|
|
row['is_set_returning'],
|
|
row['arg_defaults'].strip('{}').split(',')
|
|
if row['arg_defaults'] is not None
|
|
else row['arg_defaults']
|
|
))
|
|
|
|
if len(data) > 0:
|
|
self.extend_functions(data)
|
|
|
|
def fetch_columns(self, schemas, obj_type):
|
|
"""
|
|
This function is used to fetch the columns for the given schema name
|
|
:param schemas:
|
|
:param obj_type:
|
|
:return:
|
|
"""
|
|
|
|
data = []
|
|
query = render_template("/".join([self.sql_path, 'columns.sql']),
|
|
schema_names=schemas,
|
|
object_name='table')
|
|
if obj_type == 'views':
|
|
query = render_template("/".join([self.sql_path, 'columns.sql']),
|
|
schema_names=schemas,
|
|
object_name='view')
|
|
if self.conn.connected():
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for row in res['rows']:
|
|
data.append((
|
|
row['schema_name'], row['table_name'],
|
|
row['column_name'], row['type_name'],
|
|
row['has_default'], row['default']
|
|
))
|
|
return data
|
|
|
|
def fetch_foreign_keys(self, schemas, obj_type):
|
|
"""
|
|
This function is used to fetch the foreign_keys for the given
|
|
schema name
|
|
:param schemas:
|
|
:param obj_type:
|
|
:return:
|
|
"""
|
|
|
|
data = []
|
|
query = render_template("/".join([self.sql_path, 'foreign_keys.sql']),
|
|
schema_names=schemas)
|
|
|
|
if self.conn.connected():
|
|
status, res = self.conn.execute_dict(query)
|
|
if status:
|
|
for row in res['rows']:
|
|
data.append(ForeignKey(
|
|
row['parentschema'], row['parenttable'],
|
|
row['parentcolumn'], row['childschema'],
|
|
row['childtable'], row['childcolumn']
|
|
))
|
|
return data
|