Update pgcli to latest release 3.4.1. Fixes #7411

This commit is contained in:
Akshay Joshi 2022-06-02 17:29:58 +05:30
parent 7066841467
commit 4a17ad312f
9 changed files with 185 additions and 94 deletions

View File

@ -11,6 +11,7 @@ notes for it.
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
release_notes_6_11
release_notes_6_10 release_notes_6_10
release_notes_6_9 release_notes_6_9
release_notes_6_8 release_notes_6_8

View File

@ -0,0 +1,20 @@
************
Version 6.11
************
Release date: 2022-06-30
This release contains a number of bug fixes and new features since the release of pgAdmin 4 v6.10.
New features
************
Housekeeping
************
| `Issue #7411 <https://redmine.postgresql.org/issues/7411>`_ - Update pgcli to latest release 3.4.1.
Bug fixes
*********

View File

@ -16,8 +16,20 @@ from .completion import Completion
from collections import namedtuple, defaultdict, OrderedDict from collections import namedtuple, defaultdict, OrderedDict
from .sqlcompletion import ( from .sqlcompletion import (
FromClauseItem, suggest_type, Database, Schema, Table, FromClauseItem,
Function, Column, View, Keyword, Datatype, Alias, JoinCondition, Join) suggest_type,
Database,
Schema,
Table,
Function,
Column,
View,
Keyword,
Datatype,
Alias,
JoinCondition,
Join
)
from .parseutils.meta import FunctionMetadata, ColumnMetadata, ForeignKey from .parseutils.meta import FunctionMetadata, ColumnMetadata, ForeignKey
from .parseutils.utils import last_word from .parseutils.utils import last_word
from .parseutils.tables import TableReference from .parseutils.tables import TableReference
@ -228,7 +240,7 @@ class SQLAutoComplete(object):
:return: :return:
""" """
# casing should be a dict {lowercasename:PreferredCasingName} # casing should be a dict {lowercasename:PreferredCasingName}
self.casing = dict((word.lower(), word) for word in words) self.casing = {word.lower(): word for word in words}
def extend_relations(self, data, kind): def extend_relations(self, data, kind):
"""extend metadata for tables or views. """extend metadata for tables or views.
@ -305,13 +317,15 @@ class SQLAutoComplete(object):
# would result if we'd recalculate the arg lists each time we suggest # would result if we'd recalculate the arg lists each time we suggest
# functions (in large DBs) # functions (in large DBs)
self._arg_list_cache = \ self._arg_list_cache = {
dict((usage, usage: {
dict((meta, self._arg_list(meta, usage)) meta: self._arg_list(meta, usage)
for sch, funcs in self.dbmetadata["functions"].items() for sch, funcs in self.dbmetadata["functions"].items()
for func, metas in funcs.items() for func, metas in funcs.items()
for meta in metas)) for meta in metas
for usage in ("call", "call_display", "signature")) }
for usage in ("call", "call_display", "signature")
}
def extend_foreignkeys(self, fk_data): def extend_foreignkeys(self, fk_data):
@ -341,8 +355,8 @@ class SQLAutoComplete(object):
parentschema, parenttable, parcol, parentschema, parenttable, parcol,
childschema, childtable, childcol childschema, childtable, childcol
) )
childcolmeta.foreignkeys.append((fk)) childcolmeta.foreignkeys.append(fk)
parcolmeta.foreignkeys.append((fk)) parcolmeta.foreignkeys.append(fk)
def extend_datatypes(self, type_data): def extend_datatypes(self, type_data):
@ -383,11 +397,6 @@ class SQLAutoComplete(object):
yields prompt_toolkit Completion instances for any matches found yields prompt_toolkit Completion instances for any matches found
in the collection of available completions. in the collection of available completions.
Args:
text:
collection:
mode:
meta:
""" """
if not collection: if not collection:
return [] return []
@ -487,8 +496,11 @@ class SQLAutoComplete(object):
# We also use the unescape_name to make sure quoted names have # We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names. # the same priority as unquoted names.
lexical_priority = ( lexical_priority = (
tuple(0 if c in (" _") else -ord(c) tuple(
for c in self.unescape_name(item.lower())) + (1,) + 0 if c in " _" else -ord(c)
for c in self.unescape_name(item.lower())
) +
(1,) +
tuple(c for c in item) tuple(c for c in item)
) )
@ -551,11 +563,14 @@ class SQLAutoComplete(object):
self.fetch_schema_objects(schema, 'views') self.fetch_schema_objects(schema, 'views')
tables = suggestion.table_refs tables = suggestion.table_refs
do_qualify = suggestion.qualifiable and { do_qualify = (
"always": True, suggestion.qualifiable and
"never": False, {
"if_more_than_one_table": len(tables) > 1, "always": True,
}[self.qualify_columns] "never": False,
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
)
def qualify(col, tbl): def qualify(col, tbl):
return (tbl + '.' + col) if do_qualify else col return (tbl + '.' + col) if do_qualify else col
@ -579,13 +594,15 @@ class SQLAutoComplete(object):
# (...' which should # (...' which should
# suggest only columns that appear in the last table and one more # suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref ltbl = tables[-1].ref
other_tbl_cols = set( other_tbl_cols = {
c.name for t, cs in scoped_cols.items() if t.ref != ltbl c.name for t, cs in scoped_cols.items()
for c in cs if t.ref != ltbl for c in cs
) }
scoped_cols = \ scoped_cols = {
dict((t, [col for col in cols if col.name in other_tbl_cols]) t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items() if t.ref == ltbl) for t, cols in scoped_cols.items()
if t.ref == ltbl
}
lastword = last_word(word_before_cursor, include="most_punctuations") lastword = last_word(word_before_cursor, include="most_punctuations")
if lastword == "*": if lastword == "*":
@ -598,9 +615,11 @@ class SQLAutoComplete(object):
p.match(col.default) p.match(col.default)
for p in self.insert_col_skip_patterns for p in self.insert_col_skip_patterns
) )
scoped_cols = \
dict((t, [col for col in cols if filter(col)]) scoped_cols = {
for t, cols in scoped_cols.items()) t: [col for col in cols if filter(col)]
for t, cols in scoped_cols.items()
}
if self.asterisk_column_order == "alphabetic": if self.asterisk_column_order == "alphabetic":
for cols in scoped_cols.values(): for cols in scoped_cols.values():
cols.sort(key=operator.attrgetter("name")) cols.sort(key=operator.attrgetter("name"))
@ -650,10 +669,10 @@ class SQLAutoComplete(object):
tbls = suggestion.table_refs tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls) cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access # Set up some data structures for efficient access
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) qualified = {normalize_ref(t.ref): t.schema for t in tbls}
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
refs = set(normalize_ref(t.ref) for t in tbls) refs = {normalize_ref(t.ref) for t in tbls}
other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
joins = [] joins = []
# Iterate over FKs in existing tables to find potential joins # Iterate over FKs in existing tables to find potential joins
fks = ( fks = (
@ -689,10 +708,11 @@ class SQLAutoComplete(object):
] ]
# Schema-qualify if (1) new table in same schema as old, and old # Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public # is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and \ if not suggestion.schema and (
(qualified[normalize_ref(rtbl.ref)] and qualified[normalize_ref(rtbl.ref)] and
left.schema == right.schema or left.schema == right.schema or
left.schema not in (right.schema, "public")): left.schema not in (right.schema, "public")
):
join = left.schema + "." + join join = left.schema + "." + join
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
0 if (left.schema, left.tbl) in other_tbls else 1 0 if (left.schema, left.tbl) in other_tbls else 1
@ -726,8 +746,8 @@ class SQLAutoComplete(object):
return d return d
# Tables that are closer to the cursor get higher prio # Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) ref_prio = \
for num, tbl in enumerate(suggestion.table_refs)) {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
# Map (schema, table, col) to tables # Map (schema, table, col) to tables
coldict = list_dict( coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
@ -761,9 +781,14 @@ class SQLAutoComplete(object):
def filt(f): def filt(f):
return ( return (
not f.is_aggregate and not f.is_window and not f.is_aggregate and
not f.is_window and
not f.is_extension and not f.is_extension and
(f.is_public or f.schema_name == suggestion.schema) (
f.is_public or
f.schema_name in self.search_path or
f.schema_name == suggestion.schema
)
) )
else: else:
@ -781,13 +806,19 @@ class SQLAutoComplete(object):
# Function overloading means we way have multiple functions of the same # Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only # name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt) all_functions = self.populate_functions(suggestion.schema, filt)
funcs = set( funcs = {self._make_cand(f, alias, suggestion, arg_mode)
self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
for f in all_functions
)
matches = self.find_matches(word_before_cursor, funcs, meta="function") matches = self.find_matches(word_before_cursor, funcs, meta="function")
if not suggestion.schema and not suggestion.usage:
# also suggest hardcoded functions using startswith matching
predefined_funcs = self.find_matches(
word_before_cursor, self.functions, mode="strict",
meta="function"
)
matches.extend(predefined_funcs)
return matches return matches
def get_schema_matches(self, suggestion, word_before_cursor): def get_schema_matches(self, suggestion, word_before_cursor):
@ -930,6 +961,16 @@ class SQLAutoComplete(object):
types = self.populate_schema_objects(suggestion.schema, "datatypes") types = self.populate_schema_objects(suggestion.schema, "datatypes")
types = [self._make_cand(t, False, suggestion) for t in types] types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types, meta="datatype") matches = self.find_matches(word_before_cursor, types, meta="datatype")
if not suggestion.schema:
# Also suggest hardcoded types
matches.extend(
self.find_matches(
word_before_cursor, self.datatypes, mode="strict",
meta="datatype"
)
)
return matches return matches
def get_word_before_cursor(self, word=False): def get_word_before_cursor(self, word=False):
@ -995,7 +1036,7 @@ class SQLAutoComplete(object):
:return: {TableReference:{colname:ColumnMetaData}} :return: {TableReference:{colname:ColumnMetaData}}
""" """
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
columns = OrderedDict() columns = OrderedDict()
meta = self.dbmetadata meta = self.dbmetadata

View File

@ -1,6 +1,6 @@
""" """
Using Completion class from Using Completion class from
https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/completion/base.py https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/src/prompt_toolkit/completion/base.py
""" """
__all__ = ( __all__ = (
@ -8,7 +8,7 @@ __all__ = (
) )
class Completion(object): class Completion:
""" """
:param text: The new string that will be inserted into the document. :param text: The new string that will be inserted into the document.
:param start_position: Position relative to the cursor_position where the :param start_position: Position relative to the cursor_position where the
@ -36,31 +36,52 @@ class Completion(object):
assert self.start_position <= 0 assert self.start_position <= 0
def __repr__(self): def __repr__(self) -> str:
return "%s(text=%r, start_position=%r)" % ( if isinstance(self.display, str) and self.display == self.text:
self.__class__.__name__, self.text, self.start_position) return "{}(text={!r}, start_position={!r})".format(
self.__class__.__name__,
self.text,
self.start_position,
)
else:
return "{}(text={!r}, start_position={!r}, display={!r})".format(
self.__class__.__name__,
self.text,
self.start_position,
self.display,
)
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, Completion):
return False
return ( return (
self.text == other.text and self.text == other.text and
self.start_position == other.start_position and self.start_position == other.start_position and
self.display == other.display and self.display == other.display and
self.display_meta == other.display_meta) self._display_meta == other._display_meta
def __hash__(self):
return hash(
(self.text, self.start_position, self.display, self.display_meta)
) )
def __hash__(self) -> int:
return hash((self.text, self.start_position, self.display,
self._display_meta))
@property @property
def display_meta(self): def display_meta(self):
# Return meta-text. (This is lazy when using "get_display_meta".) # Return meta-text. (This is lazy when using "get_display_meta".)
if self._display_meta is not None: if self._display_meta is not None:
return self._display_meta return self._display_meta
elif self._get_display_meta: def new_completion_from_position(self, position: int) -> "Completion":
self._display_meta = self._get_display_meta() """
return self._display_meta (Only for internal use!)
Get a new completion by splitting this one. Used by `Application` when
it needs to have a list of new completions after inserting the common
prefix.
"""
assert position - self.start_position >= 0
else: return Completion(
return '' text=self.text[position - self.start_position:],
display=self.display,
display_meta=self._display_meta,
)

View File

@ -1,22 +1,36 @@
import sqlparse import sqlparse
def query_starts_with(query, prefixes): def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*.""" """Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes] prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
def queries_start_with(queries, prefixes): def query_is_unconditional_update(formatted_sql):
"""Check if any queries start with any item from *prefixes*.""" """Check if the query starts with UPDATE and contains no WHERE."""
for query in sqlparse.split(queries): tokens = formatted_sql.split()
if query and query_starts_with(query, prefixes) is True: return bool(tokens) and tokens[0] == "update" and "where" not in tokens
return True
return False
def is_destructive(queries): def query_is_simple_update(formatted_sql):
"""Check if the query starts with UPDATE."""
tokens = formatted_sql.split()
return bool(tokens) and tokens[0] == "update"
def is_destructive(queries, warning_level="all"):
"""Returns if any of the queries in *queries* is destructive.""" """Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter") keywords = ("drop", "shutdown", "delete", "truncate", "alter")
return queries_start_with(queries, keywords) for query in sqlparse.split(queries):
if query:
formatted_sql = sqlparse.format(query.lower(),
strip_comments=True).strip()
if query_starts_with(formatted_sql, keywords):
return True
if query_is_unconditional_update(formatted_sql):
return True
if warning_level == "all" and \
query_is_simple_update(formatted_sql):
return True
return False

View File

@ -53,7 +53,7 @@ def parse_defaults(defaults_string):
yield current yield current
class FunctionMetadata(object): class FunctionMetadata:
def __init__( def __init__(
self, self,
schema_name, schema_name,

View File

@ -41,8 +41,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for item in parsed.tokens: for item in parsed.tokens:
if tbl_prefix_seen: if tbl_prefix_seen:
if is_subselect(item): if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation): yield from extract_from_part(item, stop_at_punctuation)
yield x
elif stop_at_punctuation and item.ttype is Punctuation: elif stop_at_punctuation and item.ttype is Punctuation:
return return
# An incomplete nested select won't be recognized correctly as a # An incomplete nested select won't be recognized correctly as a
@ -63,16 +62,13 @@ def extract_from_part(parsed, stop_at_punctuation=True):
yield item yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML: elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper() item_val = item.value.upper()
if ( if item_val in (
item_val "COPY",
in ( "FROM",
"COPY", "INTO",
"FROM", "UPDATE",
"INTO", "TABLE",
"UPDATE", ) or item_val.endswith("JOIN"):
"TABLE",
) or item_val.endswith("JOIN")
):
tbl_prefix_seen = True tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list. # 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary. # So this check here is necessary.

View File

@ -14,7 +14,7 @@ def _compile_regex(keyword):
return re.compile(pattern, re.MULTILINE | re.IGNORECASE) return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
class PrevalenceCounter(object): class PrevalenceCounter:
def __init__(self, keywords): def __init__(self, keywords):
self.keyword_counts = defaultdict(int) self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int) self.name_counts = defaultdict(int)

View File

@ -1,4 +1,3 @@
import sys
import re import re
import sqlparse import sqlparse
from collections import namedtuple from collections import namedtuple
@ -9,7 +8,6 @@ from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes from .parseutils.ctes import isolate_query_ctes
Special = namedtuple("Special", [])
Database = namedtuple("Database", []) Database = namedtuple("Database", [])
Schema = namedtuple("Schema", ["quoted"]) Schema = namedtuple("Schema", ["quoted"])
Schema.__new__.__defaults__ = (False,) Schema.__new__.__defaults__ = (False,)
@ -48,7 +46,7 @@ Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", []) Path = namedtuple("Path", [])
class SqlStatement(object): class SqlStatement:
def __init__(self, full_text, text_before_cursor): def __init__(self, full_text, text_before_cursor):
self.identifier = None self.identifier = None
self.word_before_cursor = word_before_cursor = last_word( self.word_before_cursor = word_before_cursor = last_word(