PEP8 fixes for the utils module. Fixes #3076

This commit is contained in:
Murtuza Zabuawala 2018-01-31 13:58:55 +00:00 committed by Dave Page
parent c6e405ce72
commit c3ddb7df38
23 changed files with 1075 additions and 461 deletions

View File

@ -184,6 +184,7 @@ def file_quote(_p):
return _p.encode(fs_encoding)
return _p
if IS_WIN:
import ctypes
from ctypes import wintypes
@ -198,7 +199,7 @@ if IS_WIN:
if n == 0:
return None
buf= ctypes.create_unicode_buffer(u'\0'*n)
buf = ctypes.create_unicode_buffer(u'\0' * n)
ctypes.windll.kernel32.GetEnvironmentVariableW(name, buf, n)
return buf.value
@ -279,4 +280,3 @@ def get_complete_file_path(file):
file = fs_short_path(file)
return file if os.path.isfile(file) else None

View File

@ -29,6 +29,7 @@ class DataTypeJSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj)
def get_no_cache_header():
"""
Prevent browser from caching data every time an
@ -36,7 +37,8 @@ def get_no_cache_header():
Returns: headers
"""
headers = {}
headers["Cache-Control"] = "no-cache, no-store, must-revalidate" # HTTP 1.1.
# HTTP 1.1.
headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
headers["Pragma"] = "no-cache" # HTTP 1.0.
headers["Expires"] = "0" # Proxies.
return headers
@ -55,7 +57,8 @@ def make_json_response(
doc['data'] = data
return Response(
response=json.dumps(doc, cls=DataTypeJSONEncoder, separators=(',',':')),
response=json.dumps(doc, cls=DataTypeJSONEncoder,
separators=(',', ':')),
status=status,
mimetype="application/json",
headers=get_no_cache_header()
@ -65,7 +68,8 @@ def make_json_response(
def make_response(response=None, status=200):
"""Create a JSON response handled by the backbone models."""
return Response(
response=json.dumps(response, cls=DataTypeJSONEncoder, separators=(',',':')),
response=json.dumps(
response, cls=DataTypeJSONEncoder, separators=(',', ':')),
status=status,
mimetype="application/json",
headers=get_no_cache_header()
@ -135,7 +139,8 @@ def gone(errormsg=''):
)
def not_implemented(errormsg=_('Not implemented.'), info='', result=None, data=None):
def not_implemented(errormsg=_('Not implemented.'), info='',
result=None, data=None):
"""Create a response with HTTP status code 501 - Not Implemented."""
return make_json_response(
status=501,
@ -147,7 +152,8 @@ def not_implemented(errormsg=_('Not implemented.'), info='', result=None, data=N
)
def service_unavailable(errormsg=_("Service Unavailable"), info='', result=None, data=None):
def service_unavailable(errormsg=_("Service Unavailable"), info='',
result=None, data=None):
"""Create a response with HTTP status code 503 - Server Unavailable."""
return make_json_response(
status=503,

View File

@ -95,7 +95,8 @@ class BaseConnection(object):
datum result.
* execute_async(query, params, formatted_exception_msg)
- Implement this method to execute the given query asynchronously and returns result.
- Implement this method to execute the given query asynchronously and
returns result.
* execute_void(query, params, formatted_exception_msg)
- Implement this method to execute the given query with no result.
@ -173,27 +174,33 @@ class BaseConnection(object):
pass
@abstractmethod
def execute_scalar(self, query, params=None, formatted_exception_msg=False):
def execute_scalar(self, query, params=None,
formatted_exception_msg=False):
pass
@abstractmethod
def execute_async(self, query, params=None, formatted_exception_msg=True):
def execute_async(self, query, params=None,
formatted_exception_msg=True):
pass
@abstractmethod
def execute_void(self, query, params=None, formatted_exception_msg=False):
def execute_void(self, query, params=None,
formatted_exception_msg=False):
pass
@abstractmethod
def execute_2darray(self, query, params=None, formatted_exception_msg=False):
def execute_2darray(self, query, params=None,
formatted_exception_msg=False):
pass
@abstractmethod
def execute_dict(self, query, params=None, formatted_exception_msg=False):
def execute_dict(self, query, params=None,
formatted_exception_msg=False):
pass
@abstractmethod
def async_fetchmany_2darray(self, records=-1, formatted_exception_msg=False):
def async_fetchmany_2darray(self, records=-1,
formatted_exception_msg=False):
pass
@abstractmethod

View File

@ -34,9 +34,9 @@ from pgadmin.utils import get_complete_file_path
from .keywords import ScanKeyword
from ..abstract import BaseDriver, BaseConnection
from .cursor import DictCursor
from .typecast import register_global_typecasters, register_string_typecasters,\
register_binary_typecasters, register_array_to_string_typecasters,\
ALL_JSON_TYPES
from .typecast import register_global_typecasters, \
register_string_typecasters, register_binary_typecasters, \
register_array_to_string_typecasters, ALL_JSON_TYPES
if sys.version_info < (3,):
@ -67,7 +67,8 @@ class Connection(BaseConnection):
Methods:
-------
* connect(**kwargs)
- Connect the PostgreSQL/EDB Postgres Advanced Server using the psycopg2 driver
- Connect the PostgreSQL/EDB Postgres Advanced Server using the psycopg2
driver
* execute_scalar(query, params, formatted_exception_msg)
- Execute the given query and returns single datum result
@ -118,10 +119,12 @@ class Connection(BaseConnection):
connection.
* status_message()
- Returns the status message returned by the last command executed on the server.
- Returns the status message returned by the last command executed on
the server.
* rows_affected()
- Returns the no of rows affected by the last command executed on the server.
- Returns the no of rows affected by the last command executed on
the server.
* cancel_transaction(conn_id, did=None)
- This method is used to cancel the transaction for the
@ -241,9 +244,9 @@ class Connection(BaseConnection):
except Exception as e:
current_app.logger.exception(e)
return False, \
_("Failed to decrypt the saved password.\nError: {0}").format(
str(e)
)
_(
"Failed to decrypt the saved password.\nError: {0}"
).format(str(e))
# If no password credential is found then connect request might
# come from Query tool, ViewData grid, debugger etc tools.
@ -263,7 +266,8 @@ class Connection(BaseConnection):
conn_id = self.conn_id
import os
os.environ['PGAPPNAME'] = '{0} - {1}'.format(config.APP_NAME, conn_id)
os.environ['PGAPPNAME'] = '{0} - {1}'.format(
config.APP_NAME, conn_id)
pg_conn = psycopg2.connect(
host=mgr.host,
@ -294,15 +298,15 @@ class Connection(BaseConnection):
msg = e.diag.message_detail
else:
msg = str(e)
current_app.logger.info(u"""
Failed to connect to the database server(#{server_id}) for connection ({conn_id}) with error message as below:
{msg}""".format(
server_id=self.manager.sid,
conn_id=conn_id,
msg=msg.decode('utf-8') if hasattr(str, 'decode') else msg
current_app.logger.info(
u"Failed to connect to the database server(#{server_id}) for "
u"connection ({conn_id}) with error message as below"
u":{msg}".format(
server_id=self.manager.sid,
conn_id=conn_id,
msg=msg.decode('utf-8') if hasattr(str, 'decode') else msg
)
)
)
return False, msg
self.conn = pg_conn
@ -348,7 +352,7 @@ Failed to connect to the database server(#{server_id}) for connection ({conn_id}
# autocommit flag does not work with asynchronous connections.
# By default asynchronous connection runs in autocommit mode.
if self.async == 0:
if 'autocommit' in kwargs and kwargs['autocommit'] == False:
if 'autocommit' in kwargs and kwargs['autocommit'] is False:
self.conn.autocommit = False
else:
self.conn.autocommit = True
@ -363,11 +367,10 @@ Failed to connect to the database server(#{server_id}) for connection ({conn_id}
if self.use_binary_placeholder:
register_binary_typecasters(self.conn)
status = _execute(cur, """
SET DateStyle=ISO;
SET client_min_messages=notice;
SET bytea_output=escape;
SET client_encoding='UNICODE';""")
status = _execute(cur, "SET DateStyle=ISO;"
"SET client_min_messages=notice;"
"SET bytea_output=escape;"
"SET client_encoding='UNICODE';")
if status is not None:
self.conn.close()
@ -381,19 +384,19 @@ SET client_encoding='UNICODE';""")
if status is not None:
self.conn.close()
self.conn = None
current_app.logger.error("""
Connect to the database server (#{server_id}) for connection ({conn_id}), but - failed to setup the role with error message as below:
{msg}
""".format(
server_id=self.manager.sid,
conn_id=conn_id,
msg=status
)
current_app.logger.error(
"Connect to the database server (#{server_id}) for "
"connection ({conn_id}), but - failed to setup the role "
"with error message as below:{msg}".format(
server_id=self.manager.sid,
conn_id=conn_id,
msg=status
)
)
return False, \
_("Failed to setup the role with error message:\n{0}").format(
status
)
_(
"Failed to setup the role with error message:\n{0}"
).format(status)
if mgr.ver is None:
status = _execute(cur, "SELECT version()")
@ -402,14 +405,14 @@ Connect to the database server (#{server_id}) for connection ({conn_id}), but -
self.conn.close()
self.conn = None
self.wasConnected = False
current_app.logger.error("""
Failed to fetch the version information on the established connection to the database server (#{server_id}) for '{conn_id}' with below error message:
{msg}
""".format(
server_id=self.manager.sid,
conn_id=conn_id,
msg=status
)
current_app.logger.error(
"Failed to fetch the version information on the "
"established connection to the database server "
"(#{server_id}) for '{conn_id}' with below error "
"message:{msg}".format(
server_id=self.manager.sid,
conn_id=conn_id,
msg=status)
)
return False, status
@ -420,7 +423,8 @@ Failed to fetch the version information on the established connection to the dat
status = _execute(cur, """
SELECT
db.oid as did, db.datname, db.datallowconn, pg_encoding_to_char(db.encoding) AS serverencoding,
db.oid as did, db.datname, db.datallowconn,
pg_encoding_to_char(db.encoding) AS serverencoding,
has_database_privilege(db.oid, 'CREATE') as cancreate, datlastsysoid
FROM
pg_database db
@ -454,7 +458,8 @@ WHERE
mgr.password = kwargs['password']
server_types = None
if 'server_types' in kwargs and isinstance(kwargs['server_types'], list):
if 'server_types' in kwargs and isinstance(
kwargs['server_types'], list):
server_types = mgr.server_types = kwargs['server_types']
if server_types is None:
@ -491,7 +496,8 @@ WHERE
errmsg = ""
current_app.logger.warning(
"Connection to database server (#{server_id}) for the connection - '{conn_id}' has been lost.".format(
"Connection to database server (#{server_id}) for the "
"connection - '{conn_id}' has been lost.".format(
server_id=self.manager.sid,
conn_id=self.conn_id
)
@ -518,7 +524,8 @@ WHERE
except psycopg2.Error as pe:
current_app.logger.exception(pe)
errmsg = gettext(
"Failed to create cursor for psycopg2 connection with error message for the server#{1}:{2}:\n{0}"
"Failed to create cursor for psycopg2 connection with error "
"message for the server#{1}:{2}:\n{0}"
).format(
str(pe), self.manager.sid, self.db
)
@ -529,7 +536,8 @@ WHERE
if self.auto_reconnect and not self.reconnecting:
current_app.logger.info(
gettext(
"Attempting to reconnect to the database server (#{server_id}) for the connection - '{conn_id}'."
"Attempting to reconnect to the database server "
"(#{server_id}) for the connection - '{conn_id}'."
).format(
server_id=self.manager.sid,
conn_id=self.conn_id
@ -542,7 +550,8 @@ WHERE
raise ConnectionLost(
self.manager.sid,
self.db,
None if self.conn_id[0:3] == u'DB:' else self.conn_id[5:]
None if self.conn_id[0:3] == u'DB:'
else self.conn_id[5:]
)
setattr(
@ -602,15 +611,15 @@ WHERE
query = query.encode('utf-8')
current_app.logger.log(
25,
u"Execute (with server cursor) for server #{server_id} - {conn_id} "
u"(Query-id: {query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query.decode('utf-8') if
sys.version_info < (3,) else query,
query_id=query_id
)
25,
u"Execute (with server cursor) for server #{server_id} - "
u"{conn_id} (Query-id: {query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query.decode('utf-8') if
sys.version_info < (3,) else query,
query_id=query_id
)
)
try:
self.__internal_blocking_execute(cur, query, params)
@ -673,7 +682,8 @@ WHERE
new_results = []
for row in results:
new_results.append(
dict([(k.decode(conn_encoding), v) for k, v in row.items()])
dict([(k.decode(conn_encoding), v)
for k, v in row.items()])
)
return new_results
@ -715,13 +725,14 @@ WHERE
# Decode the field_separator
try:
field_separator = field_separator.decode('utf-8')
except:
pass
except Exception as e:
current_app.logger.error(e)
# Decode the quote_char
try:
quote_char = quote_char.decode('utf-8')
except:
pass
except Exception as e:
current_app.logger.error(e)
csv_writer = csv.DictWriter(
res_io, fieldnames=header, delimiter=field_separator,
@ -759,7 +770,8 @@ WHERE
return True, gen
def execute_scalar(self, query, params=None, formatted_exception_msg=False):
def execute_scalar(self, query, params=None,
formatted_exception_msg=False):
status, cur = self.__cursor()
self.row_count = 0
@ -769,7 +781,8 @@ WHERE
current_app.logger.log(
25,
u"Execute (scalar) for server #{server_id} - {conn_id} (Query-id: {query_id}):\n{query}".format(
u"Execute (scalar) for server #{server_id} - {conn_id} (Query-id: "
u"{query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
@ -783,7 +796,8 @@ WHERE
if not self.connected():
if self.auto_reconnect and not self.reconnecting:
return self.__attempt_execution_reconnect(
self.execute_dict, query, params, formatted_exception_msg
self.execute_dict, query, params,
formatted_exception_msg
)
raise ConnectionLost(
self.manager.sid,
@ -792,7 +806,9 @@ WHERE
)
errmsg = self._formatted_exception_msg(pe, formatted_exception_msg)
current_app.logger.error(
u"Failed to execute query (execute_scalar) for the server #{server_id} - {conn_id} (Query-id: {query_id}):\nError Message:{errmsg}".format(
u"Failed to execute query (execute_scalar) for the server "
u"#{server_id} - {conn_id} (Query-id: {query_id}):\n"
u"Error Message:{errmsg}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
@ -812,12 +828,14 @@ WHERE
def execute_async(self, query, params=None, formatted_exception_msg=True):
"""
This function executes the given query asynchronously and returns result.
This function executes the given query asynchronously and returns
result.
Args:
query: SQL query to run.
params: extra parameters to the function
formatted_exception_msg: if True then function return the formatted exception message
formatted_exception_msg: if True then function return the
formatted exception message
"""
if sys.version_info < (3,):
@ -835,7 +853,8 @@ WHERE
current_app.logger.log(
25,
u"Execute (async) for server #{server_id} - {conn_id} (Query-id: {query_id}):\n{query}".format(
u"Execute (async) for server #{server_id} - {conn_id} (Query-id: "
u"{query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query.decode('utf-8'),
@ -850,16 +869,16 @@ WHERE
res = self._wait_timeout(cur.connection)
except psycopg2.Error as pe:
errmsg = self._formatted_exception_msg(pe, formatted_exception_msg)
current_app.logger.error(u"""
Failed to execute query (execute_async) for the server #{server_id} - {conn_id}
(Query-id: {query_id}):\nError Message:{errmsg}
""".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query.decode('utf-8'),
errmsg=errmsg,
query_id=query_id
)
current_app.logger.error(
u"Failed to execute query (execute_async) for the server "
u"#{server_id} - {conn_id}(Query-id: {query_id}):\n"
u"Error Message:{errmsg}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query.decode('utf-8'),
errmsg=errmsg,
query_id=query_id
)
)
return False, errmsg
@ -875,7 +894,8 @@ Failed to execute query (execute_async) for the server #{server_id} - {conn_id}
Args:
query: SQL query to run.
params: extra parameters to the function
formatted_exception_msg: if True then function return the formatted exception message
formatted_exception_msg: if True then function return the
formatted exception message
"""
status, cur = self.__cursor()
self.row_count = 0
@ -886,7 +906,8 @@ Failed to execute query (execute_async) for the server #{server_id} - {conn_id}
current_app.logger.log(
25,
u"Execute (void) for server #{server_id} - {conn_id} (Query-id: {query_id}):\n{query}".format(
u"Execute (void) for server #{server_id} - {conn_id} (Query-id: "
u"{query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
@ -901,7 +922,8 @@ Failed to execute query (execute_async) for the server #{server_id} - {conn_id}
if not self.connected():
if self.auto_reconnect and not self.reconnecting:
return self.__attempt_execution_reconnect(
self.execute_void, query, params, formatted_exception_msg
self.execute_void, query, params,
formatted_exception_msg
)
raise ConnectionLost(
self.manager.sid,
@ -909,16 +931,16 @@ Failed to execute query (execute_async) for the server #{server_id} - {conn_id}
None if self.conn_id[0:3] == u'DB:' else self.conn_id[5:]
)
errmsg = self._formatted_exception_msg(pe, formatted_exception_msg)
current_app.logger.error(u"""
Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
(Query-id: {query_id}):\nError Message:{errmsg}
""".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
errmsg=errmsg,
query_id=query_id
)
current_app.logger.error(
u"Failed to execute query (execute_void) for the server "
u"#{server_id} - {conn_id}(Query-id: {query_id}):\n"
u"Error Message:{errmsg}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
errmsg=errmsg,
query_id=query_id
)
)
return False, errmsg
@ -944,7 +966,8 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
self.reconnecting = False
current_app.warning(
"Failed to reconnect the database server (#{server_id})".format(
"Failed to reconnect the database server "
"(#{server_id})".format(
server_id=self.manager.sid,
conn_id=self.conn_id
)
@ -956,7 +979,8 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
None if self.conn_id[0:3] == u'DB:' else self.conn_id[5:]
)
def execute_2darray(self, query, params=None, formatted_exception_msg=False):
def execute_2darray(self, query, params=None,
formatted_exception_msg=False):
status, cur = self.__cursor()
self.row_count = 0
@ -966,7 +990,8 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
query_id = random.randint(1, 9999999)
current_app.logger.log(
25,
u"Execute (2darray) for server #{server_id} - {conn_id} (Query-id: {query_id}):\n{query}".format(
u"Execute (2darray) for server #{server_id} - {conn_id} "
u"(Query-id: {query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
@ -981,11 +1006,14 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
if self.auto_reconnect and \
not self.reconnecting:
return self.__attempt_execution_reconnect(
self.execute_2darray, query, params, formatted_exception_msg
self.execute_2darray, query, params,
formatted_exception_msg
)
errmsg = self._formatted_exception_msg(pe, formatted_exception_msg)
current_app.logger.error(
u"Failed to execute query (execute_2darray) for the server #{server_id} - {conn_id} (Query-id: {query_id}):\nError Message:{errmsg}".format(
u"Failed to execute query (execute_2darray) for the server "
u"#{server_id} - {conn_id} (Query-id: {query_id}):\n"
u"Error Message:{errmsg}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
@ -998,7 +1026,7 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
# Get Resultset Column Name, Type and size
columns = cur.description and [
desc.to_dict() for desc in cur.ordered_description()
] or []
] or []
rows = []
self.row_count = cur.rowcount
@ -1017,7 +1045,8 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
query_id = random.randint(1, 9999999)
current_app.logger.log(
25,
u"Execute (dict) for server #{server_id} - {conn_id} (Query-id: {query_id}):\n{query}".format(
u"Execute (dict) for server #{server_id} - {conn_id} (Query-id: "
u"{query_id}):\n{query}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query=query,
@ -1041,7 +1070,9 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
)
errmsg = self._formatted_exception_msg(pe, formatted_exception_msg)
current_app.logger.error(
u"Failed to execute query (execute_dict) for the server #{server_id}- {conn_id} (Query-id: {query_id}):\nError Message:{errmsg}".format(
u"Failed to execute query (execute_dict) for the server "
u"#{server_id}- {conn_id} (Query-id: {query_id}):\n"
u"Error Message:{errmsg}".format(
server_id=self.manager.sid,
conn_id=self.conn_id,
query_id=query_id,
@ -1053,7 +1084,7 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
# Get Resultset Column Name, Type and size
columns = cur.description and [
desc.to_dict() for desc in cur.ordered_description()
] or []
] or []
rows = []
self.row_count = cur.rowcount
@ -1063,7 +1094,8 @@ Failed to execute query (execute_void) for the server #{server_id} - {conn_id}
return True, {'columns': columns, 'rows': rows}
def async_fetchmany_2darray(self, records=2000, formatted_exception_msg=False):
def async_fetchmany_2darray(self, records=2000,
formatted_exception_msg=False):
"""
User should poll and check if status is ASYNC_OK before calling this
function
@ -1209,7 +1241,8 @@ Failed to reset the connection to the server due to following error:
elif state == psycopg2.extensions.POLL_READ:
select.select([conn.fileno()], [], [])
else:
raise psycopg2.OperationalError("poll() returned %s from _wait function" % state)
raise psycopg2.OperationalError(
"poll() returned %s from _wait function" % state)
def _wait_timeout(self, conn):
"""
@ -1248,10 +1281,10 @@ Failed to reset the connection to the server due to following error:
# select.select timeout option works only if we provide
# empty [] [] [] file descriptor in select.select() function
# and that also works only on UNIX based system, it do not support Windows
# Hence we have wrote our own pooling mechanism to read data fast
# each call conn.poll() reads chunks of data from connection object
# more we poll more we read data from connection
# and that also works only on UNIX based system, it do not support
# Windows Hence we have wrote our own pooling mechanism to read
# data fast each call conn.poll() reads chunks of data from
# connection object more we poll more we read data from connection
cnt = 0
while cnt < 1000:
# poll again to check the state if it is still POLL_READ
@ -1337,9 +1370,9 @@ Failed to reset the connection to the server due to following error:
result = []
# For DDL operation, we may not have result.
#
# Because - there is not direct way to differentiate DML and
# DDL operations, we need to rely on exception to figure
# that out at the moment.
# Because - there is not direct way to differentiate DML
# and DDL operations, we need to rely on exception to
# figure that out at the moment.
try:
for row in cur:
new_row = []
@ -1354,11 +1387,14 @@ Failed to reset the connection to the server due to following error:
def status_message(self):
"""
This function will return the status message returned by the last command executed on the server.
This function will return the status message returned by the last
command executed on the server.
"""
cur = self.__async_cursor
if not cur:
return gettext("Cursor could not be found for the async connection.")
return gettext(
"Cursor could not be found for the async connection."
)
current_app.logger.log(
25,
@ -1396,7 +1432,8 @@ Failed to reset the connection to the server due to following error:
did: Database id (optional)
"""
cancel_conn = self.manager.connection(did=did, conn_id=conn_id)
query = """SELECT pg_cancel_backend({0});""".format(cancel_conn.__backend_pid)
query = """SELECT pg_cancel_backend({0});""".format(
cancel_conn.__backend_pid)
status = True
msg = ''
@ -1494,9 +1531,9 @@ Failed to reset the connection to the server due to following error:
value = value.decode(pref_encoding)\
.encode('utf-8')\
.decode('utf-8')
except:
except Exception:
is_error = True
except:
except Exception:
is_error = True
# If still not able to decode then
@ -1513,7 +1550,8 @@ Failed to reset the connection to the server due to following error:
Args:
exception_obj: exception object
formatted_msg: if True then function return the formatted exception message
formatted_msg: if True then function return the formatted exception
message
"""
if exception_obj.pgerror:
@ -1704,7 +1742,7 @@ class ServerManager(object):
database = database.decode('utf-8')
if did is not None:
if did in self.db_info:
self.db_info[did]['datname']=database
self.db_info[did]['datname'] = database
else:
if did is None:
database = self.db
@ -1774,8 +1812,8 @@ WHERE db.oid = {0}""".format(did))
try:
if 'password' in data and data['password']:
data['password'] = data['password'].encode('utf-8')
except:
pass
except Exception as e:
current_app.logger.exception(e)
connections = data['connections']
for conn_id in connections:
@ -2111,9 +2149,9 @@ class Driver(BaseDriver):
'tinyint': 3,
'tinytext': 3,
'varchar2': 3
};
}
return (key in extraKeywords and extraKeywords[key]) or ScanKeyword(key)
return extraKeywords.get(key, None) or ScanKeyword(key)
@staticmethod
def needsQuoting(key, forTypes):
@ -2123,7 +2161,8 @@ class Driver(BaseDriver):
# check if the string is number or not
if isinstance(value, int):
return True
# certain types should not be quoted even though it contains a space. Evilness.
# certain types should not be quoted even though it contains a space.
# Evilness.
elif forTypes and value[-2:] == u"[]":
valNoArray = value[:-2]
@ -2173,7 +2212,8 @@ class Driver(BaseDriver):
@staticmethod
def qtTypeIdent(conn, *args):
# We're not using the conn object at the moment, but - we will modify the
# We're not using the conn object at the moment, but - we will
# modify the
# logic to use the server version specific keywords later.
res = None
value = None
@ -2200,8 +2240,8 @@ class Driver(BaseDriver):
@staticmethod
def qtIdent(conn, *args):
# We're not using the conn object at the moment, but - we will modify the
# logic to use the server version specific keywords later.
# We're not using the conn object at the moment, but - we will
# modify the logic to use the server version specific keywords later.
res = None
value = None

View File

@ -59,4 +59,5 @@ if __name__ == '__main__':
)
idx += 1
keywords_file.write('\n }\n')
keywords_file.write(' return (key in keywordDict and keywordDict[key]) or None')
keywords_file.write(
' return (key in keywordDict and keywordDict[key]) or None')

View File

@ -9,60 +9,424 @@
# ScanKeyword function for PostgreSQL 9.5rc1
def ScanKeyword(key):
keywordDict = {
"abort": 0, "absolute": 0, "access": 0, "action": 0, "add": 0, "admin": 0, "after": 0, "aggregate": 0, "all": 3,
"also": 0, "alter": 0, "always": 0, "analyze": 3, "and": 3, "any": 3, "array": 3, "as": 3, "asc": 3,
"assertion": 0, "assignment": 0, "asymmetric": 3, "at": 0, "attribute": 0, "authorization": 2, "backward": 0,
"before": 0, "begin": 0, "between": 1, "bigint": 1, "binary": 2, "bit": 1, "boolean": 1, "both": 3, "by": 0,
"cache": 0, "called": 0, "cascade": 0, "cascaded": 0, "case": 3, "cast": 3, "catalog": 0, "chain": 0, "char": 1,
"character": 1, "characteristics": 0, "check": 3, "checkpoint": 0, "class": 0, "close": 0, "cluster": 0,
"coalesce": 1, "collate": 3, "collation": 2, "column": 3, "comment": 0, "comments": 0, "commit": 0,
"committed": 0, "concurrently": 2, "configuration": 0, "conflict": 0, "connection": 0, "constraint": 3,
"constraints": 0, "content": 0, "continue": 0, "conversion": 0, "copy": 0, "cost": 0, "create": 3, "cross": 2,
"csv": 0, "cube": 0, "current": 0, "current_catalog": 3, "current_date": 3, "current_role": 3,
"current_schema": 2, "current_time": 3, "current_timestamp": 3, "current_user": 3, "cursor": 0, "cycle": 0,
"data": 0, "database": 0, "day": 0, "deallocate": 0, "dec": 1, "decimal": 1, "declare": 0, "default": 3,
"defaults": 0, "deferrable": 3, "deferred": 0, "definer": 0, "delete": 0, "delimiter": 0, "delimiters": 0,
"desc": 3, "dictionary": 0, "disable": 0, "discard": 0, "distinct": 3, "do": 3, "document": 0, "domain": 0,
"double": 0, "drop": 0, "each": 0, "else": 3, "enable": 0, "encoding": 0, "encrypted": 0, "end": 3, "enum": 0,
"escape": 0, "event": 0, "except": 3, "exclude": 0, "excluding": 0, "exclusive": 0, "execute": 0, "exists": 1,
"explain": 0, "extension": 0, "external": 0, "extract": 1, "false": 3, "family": 0, "fetch": 3, "filter": 0,
"first": 0, "float": 1, "following": 0, "for": 3, "force": 0, "foreign": 3, "forward": 0, "freeze": 2,
"from": 3, "full": 2, "function": 0, "functions": 0, "global": 0, "grant": 3, "granted": 0, "greatest": 1,
"group": 3, "grouping": 1, "handler": 0, "having": 3, "header": 0, "hold": 0, "hour": 0, "identity": 0, "if": 0,
"ilike": 2, "immediate": 0, "immutable": 0, "implicit": 0, "import": 0, "in": 3, "including": 0, "increment": 0,
"index": 0, "indexes": 0, "inherit": 0, "inherits": 0, "initially": 3, "inline": 0, "inner": 2, "inout": 1,
"input": 0, "insensitive": 0, "insert": 0, "instead": 0, "int": 1, "integer": 1, "intersect": 3, "interval": 1,
"into": 3, "invoker": 0, "is": 2, "isnull": 2, "isolation": 0, "join": 2, "key": 0, "label": 0, "language": 0,
"large": 0, "last": 0, "lateral": 3, "leading": 3, "leakproof": 0, "least": 1, "left": 2, "level": 0, "like": 2,
"limit": 3, "listen": 0, "load": 0, "local": 0, "localtime": 3, "localtimestamp": 3, "location": 0, "lock": 0,
"locked": 0, "logged": 0, "mapping": 0, "match": 0, "materialized": 0, "maxvalue": 0, "minute": 0,
"minvalue": 0, "mode": 0, "month": 0, "move": 0, "name": 0, "names": 0, "national": 1, "natural": 2, "nchar": 1,
"next": 0, "no": 0, "none": 1, "not": 3, "nothing": 0, "notify": 0, "notnull": 2, "nowait": 0, "null": 3,
"nullif": 1, "nulls": 0, "numeric": 1, "object": 0, "of": 0, "off": 0, "offset": 3, "oids": 0, "on": 3,
"only": 3, "operator": 0, "option": 0, "options": 0, "or": 3, "order": 3, "ordinality": 0, "out": 1, "outer": 2,
"over": 0, "overlaps": 2, "overlay": 1, "owned": 0, "owner": 0, "parser": 0, "partial": 0, "partition": 0,
"passing": 0, "password": 0, "placing": 3, "plans": 0, "policy": 0, "position": 1, "preceding": 0,
"precision": 1, "prepare": 0, "prepared": 0, "preserve": 0, "primary": 3, "prior": 0, "privileges": 0,
"procedural": 0, "procedure": 0, "program": 0, "quote": 0, "range": 0, "read": 0, "real": 1, "reassign": 0,
"recheck": 0, "recursive": 0, "ref": 0, "references": 3, "refresh": 0, "reindex": 0, "relative": 0,
"release": 0, "rename": 0, "repeatable": 0, "replace": 0, "replica": 0, "reset": 0, "restart": 0, "restrict": 0,
"returning": 3, "returns": 0, "revoke": 0, "right": 2, "role": 0, "rollback": 0, "rollup": 0, "row": 1,
"rows": 0, "rule": 0, "savepoint": 0, "schema": 0, "scroll": 0, "search": 0, "second": 0, "security": 0,
"select": 3, "sequence": 0, "sequences": 0, "serializable": 0, "server": 0, "session": 0, "session_user": 3,
"set": 0, "setof": 1, "sets": 0, "share": 0, "show": 0, "similar": 2, "simple": 0, "skip": 0, "smallint": 1,
"snapshot": 0, "some": 3, "sql": 0, "stable": 0, "standalone": 0, "start": 0, "statement": 0, "statistics": 0,
"stdin": 0, "stdout": 0, "storage": 0, "strict": 0, "strip": 0, "substring": 1, "symmetric": 3, "sysid": 0,
"system": 0, "table": 3, "tables": 0, "tablesample": 2, "tablespace": 0, "temp": 0, "template": 0,
"temporary": 0, "text": 0, "then": 3, "time": 1, "timestamp": 1, "to": 3, "trailing": 3, "transaction": 0,
"transform": 0, "treat": 1, "trigger": 0, "trim": 1, "true": 3, "truncate": 0, "trusted": 0, "type": 0,
"types": 0, "unbounded": 0, "uncommitted": 0, "unencrypted": 0, "union": 3, "unique": 3, "unknown": 0,
"unlisten": 0, "unlogged": 0, "until": 0, "update": 0, "user": 3, "using": 3, "vacuum": 0, "valid": 0,
"validate": 0, "validator": 0, "value": 0, "values": 1, "varchar": 1, "variadic": 3, "varying": 0, "verbose": 2,
"version": 0, "view": 0, "views": 0, "volatile": 0, "when": 3, "where": 3, "whitespace": 0, "window": 3,
"with": 3, "within": 0, "without": 0, "work": 0, "wrapper": 0, "write": 0, "xml": 0, "xmlattributes": 1,
"xmlconcat": 1, "xmlelement": 1, "xmlexists": 1, "xmlforest": 1, "xmlparse": 1, "xmlpi": 1, "xmlroot": 1,
"xmlserialize": 1, "year": 0, "yes": 0, "zone": 0
'abort': 0,
'absolute': 0,
'access': 0,
'action': 0,
'add': 0,
'admin': 0,
'after': 0,
'aggregate': 0,
'all': 3,
'also': 0,
'alter': 0,
'always': 0,
'analyze': 3,
'and': 3,
'any': 3,
'array': 3,
'as': 3,
'asc': 3,
'assertion': 0,
'assignment': 0,
'asymmetric': 3,
'at': 0,
'attribute': 0,
'authorization': 2,
'backward': 0,
'before': 0,
'begin': 0,
'between': 1,
'bigint': 1,
'binary': 2,
'bit': 1,
'boolean': 1,
'both': 3,
'by': 0,
'cache': 0,
'called': 0,
'cascade': 0,
'cascaded': 0,
'case': 3,
'cast': 3,
'catalog': 0,
'chain': 0,
'char': 1,
'character': 1,
'characteristics': 0,
'check': 3,
'checkpoint': 0,
'class': 0,
'close': 0,
'cluster': 0,
'coalesce': 1,
'collate': 3,
'collation': 2,
'column': 3,
'comment': 0,
'comments': 0,
'commit': 0,
'committed': 0,
'concurrently': 2,
'configuration': 0,
'conflict': 0,
'connection': 0,
'constraint': 3,
'constraints': 0,
'content': 0,
'continue': 0,
'conversion': 0,
'copy': 0,
'cost': 0,
'create': 3,
'cross': 2,
'csv': 0,
'cube': 0,
'current': 0,
'current_catalog': 3,
'current_date': 3,
'current_role': 3,
'current_schema': 2,
'current_time': 3,
'current_timestamp': 3,
'current_user': 3,
'cursor': 0,
'cycle': 0,
'data': 0,
'database': 0,
'day': 0,
'deallocate': 0,
'dec': 1,
'decimal': 1,
'declare': 0,
'default': 3,
'defaults': 0,
'deferrable': 3,
'deferred': 0,
'definer': 0,
'delete': 0,
'delimiter': 0,
'delimiters': 0,
'desc': 3,
'dictionary': 0,
'disable': 0,
'discard': 0,
'distinct': 3,
'do': 3,
'document': 0,
'domain': 0,
'double': 0,
'drop': 0,
'each': 0,
'else': 3,
'enable': 0,
'encoding': 0,
'encrypted': 0,
'end': 3,
'enum': 0,
'escape': 0,
'event': 0,
'except': 3,
'exclude': 0,
'excluding': 0,
'exclusive': 0,
'execute': 0,
'exists': 1,
'explain': 0,
'extension': 0,
'external': 0,
'extract': 1,
'false': 3,
'family': 0,
'fetch': 3,
'filter': 0,
'first': 0,
'float': 1,
'following': 0,
'for': 3,
'force': 0,
'foreign': 3,
'forward': 0,
'freeze': 2,
'from': 3,
'full': 2,
'function': 0,
'functions': 0,
'global': 0,
'grant': 3,
'granted': 0,
'greatest': 1,
'group': 3,
'grouping': 1,
'handler': 0,
'having': 3,
'header': 0,
'hold': 0,
'hour': 0,
'identity': 0,
'if': 0,
'ilike': 2,
'immediate': 0,
'immutable': 0,
'implicit': 0,
'import': 0,
'in': 3,
'including': 0,
'increment': 0,
'index': 0,
'indexes': 0,
'inherit': 0,
'inherits': 0,
'initially': 3,
'inline': 0,
'inner': 2,
'inout': 1,
'input': 0,
'insensitive': 0,
'insert': 0,
'instead': 0,
'int': 1,
'integer': 1,
'intersect': 3,
'interval': 1,
'into': 3,
'invoker': 0,
'is': 2,
'isnull': 2,
'isolation': 0,
'join': 2,
'key': 0,
'label': 0,
'language': 0,
'large': 0,
'last': 0,
'lateral': 3,
'leading': 3,
'leakproof': 0,
'least': 1,
'left': 2,
'level': 0,
'like': 2,
'limit': 3,
'listen': 0,
'load': 0,
'local': 0,
'localtime': 3,
'localtimestamp': 3,
'location': 0,
'lock': 0,
'locked': 0,
'logged': 0,
'mapping': 0,
'match': 0,
'materialized': 0,
'maxvalue': 0,
'minute': 0,
'minvalue': 0,
'mode': 0,
'month': 0,
'move': 0,
'name': 0,
'names': 0,
'national': 1,
'natural': 2,
'nchar': 1,
'next': 0,
'no': 0,
'none': 1,
'not': 3,
'nothing': 0,
'notify': 0,
'notnull': 2,
'nowait': 0,
'null': 3,
'nullif': 1,
'nulls': 0,
'numeric': 1,
'object': 0,
'of': 0,
'off': 0,
'offset': 3,
'oids': 0,
'on': 3,
'only': 3,
'operator': 0,
'option': 0,
'options': 0,
'or': 3,
'order': 3,
'ordinality': 0,
'out': 1,
'outer': 2,
'over': 0,
'overlaps': 2,
'overlay': 1,
'owned': 0,
'owner': 0,
'parser': 0,
'partial': 0,
'partition': 0,
'passing': 0,
'password': 0,
'placing': 3,
'plans': 0,
'policy': 0,
'position': 1,
'preceding': 0,
'precision': 1,
'prepare': 0,
'prepared': 0,
'preserve': 0,
'primary': 3,
'prior': 0,
'privileges': 0,
'procedural': 0,
'procedure': 0,
'program': 0,
'quote': 0,
'range': 0,
'read': 0,
'real': 1,
'reassign': 0,
'recheck': 0,
'recursive': 0,
'ref': 0,
'references': 3,
'refresh': 0,
'reindex': 0,
'relative': 0,
'release': 0,
'rename': 0,
'repeatable': 0,
'replace': 0,
'replica': 0,
'reset': 0,
'restart': 0,
'restrict': 0,
'returning': 3,
'returns': 0,
'revoke': 0,
'right': 2,
'role': 0,
'rollback': 0,
'rollup': 0,
'row': 1,
'rows': 0,
'rule': 0,
'savepoint': 0,
'schema': 0,
'scroll': 0,
'search': 0,
'second': 0,
'security': 0,
'select': 3,
'sequence': 0,
'sequences': 0,
'serializable': 0,
'server': 0,
'session': 0,
'session_user': 3,
'set': 0,
'setof': 1,
'sets': 0,
'share': 0,
'show': 0,
'similar': 2,
'simple': 0,
'skip': 0,
'smallint': 1,
'snapshot': 0,
'some': 3,
'sql': 0,
'stable': 0,
'standalone': 0,
'start': 0,
'statement': 0,
'statistics': 0,
'stdin': 0,
'stdout': 0,
'storage': 0,
'strict': 0,
'strip': 0,
'substring': 1,
'symmetric': 3,
'sysid': 0,
'system': 0,
'table': 3,
'tables': 0,
'tablesample': 2,
'tablespace': 0,
'temp': 0,
'template': 0,
'temporary': 0,
'text': 0,
'then': 3,
'time': 1,
'timestamp': 1,
'to': 3,
'trailing': 3,
'transaction': 0,
'transform': 0,
'treat': 1,
'trigger': 0,
'trim': 1,
'true': 3,
'truncate': 0,
'trusted': 0,
'type': 0,
'types': 0,
'unbounded': 0,
'uncommitted': 0,
'unencrypted': 0,
'union': 3,
'unique': 3,
'unknown': 0,
'unlisten': 0,
'unlogged': 0,
'until': 0,
'update': 0,
'user': 3,
'using': 3,
'vacuum': 0,
'valid': 0,
'validate': 0,
'validator': 0,
'value': 0,
'values': 1,
'varchar': 1,
'variadic': 3,
'varying': 0,
'verbose': 2,
'version': 0,
'view': 0,
'views': 0,
'volatile': 0,
'when': 3,
'where': 3,
'whitespace': 0,
'window': 3,
'with': 3,
'within': 0,
'without': 0,
'work': 0,
'wrapper': 0,
'write': 0,
'xml': 0,
'xmlattributes': 1,
'xmlconcat': 1,
'xmlelement': 1,
'xmlexists': 1,
'xmlforest': 1,
'xmlparse': 1,
'xmlpi': 1,
'xmlroot': 1,
'xmlserialize': 1,
'year': 0,
'yes': 0,
'zone': 0,
}
return (key in keywordDict and keywordDict[key]) or None
return keywordDict.get(key, None)

View File

@ -27,14 +27,14 @@ from psycopg2.extensions import encodings
# string.
TO_STRING_DATATYPES = (
# To cast bytea, interval type
17, 1186,
# To cast bytea, interval type
17, 1186,
# date, timestamp, timestamptz, bigint, double precision
1700, 1082, 1114, 1184, 20, 701,
# date, timestamp, timestamptz, bigint, double precision
1700, 1082, 1114, 1184, 20, 701,
# real, time without time zone
700, 1083
# real, time without time zone
700, 1083
)
# OIDs of array data types which need to typecast to array of string.
@ -45,17 +45,17 @@ TO_STRING_DATATYPES = (
# data type. e.g: uuid, bit, varbit, etc.
TO_ARRAY_OF_STRING_DATATYPES = (
# To cast bytea[] type
1001,
# To cast bytea[] type
1001,
# bigint[]
1016,
# bigint[]
1016,
# double precision[], real[]
1022, 1021,
# double precision[], real[]
1022, 1021,
# bit[], varbit[]
1561, 1563,
# bit[], varbit[]
1561, 1563,
)
# OID of record array data type
@ -96,7 +96,7 @@ PSYCOPG_SUPPORTED_JSON_TYPES = (114, 3802)
PSYCOPG_SUPPORTED_JSON_ARRAY_TYPES = (199, 3807)
ALL_JSON_TYPES = PSYCOPG_SUPPORTED_JSON_TYPES +\
PSYCOPG_SUPPORTED_JSON_ARRAY_TYPES
PSYCOPG_SUPPORTED_JSON_ARRAY_TYPES
# INET[], CIDR[]
@ -150,10 +150,11 @@ def register_global_typecasters():
# define type caster to convert pg array types of above types into
# array of string type
pg_array_types_to_array_of_string_type = psycopg2.extensions.new_array_type(
TO_ARRAY_OF_STRING_DATATYPES,
'TYPECAST_TO_ARRAY_OF_STRING', pg_types_to_string_type
)
pg_array_types_to_array_of_string_type = \
psycopg2.extensions.new_array_type(
TO_ARRAY_OF_STRING_DATATYPES,
'TYPECAST_TO_ARRAY_OF_STRING', pg_types_to_string_type
)
# This registers a type caster to convert various pg types into string type
psycopg2.extensions.register_type(pg_types_to_string_type)
@ -212,10 +213,11 @@ def register_binary_typecasters(connection):
(
# To cast bytea type
17,
),
),
'BYTEA_PLACEHOLDER',
# Only show placeholder if data actually exists.
lambda value, cursor: 'binary data' if value is not None else None),
lambda value, cursor: 'binary data'
if value is not None else None),
connection
)
@ -224,10 +226,11 @@ def register_binary_typecasters(connection):
(
# To cast bytea[] type
1001,
),
),
'BYTEA_ARRAY_PLACEHOLDER',
# Only show placeholder if data actually exists.
lambda value, cursor: 'binary data[]' if value is not None else None),
lambda value, cursor: 'binary data[]'
if value is not None else None),
connection
)
@ -244,7 +247,3 @@ def register_array_to_string_typecasters(connection):
_STRING),
connection
)

View File

@ -13,10 +13,10 @@ from flask_babel import gettext
def _decorate_cls_name(module_name):
l = len(__package__) + 1
length = len(__package__) + 1
if len(module_name) > l and module_name.startswith(__package__):
return module_name[l:]
if len(module_name) > length and module_name.startswith(__package__):
return module_name[length:]
return module_name

View File

@ -30,6 +30,6 @@ def safe_str(x):
if not IS_PY2:
x = x.decode('utf-8')
except:
except Exception:
pass
return cgi.escape(x)

View File

@ -32,8 +32,10 @@ class JavascriptBundler:
self.jsState = JsState.NEW
except OSError:
webdir_path()
generatedJavascriptDir = os.path.join(webdir_path(), 'pgadmin', 'static', 'js', 'generated')
if os.path.exists(generatedJavascriptDir) and os.listdir(generatedJavascriptDir):
generatedJavascriptDir = os.path.join(
webdir_path(), 'pgadmin', 'static', 'js', 'generated')
if os.path.exists(generatedJavascriptDir) and \
os.listdir(generatedJavascriptDir):
self.jsState = JsState.OLD
else:
self.jsState = JsState.NONE

View File

@ -19,9 +19,10 @@ from pgadmin.utils.route import BaseTestGenerator
from pgadmin.utils.javascript.javascript_bundler import JavascriptBundler
from pgadmin.utils.javascript.javascript_bundler import JsState
class JavascriptBundlerTestCase(BaseTestGenerator):
"""This tests that the javascript bundler tool causes the application to bundle,
and can be invoked before and after app start correctly"""
"""This tests that the javascript bundler tool causes the application to
bundle,and can be invoked before and after app start correctly"""
scenarios = [('scenario name: JavascriptBundlerTestCase', dict())]
@ -59,10 +60,12 @@ class JavascriptBundlerTestCase(BaseTestGenerator):
self.assertEqual(len(self.mockSubprocessCall.method_calls), 0)
self.mockSubprocessCall.return_value = 0
self.mockOs.listdir.return_value = [u'history.js', u'reactComponents.js']
self.mockOs.listdir.return_value = [
u'history.js', u'reactComponents.js']
javascript_bundler.bundle()
self.mockSubprocessCall.assert_called_once_with(['yarn', 'run', 'bundle:dev'])
self.mockSubprocessCall.assert_called_once_with(
['yarn', 'run', 'bundle:dev'])
self.__assertState(javascript_bundler, JsState.NEW)
@ -78,7 +81,8 @@ class JavascriptBundlerTestCase(BaseTestGenerator):
def _bundling_fails_and_there_is_no_existing_bundle(self):
javascript_bundler = JavascriptBundler()
self.mockSubprocessCall.side_effect = OSError("mock exception behavior")
self.mockSubprocessCall.side_effect = OSError(
"mock exception behavior")
self.mockOs.path.exists.return_value = True
self.mockOs.listdir.return_value = []
@ -88,7 +92,8 @@ class JavascriptBundlerTestCase(BaseTestGenerator):
def _bundling_fails_and_there_is_no_existing_bundle_directory(self):
javascript_bundler = JavascriptBundler()
self.mockSubprocessCall.side_effect = OSError("mock exception behavior")
self.mockSubprocessCall.side_effect = OSError(
"mock exception behavior")
self.mockOs.path.exists.return_value = False
self.mockOs.listdir.side_effect = OSError("mock exception behavior")
@ -98,12 +103,15 @@ class JavascriptBundlerTestCase(BaseTestGenerator):
def _bundling_fails_but_there_was_existing_bundle(self):
javascript_bundler = JavascriptBundler()
self.mockSubprocessCall.side_effect = OSError("mock exception behavior")
self.mockSubprocessCall.side_effect = OSError(
"mock exception behavior")
self.mockOs.path.exists.return_value = True
self.mockOs.listdir.return_value = [u'history.js', u'reactComponents.js']
self.mockOs.listdir.return_value = [
u'history.js', u'reactComponents.js']
javascript_bundler.bundle()
self.mockSubprocessCall.assert_called_once_with(['yarn', 'run', 'bundle:dev'])
self.mockSubprocessCall.assert_called_once_with(
['yarn', 'run', 'bundle:dev'])
self.__assertState(javascript_bundler, JsState.OLD)

View File

@ -16,9 +16,11 @@ class MenuItem(object):
class Panel(object):
def __init__(self, name, title, content='', width=500, height=600, isIframe=True,
showTitle=True, isCloseable=True, isPrivate=False, priority=None,
icon=None, data=None, events=None, limit=None, canHide=False):
def __init__(
self, name, title, content='', width=500, height=600, isIframe=True,
showTitle=True, isCloseable=True, isPrivate=False, priority=None,
icon=None, data=None, events=None, limit=None, canHide=False
):
self.name = name
self.title = title
self.content = content

View File

@ -14,7 +14,6 @@ import os
from flask_security import current_user, login_required
@login_required
def get_storage_directory():
import config
@ -38,8 +37,8 @@ def get_storage_directory():
username = 'pga_user_' + username
storage_dir = os.path.join(
storage_dir.decode('utf-8') if hasattr(storage_dir, 'decode') \
else storage_dir,
storage_dir.decode('utf-8') if hasattr(storage_dir, 'decode')
else storage_dir,
username
)
@ -66,11 +65,13 @@ def init_app(app):
if storage_dir and not os.path.isdir(storage_dir):
if os.path.exists(storage_dir):
raise Exception(
'The path specified for the storage directory is not a directory.'
'The path specified for the storage directory is not a '
'directory.'
)
os.makedirs(storage_dir, int('700', 8))
if storage_dir and not os.access(storage_dir, os.W_OK | os.R_OK):
raise Exception(
'The user does not have permission to read and write to the specified storage directory.'
'The user does not have permission to read and write to the '
'specified storage directory.'
)

View File

@ -31,8 +31,8 @@ class _Preference(object):
"""
def __init__(
self, cid, name, label, _type, default, help_str=None, min_val=None,
max_val=None, options=None, select2=None, fields=None
self, cid, name, label, _type, default, help_str=None,
min_val=None, max_val=None, options=None, select2=None, fields=None
):
"""
__init__
@ -111,7 +111,7 @@ class _Preference(object):
# The data stored in the configuration will be in string format, we
# need to convert them in proper format.
if self._type == 'boolean' or self._type == 'switch' or \
self._type == 'node':
self._type == 'node':
return res.value == 'True'
if self._type == 'integer':
try:
@ -162,7 +162,7 @@ class _Preference(object):
# We can't store the values in the given format, we need to convert
# them in string first. We also need to validate the value type.
if self._type == 'boolean' or self._type == 'switch' or \
self._type == 'node':
self._type == 'node':
if type(value) != bool:
return False, gettext("Invalid value for a boolean option.")
elif self._type == 'integer':
@ -518,17 +518,19 @@ class Preferences(object):
if _user_id is None:
return None
cat = PrefCategoryTbl.query.filter_by(mid=module.id).filter_by(name=_category).first()
cat = PrefCategoryTbl.query.filter_by(
mid=module.id).filter_by(name=_category).first()
if cat is None:
return None
pref = PrefTable.query.filter_by(name=_preference).filter_by(cid=cat.id).first()
pref = PrefTable.query.filter_by(
name=_preference).filter_by(cid=cat.id).first()
if pref is None:
return None
user_pref = UserPrefTable.query.filter_by(
user_pref = UserPrefTable.query.filter_by(
pid=pref.id
).filter_by(uid=_user_id).first()

View File

@ -33,8 +33,9 @@ class TestsGeneratorRegistry(ABCMeta):
call this function explicitly. This will be automatically executed,
whenever we create a class and inherit from BaseTestGenerator -
it will register it as an available module in TestsGeneratorRegistry.
By setting the __metaclass__ for BaseTestGenerator to TestsGeneratorRegistry
it will create new instance of this TestsGeneratorRegistry per class.
By setting the __metaclass__ for BaseTestGenerator to
TestsGeneratorRegistry it will create new instance of this
TestsGeneratorRegistry per class.
* load_generators():
- This function will load all the modules from __init__()
@ -66,7 +67,9 @@ class TestsGeneratorRegistry(ABCMeta):
for module_name in all_modules:
try:
if "tests." in str(module_name) and not any(
str(module_name).startswith('pgadmin.' + str(exclude_pkg)) for exclude_pkg in exclude_pkgs
str(module_name).startswith(
'pgadmin.' + str(exclude_pkg)
) for exclude_pkg in exclude_pkgs
):
import_module(module_name)
except ImportError:

View File

@ -27,12 +27,12 @@ from uuid import uuid4
try:
from cPickle import dump, load
except:
except ImportError:
from pickle import dump, load
try:
from collections import OrderedDict
except:
except ImportError:
from ordereddict import OrderedDict
from flask.sessions import SessionInterface, SessionMixin
@ -48,7 +48,8 @@ def _calc_hmac(body, secret):
class ManagedSession(CallbackDict, SessionMixin):
def __init__(self, initial=None, sid=None, new=False, randval=None, hmac_digest=None):
def __init__(self, initial=None, sid=None, new=False, randval=None,
hmac_digest=None):
def on_update(self):
self.modified = True
@ -71,7 +72,8 @@ class ManagedSession(CallbackDict, SessionMixin):
population += string.digits
self.randval = ''.join(random.sample(population, 20))
self.hmac_digest = _calc_hmac('%s:%s' % (self.sid, self.randval), secret)
self.hmac_digest = _calc_hmac(
'%s:%s' % (self.sid, self.randval), secret)
class SessionManager(object):
@ -148,7 +150,7 @@ class CachingSessionManager(SessionManager):
if session.sid in self._cache:
try:
del self._cache[session.sid]
except:
except Exception:
pass
self._cache[session.sid] = session
self._normalize()
@ -198,7 +200,7 @@ class FileBackedSessionManager(SessionManager):
try:
with open(fname, 'rb') as f:
randval, hmac_digest, data = load(f)
except:
except Exception:
pass
if not data:
@ -221,8 +223,9 @@ class FileBackedSessionManager(SessionManager):
if not session.hmac_digest:
session.sign(self.secret)
elif not session.force_write:
if session.last_write is not None \
and (current_time - float(session.last_write)) < self.disk_write_delay:
if session.last_write is not None and \
(current_time - float(session.last_write)) < \
self.disk_write_delay:
return
session.last_write = current_time
@ -249,7 +252,7 @@ class ManagedSessionInterface(SessionInterface):
def open_session(self, app, request):
cookie_val = request.cookies.get(app.session_cookie_name)
if not cookie_val or not '!' in cookie_val:
if not cookie_val or '!' not in cookie_val:
# Don't bother creating a cookie for static resources
for sp in self.skip_paths:
if request.path.startswith(sp):
@ -282,9 +285,11 @@ class ManagedSessionInterface(SessionInterface):
session.modified = False
cookie_exp = self.get_expiration_time(app, session)
response.set_cookie(app.session_cookie_name,
'%s!%s' % (session.sid, session.hmac_digest),
expires=cookie_exp, httponly=True, domain=domain)
response.set_cookie(
app.session_cookie_name,
'%s!%s' % (session.sid, session.hmac_digest),
expires=cookie_exp, httponly=True, domain=domain
)
def create_session_interface(app, skip_paths=[]):

View File

@ -94,7 +94,8 @@ class SQLAutoComplete(object):
self.search_path = []
# Fetch the search path
if self.conn.connected():
query = render_template("/".join([self.sql_path, 'schema.sql']), search_path=True)
query = render_template(
"/".join([self.sql_path, 'schema.sql']), search_path=True)
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
@ -117,7 +118,7 @@ class SQLAutoComplete(object):
def escape_name(self, name):
if name and ((not self.name_pattern.match(name)) or
(name.upper() in self.reserved_words)):
(name.upper() in self.reserved_words)):
name = '"%s"' % name
return name
@ -219,7 +220,8 @@ class SQLAutoComplete(object):
# with 1 to prioritize shorter strings (ie "user" > "users").
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = tuple(-ord(c) for c in self.unescape_name(item)) + (1,)
lexical_priority = tuple(
-ord(c) for c in self.unescape_name(item)) + (1,)
priority = sort_key, priority_func(item), lexical_priority
@ -252,7 +254,7 @@ class SQLAutoComplete(object):
for m in matches:
# Escape name only if meta type is not a keyword and datatype.
if m.completion.display_meta != 'keyword' and \
m.completion.display_meta != 'datatype':
m.completion.display_meta != 'datatype':
name = self.escape_name(m.completion.display)
else:
name = m.completion.display
@ -272,20 +274,24 @@ class SQLAutoComplete(object):
in Counter(scoped_cols).items()
if count > 1 and col != '*']
return self.find_matches(word_before_cursor, scoped_cols, mode='strict', meta='column')
return self.find_matches(
word_before_cursor, scoped_cols, mode='strict', meta='column'
)
def get_function_matches(self, suggestion, word_before_cursor):
if suggestion.filter == 'is_set_returning':
# Only suggest set-returning functions
funcs = self.populate_functions(suggestion.schema)
else:
funcs = self.populate_schema_objects(suggestion.schema, 'functions')
funcs = self.populate_schema_objects(
suggestion.schema, 'functions')
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
funcs = set(funcs)
funcs = self.find_matches(word_before_cursor, funcs, mode='strict', meta='function')
funcs = self.find_matches(
word_before_cursor, funcs, mode='strict', meta='function')
return funcs
@ -303,7 +309,9 @@ class SQLAutoComplete(object):
if not word_before_cursor.startswith('pg_'):
schema_names = [s for s in schema_names if not s.startswith('pg_')]
return self.find_matches(word_before_cursor, schema_names, mode='strict', meta='schema')
return self.find_matches(
word_before_cursor, schema_names, mode='strict', meta='schema'
)
def get_table_matches(self, suggestion, word_before_cursor):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
@ -314,7 +322,9 @@ class SQLAutoComplete(object):
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.startswith('pg_')]
return self.find_matches(word_before_cursor, tables, mode='strict', meta='table')
return self.find_matches(
word_before_cursor, tables, mode='strict', meta='table'
)
def get_view_matches(self, suggestion, word_before_cursor):
views = self.populate_schema_objects(suggestion.schema, 'views')
@ -323,7 +333,9 @@ class SQLAutoComplete(object):
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.startswith('pg_')]
return self.find_matches(word_before_cursor, views, mode='strict', meta='view')
return self.find_matches(
word_before_cursor, views, mode='strict', meta='view'
)
def get_alias_matches(self, suggestion, word_before_cursor):
aliases = suggestion.aliases
@ -350,7 +362,8 @@ class SQLAutoComplete(object):
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
matches = self.find_matches(word_before_cursor, types, mode='strict', meta='datatype')
matches = self.find_matches(
word_before_cursor, types, mode='strict', meta='datatype')
return matches
@ -366,7 +379,9 @@ class SQLAutoComplete(object):
if self.text_before_cursor[-1:].isspace():
return ''
else:
return self.text_before_cursor[self.find_start_of_previous_word(word=word):]
return self.text_before_cursor[self.find_start_of_previous_word(
word=word
):]
def find_start_of_previous_word(self, count=1, word=False):
"""
@ -418,19 +433,23 @@ class SQLAutoComplete(object):
relname = self.escape_name(tbl.name)
if tbl.is_function:
query = render_template("/".join([self.sql_path, 'functions.sql']),
schema_name=schema,
func_name=relname)
query = render_template(
"/".join([self.sql_path, 'functions.sql']),
schema_name=schema,
func_name=relname
)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
func = None
if status:
for row in res['rows']:
func = FunctionMetadata(row['schema_name'], row['func_name'],
row['arg_list'], row['return_type'],
row['is_aggregate'], row['is_window'],
row['is_set_returning'])
func = FunctionMetadata(
row['schema_name'], row['func_name'],
row['arg_list'], row['return_type'],
row['is_aggregate'], row['is_window'],
row['is_set_returning']
)
if func:
columns.extend(func.fieldnames())
else:
@ -438,77 +457,98 @@ class SQLAutoComplete(object):
# tables and views cannot share the same name, we can check
# one at a time
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='table',
schema_name=schema,
rel_name=relname)
query = render_template(
"/".join([self.sql_path, 'columns.sql']),
object_name='table',
schema_name=schema,
rel_name=relname
)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
if len(res['rows']) > 0:
# Table exists, so don't bother checking for a view
# Table exists, so don't bother checking for a
# view
for record in res['rows']:
columns.append(record['column_name'])
else:
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='view',
schema_name=schema,
rel_name=relname)
query = render_template(
"/".join([self.sql_path, 'columns.sql']),
object_name='view',
schema_name=schema,
rel_name=relname
)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
columns.append(record['column_name'])
columns.append(
record['column_name'])
else:
# Schema not specified, so traverse the search path looking for
# a table or view that matches. Note that in order to get proper
# shadowing behavior, we need to check both views and tables for
# each schema before checking the next schema
# a table or view that matches. Note that in order to get
# proper shadowing behavior, we need to check both views and
# tables for each schema before checking the next schema
for schema in self.search_path:
relname = self.escape_name(tbl.name)
if tbl.is_function:
query = render_template("/".join([self.sql_path, 'functions.sql']),
schema_name=schema,
func_name=relname)
query = render_template(
"/".join([self.sql_path, 'functions.sql']),
schema_name=schema,
func_name=relname
)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
func = None
if status:
for row in res['rows']:
func = FunctionMetadata(row['schema_name'], row['func_name'],
row['arg_list'], row['return_type'],
row['is_aggregate'], row['is_window'],
row['is_set_returning'])
func = FunctionMetadata(
row['schema_name'], row['func_name'],
row['arg_list'], row['return_type'],
row['is_aggregate'], row['is_window'],
row['is_set_returning']
)
if func:
columns.extend(func.fieldnames())
else:
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='table',
schema_name=schema,
rel_name=relname)
query = render_template(
"/".join([self.sql_path, 'columns.sql']),
object_name='table',
schema_name=schema,
rel_name=relname
)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
if len(res['rows']) > 0:
# Table exists, so don't bother checking for a view
# Table exists, so don't bother checking
# for a view
for record in res['rows']:
columns.append(record['column_name'])
else:
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='view',
schema_name=schema,
rel_name=relname)
query = render_template(
"/".join(
[self.sql_path, 'columns.sql']
),
object_name='view',
schema_name=schema,
rel_name=relname
)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
status, res = self.conn.execute_dict(
query
)
if status:
for record in res['rows']:
columns.append(record['column_name'])
columns.append(
record['column_name']
)
return columns
@ -600,21 +640,23 @@ class SQLAutoComplete(object):
Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
A scope for a column category will be a list of tables.
Returns a tuple with a type of entity ('table', 'column' etc) and a
scope. A scope for a column category will be a list of tables.
Args:
full_text: Contains complete query
text_before_cursor: Contains text before the cursor
"""
word_before_cursor = last_word(text_before_cursor, include='many_punctuations')
word_before_cursor = last_word(
text_before_cursor, include='many_punctuations')
identifier = None
def strip_named_query(txt):
"""
This will strip "save named query" command in the beginning of the line:
This will strip "save named query" command in the beginning of
the line:
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
@ -630,11 +672,12 @@ class SQLAutoComplete(object):
full_text = strip_named_query(full_text)
text_before_cursor = strip_named_query(text_before_cursor)
# If we've partially typed a word then word_before_cursor won't be an empty
# string. In that case we want to remove the partially typed string before
# sending it to the sqlparser. Otherwise the last token will always be the
# partially typed string which renders the smart completion useless because
# it will always return the list of keywords as completion.
# If we've partially typed a word then word_before_cursor won't be an
# empty string. In that case we want to remove the partially typed
# string before sending it to the sqlparser. Otherwise the last token
# will always be the partially typed string which renders the smart
# completion useless because it will always return the list of
# keywords as completion.
if word_before_cursor:
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
parsed = sqlparse.parse(text_before_cursor)
@ -649,8 +692,8 @@ class SQLAutoComplete(object):
statement = None
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
# cumulatively summing statement lengths to find the one that bounds the
# current position
# cumulatively summing statement lengths to find the one that
# bounds the current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
@ -670,12 +713,16 @@ class SQLAutoComplete(object):
# The empty string
statement = None
last_token = statement and statement.token_prev(len(statement.tokens)) or ''
last_token = statement and statement.token_prev(
len(statement.tokens)
) or ''
return self.suggest_based_on_last_token(last_token, text_before_cursor,
full_text, identifier)
return self.suggest_based_on_last_token(
last_token, text_before_cursor, full_text, identifier
)
def suggest_based_on_last_token(self, token, text_before_cursor, full_text, identifier):
def suggest_based_on_last_token(self, token, text_before_cursor, full_text,
identifier):
# New version of sqlparse sends tuple, we need to make it
# compatible with our logic
if isinstance(token, tuple) and len(token) > 1:
@ -686,33 +733,37 @@ class SQLAutoComplete(object):
elif isinstance(token, Comparison):
# If 'token' is a Comparison type such as
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
# token.value on the comparison type will only return the lhs of the
# comparison. In this case a.id. So we need to do token.tokens to get
# both sides of the comparison and pick the last token out of that
# list.
# token.value on the comparison type will only return the lhs of
# the comparison. In this case a.id. So we need to do token.tokens
# to get both sides of the comparison and pick the last token out
# of that list.
token_v = token.tokens[-1].value.lower()
elif isinstance(token, Where):
# sqlparse groups all tokens from the where clause into a single token
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
# sqlparse groups all tokens from the where clause into a single
# token list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to
# handle suggestions in complicated where clauses correctly
prev_keyword, text_before_cursor = find_prev_keyword(
text_before_cursor
)
return self.suggest_based_on_last_token(
prev_keyword, text_before_cursor, full_text, identifier)
prev_keyword, text_before_cursor, full_text, identifier
)
elif isinstance(token, Identifier):
# If the previous token is an identifier, we can suggest datatypes if
# we're in a parenthesized column/field list, e.g.:
# If the previous token is an identifier, we can suggest datatypes
# if we're in a parenthesized column/field list, e.g.:
# CREATE TABLE foo (Identifier <CURSOR>
# CREATE FUNCTION foo (Identifier <CURSOR>
# If we're not in a parenthesized list, the most likely scenario is the
# user is about to specify an alias, e.g.:
# If we're not in a parenthesized list, the most likely scenario
# is the user is about to specify an alias, e.g.:
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(text_before_cursor)
if prev_keyword and prev_keyword.value == '(':
# Suggest datatypes
return self.suggest_based_on_last_token(
'type', text_before_cursor, full_text, identifier)
'type', text_before_cursor, full_text, identifier
)
else:
return Keyword(),
else:
@ -732,11 +783,13 @@ class SQLAutoComplete(object):
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
# Suggest columns/functions AND keywords. (If we wanted
# to be really fancy, we could suggest only array-typed
# columns)
column_suggestions = self.suggest_based_on_last_token(
'where', text_before_cursor, full_text, identifier)
'where', text_before_cursor, full_text, identifier
)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
@ -754,7 +807,8 @@ class SQLAutoComplete(object):
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)
if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
if prev_tok and prev_tok.value and \
prev_tok.value.lower() == 'using':
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = extract_tables(full_text)
@ -762,8 +816,8 @@ class SQLAutoComplete(object):
return Column(tables=tables, drop_unique=True),
elif p.token_first().value.lower() == 'select':
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
# If the lparen is preceeded by a space chances are we're
# about to do a sub-select.
if last_word(text_before_cursor,
'all_punctuations').startswith('('):
return Keyword(),
@ -788,7 +842,8 @@ class SQLAutoComplete(object):
Keyword(),)
elif (token_v.endswith('join') and token.is_keyword) or \
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate')):
(token_v in ('copy', 'from', 'update', 'into',
'describe', 'truncate')):
schema = (identifier and identifier.get_parent_name()) or None
@ -805,14 +860,18 @@ class SQLAutoComplete(object):
suggest.append(View(schema=schema))
# Suggest set-returning functions in the FROM clause
if token_v == 'from' or (token_v.endswith('join') and token.is_keyword):
suggest.append(Function(schema=schema, filter='is_set_returning'))
if token_v == 'from' or \
(token_v.endswith('join') and token.is_keyword):
suggest.append(
Function(schema=schema, filter='is_set_returning')
)
return tuple(suggest)
elif token_v in ('table', 'view', 'function'):
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
rel_type = {'table': Table, 'view': View, 'function': Function}[token_v]
rel_type = {'table': Table, 'view': View,
'function': Function}[token_v]
schema = (identifier and identifier.get_parent_name()) or None
if schema:
return rel_type(schema=schema),
@ -843,7 +902,8 @@ class SQLAutoComplete(object):
# DROP SCHEMA schema_name
return Schema(),
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
prev_keyword, text_before_cursor = find_prev_keyword(
text_before_cursor)
if prev_keyword:
return self.suggest_based_on_last_token(
prev_keyword, text_before_cursor, full_text, identifier)

View File

@ -1,6 +1,7 @@
"""
Using Completion class from
https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/completion.py
https://github.com/jonathanslenders/python-prompt-toolkit/
blob/master/prompt_toolkit/completion.py
"""
from __future__ import unicode_literals
@ -50,7 +51,9 @@ class Completion(object):
self.display_meta == other.display_meta)
def __hash__(self):
return hash((self.text, self.start_position, self.display, self.display_meta))
return hash(
(self.text, self.start_position, self.display, self.display_meta)
)
@property
def display_meta(self):

View File

@ -27,10 +27,10 @@ class Counter(dict):
from an input iterable. Or, initialize the count from another mapping
of elements to their counts.
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
self.update(iterable, **kwds)
@ -70,7 +70,8 @@ class Counter(dict):
@classmethod
def fromkeys(cls, iterable, v=None):
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.'
)
def update(self, iterable=None, **kwds):
'''Like dict.update() but add counts instead of replacing them.
@ -92,7 +93,8 @@ class Counter(dict):
for elem, count in iterable.iteritems():
self[elem] = self_get(elem, 0) + count
else:
dict.update(self, iterable) # fast path when counter is empty
# fast path when counter is empty
dict.update(self, iterable)
else:
self_get = self.get
for elem in iterable:
@ -105,7 +107,8 @@ class Counter(dict):
return Counter(self)
def __delitem__(self, elem):
'Like dict.__delitem__() but does not raise KeyError for missing values.'
"""Like dict.__delitem__() but does not raise KeyError for missing
values."""
if elem in self:
dict.__delitem__(self, elem)

View File

@ -7,8 +7,8 @@ table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
class FunctionMetadata(object):
def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate,
is_window, is_set_returning):
def __init__(self, schema_name, func_name, arg_list, return_type,
is_aggregate, is_window, is_set_returning):
"""Class for describing a postgresql function"""
self.schema_name = schema_name
@ -20,8 +20,8 @@ class FunctionMetadata(object):
self.is_set_returning = is_set_returning
def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
@ -32,11 +32,13 @@ class FunctionMetadata(object):
self.is_set_returning))
def __repr__(self):
return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
% (self.__class__.__name__, self.schema_name, self.func_name,
self.arg_list, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning))
return (
('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
% (self.__class__.__name__, self.schema_name, self.func_name,
self.arg_list, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning)
)
def fieldnames(self):
"""Returns a list of output field names"""
@ -130,7 +132,8 @@ def parse_typed_field_list(tokens):
else:
field[parse_state].append(tok)
# Final argument won't be followed by a comma, so make sure it gets yielded
# Final argument won't be followed by a comma, so make sure it gets
# yielded
if field.type:
yield field

View File

@ -65,8 +65,9 @@ def last_word(text, include='alphanum_underscore'):
return ''
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
TableReference = namedtuple(
'TableReference', ['schema', 'name', 'alias', 'is_function']
)
# This code is borrowed from sqlparse example script.
@ -74,9 +75,9 @@ TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
def is_subselect(parsed):
if not parsed.is_group():
return False
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
'UPDATE', 'CREATE', 'DELETE'):
if item.ttype is DML and item.value.upper() in sql_type:
return True
return False
@ -95,13 +96,13 @@ def extract_from_part(parsed, stop_at_punctuation=True):
elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a
# StopIteration. So we need to ignore the keyword if the keyword
# FROM.
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This
# causes the second FROM to trigger this elif condition resulting
# in a StopIteration. So we need to ignore the keyword if the
# keyword FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
# condition. So we need to ignore the keyword JOIN and its
# variants INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
@ -118,7 +119,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
identifier.value.upper() == 'FROM'):
tbl_prefix_seen = True
break
@ -139,23 +140,28 @@ def extract_table_identifiers(token_stream, allow_functions=True):
except AttributeError:
continue
if real_name:
yield TableReference(schema_name, real_name,
identifier.get_alias(), is_function)
yield TableReference(
schema_name, real_name, identifier.get_alias(),
is_function
)
elif isinstance(item, Identifier):
real_name = item.get_real_name()
schema_name = item.get_parent_name()
is_function = allow_functions and _identifier_is_function(item)
if real_name:
yield TableReference(schema_name, real_name, item.get_alias(),
is_function)
yield TableReference(
schema_name, real_name, item.get_alias(), is_function
)
else:
name = item.get_name()
yield TableReference(None, name, item.get_alias() or name,
is_function)
yield TableReference(
None, name, item.get_alias() or name, is_function
)
elif isinstance(item, Function):
yield TableReference(None, item.get_real_name(), item.get_alias(),
allow_functions)
yield TableReference(
None, item.get_real_name(), item.get_alias(), allow_functions
)
# extract_tables is inspired from examples in the sqlparse lib.
@ -181,8 +187,9 @@ def extract_tables(sql):
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
identifiers = extract_table_identifiers(
stream, allow_functions=not insert_stmt
)
return tuple(identifiers)
@ -202,11 +209,11 @@ def find_prev_keyword(sql):
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error
# Minimal example:
# We can't use parsed.token_index(t) because t may be a child
# token inside a TokenList, in which case token_index thows an
# error Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
# p.token_index(t) # Throws ValueError: not in list
@ -242,13 +249,13 @@ def _parsed_is_open_quote(parsed):
if tok.match(Token.Error, "'"):
# An unmatched single quote
return True
elif (tok.ttype in Token.Name.Builtin
and dollar_quote_regex.match(tok.value)):
elif (tok.ttype in Token.Name.Builtin and
dollar_quote_regex.match(tok.value)):
# Find the matching closing dollar quote sign
for (j, tok2) in enumerate(tokens[i + 1:], i + 1):
if tok2.match(Token.Name.Builtin, tok.value):
# Found the matching closing quote - continue our scan for
# open quotes thereafter
# Found the matching closing quote - continue our scan
# for open quotes thereafter
i = j
break
else:
@ -285,4 +292,4 @@ def parse_partial_identifier(word):
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print (extract_tables(sql))
print(extract_tables(sql))

View File

@ -19,14 +19,39 @@ from pgadmin.utils.route import BaseTestGenerator
class TestVersionedTemplateLoader(BaseTestGenerator):
scenarios = [
("Render a template when called", dict(scenario=1)),
("Render a version 9.1 template when it is present", dict(scenario=2)),
("Render a version 9.2 template when request for a higher version", dict(scenario=3)),
("Render default version when version 9.0 was requested and only 9.1 and 9.2 are present", dict(scenario=4)),
("Raise error when version is smaller than available templates", dict(scenario=5)),
("Render a version GPDB 5.0 template when it is present", dict(scenario=6)),
("Render a version GPDB 5.0 template when it is in default", dict(scenario=7)),
("Raise error when version is gpdb but template does not exist", dict(scenario=8))
(
"Render a template when called",
dict(scenario=1)
),
(
"Render a version 9.1 template when it is present",
dict(scenario=2)
),
(
"Render a version 9.2 template when request for a higher version",
dict(scenario=3)
),
(
"Render default version when version 9.0 was requested and only "
"9.1 and 9.2 are present",
dict(scenario=4)
),
(
"Raise error when version is smaller than available templates",
dict(scenario=5)
),
(
"Render a version GPDB 5.0 template when it is present",
dict(scenario=6)
),
(
"Render a version GPDB 5.0 template when it is in default",
dict(scenario=7)
),
(
"Raise error when version is gpdb but template does not exist",
dict(scenario=8)
)
]
def setUp(self):
@ -36,86 +61,150 @@ class TestVersionedTemplateLoader(BaseTestGenerator):
if self.scenario == 1:
self.test_get_source_returns_a_template()
if self.scenario == 2:
self.test_get_source_when_the_version_is_9_1_returns_9_1_template()
# test_get_source_when_the_version_is_9_1_returns_9_1_template
self.test_get_source_when_the_version_is_9_1()
if self.scenario == 3:
self.test_get_source_when_the_version_is_9_3_and_there_are_templates_for_9_2_and_9_1_returns_9_2_template()
# test_get_source_when_the_version_is_9_3_and_there_are_templates_
# for_9_2_and_9_1_returns_9_2_template
self.test_get_source_when_the_version_is_9_3()
if self.scenario == 4:
self.test_get_source_when_version_is_9_0_and_there_are_templates_for_9_1_and_9_2_returns_default_template()
# test_get_source_when_the_version_is_9_0_and_there_are_templates_
# for_9_1_and_9_2_returns_default_template
self.test_get_source_when_the_version_is_9_0()
if self.scenario == 5:
self.test_raise_not_found_exception_when_postgres_version_less_than_all_available_sql_templates()
# test_raise_not_found_exception_when_postgres_version_less_than_
# all_available_sql_templates
self.test_raise_not_found_exception()
if self.scenario == 6:
self.test_get_source_when_the_version_is_gpdb_5_0_returns_gpdb_5_0_template()
# test_get_source_when_the_version_is_gpdb_5_0_returns_gpdb_5_0_
# template
self.test_get_source_when_the_version_is_gpdb_5_0()
if self.scenario == 7:
self.test_get_source_when_the_version_is_gpdb_5_0_returns_default_template()
# test_get_source_when_the_version_is_gpdb_5_0_returns_default_
# template
self.test_get_source_when_the_version_is_gpdb_5_0_returns_default()
if self.scenario == 8:
self.test_raise_not_found_exception_when_the_version_is_gpdb_template_not_exist()
# test_raise_not_found_exception_when_the_version_is_gpdb_template
# _not_exist
self.test_raise_not_found_exception_when_the_version_is_gpdb()
def test_get_source_returns_a_template(self):
expected_content = "Some SQL" \
"\nsome more stuff on a new line\n"
# For cross platform we join the SQL path (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql", "9.1_plus", "some_action.sql")
content, filename, up_to_dateness = self.loader.get_source(None, "some_feature/sql/9.1_plus/some_action.sql")
self.assertEqual(expected_content, str(content).replace("\r",""))
# For cross platform we join the SQL path
# (This solves the slashes issue)
sql_path = os.path.join(
"some_feature", "sql", "9.1_plus", "some_action.sql"
)
content, filename, up_to_dateness = self.loader.get_source(
None, "some_feature/sql/9.1_plus/some_action.sql"
)
self.assertEqual(
expected_content, str(content).replace("\r", "")
)
self.assertIn(sql_path, filename)
def test_get_source_when_the_version_is_9_1_returns_9_1_template(self):
def test_get_source_when_the_version_is_9_1(self):
"""Render a version 9.1 template when it is present"""
expected_content = "Some SQL" \
"\nsome more stuff on a new line\n"
# For cross platform we join the SQL path (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql", "9.1_plus", "some_action.sql")
content, filename, up_to_dateness = self.loader.get_source(None, "some_feature/sql/#90100#/some_action.sql")
# For cross platform we join the SQL path
# (This solves the slashes issue)
sql_path = os.path.join(
"some_feature", "sql", "9.1_plus", "some_action.sql"
)
content, filename, up_to_dateness = self.loader.get_source(
None, "some_feature/sql/#90100#/some_action.sql"
)
self.assertEqual(expected_content, str(content).replace("\r",""))
self.assertEqual(
expected_content, str(content).replace("\r", "")
)
self.assertIn(sql_path, filename)
def test_get_source_when_the_version_is_9_3_and_there_are_templates_for_9_2_and_9_1_returns_9_2_template(self):
# For cross platform we join the SQL path (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql", "9.2_plus", "some_action.sql")
content, filename, up_to_dateness = self.loader.get_source(None, "some_feature/sql/#90300#/some_action.sql")
def test_get_source_when_the_version_is_9_3(self):
"""Render a version 9.2 template when request for a higher version"""
# For cross platform we join the SQL path
# (This solves the slashes issue)
sql_path = os.path.join(
"some_feature", "sql", "9.2_plus", "some_action.sql"
)
content, filename, up_to_dateness = self.loader.get_source(
None, "some_feature/sql/#90300#/some_action.sql"
)
self.assertEqual("Some 9.2 SQL", str(content).replace("\r",""))
self.assertEqual(
"Some 9.2 SQL", str(content).replace("\r", "")
)
self.assertIn(sql_path, filename)
def test_get_source_when_version_is_9_0_and_there_are_templates_for_9_1_and_9_2_returns_default_template(self):
# For cross platform we join the SQL path (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql", "default", "some_action_with_default.sql")
def test_get_source_when_the_version_is_9_0(self):
"""Render default version when version 9.0 was requested and only
9.1 and 9.2 are present"""
# For cross platform we join the SQL path
# (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql",
"default", "some_action_with_default.sql")
content, filename, up_to_dateness = self.loader.get_source(
None,
"some_feature/sql/#90000#/some_action_with_default.sql")
self.assertEqual("Some default SQL", str(content).replace("\r",""))
self.assertEqual("Some default SQL", str(content).replace("\r", ""))
self.assertIn(sql_path, filename)
def test_raise_not_found_exception_when_postgres_version_less_than_all_available_sql_templates(self):
def test_raise_not_found_exception(self):
"""Raise error when version is smaller than available templates"""
try:
self.loader.get_source(None, "some_feature/sql/#10100#/some_action.sql")
self.loader.get_source(
None, "some_feature/sql/#10100#/some_action.sql"
)
self.fail("No exception raised")
except TemplateNotFound:
return
def test_get_source_when_the_version_is_gpdb_5_0_returns_gpdb_5_0_template(self):
def test_get_source_when_the_version_is_gpdb_5_0(self):
"""Render a version GPDB 5.0 template when it is present"""
expected_content = "Some default SQL for GPDB\n"
# For cross platform we join the SQL path (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql", "gpdb_5.0_plus", "some_action_with_gpdb_5_0.sql")
content, filename, up_to_dateness = self.loader.get_source(None, "some_feature/sql/#gpdb#80323#/some_action_with_gpdb_5_0.sql")
# For cross platform we join the SQL path
# (This solves the slashes issue)
sql_path = os.path.join(
"some_feature", "sql", "gpdb_5.0_plus",
"some_action_with_gpdb_5_0.sql"
)
content, filename, up_to_dateness = self.loader.get_source(
None,
"some_feature/sql/#gpdb#80323#/some_action_with_gpdb_5_0.sql"
)
self.assertEqual(expected_content, str(content).replace("\r", ""))
self.assertEqual(
expected_content, str(content).replace("\r", "")
)
self.assertIn(sql_path, filename)
def test_get_source_when_the_version_is_gpdb_5_0_returns_default_template(self):
def test_get_source_when_the_version_is_gpdb_5_0_returns_default(self):
"""Render a version GPDB 5.0 template when it is in default"""
expected_content = "Some default SQL"
# For cross platform we join the SQL path (This solves the slashes issue)
sql_path = os.path.join("some_feature", "sql", "default", "some_action_with_default.sql")
content, filename, up_to_dateness = self.loader.get_source(None, "some_feature/sql/#gpdb#80323#/some_action_with_default.sql")
# For cross platform we join the SQL path
# (This solves the slashes issue)
sql_path = os.path.join(
"some_feature", "sql", "default", "some_action_with_default.sql"
)
content, filename, up_to_dateness = self.loader.get_source(
None, "some_feature/sql/#gpdb#80323#/some_action_with_default.sql"
)
self.assertEqual(expected_content, str(content).replace("\r", ""))
self.assertEqual(
expected_content, str(content).replace("\r", "")
)
self.assertIn(sql_path, filename)
def test_raise_not_found_exception_when_the_version_is_gpdb_template_not_exist(self):
def test_raise_not_found_exception_when_the_version_is_gpdb(self):
""""Raise error when version is gpdb but template does not exist"""
try:
self.loader.get_source(None, "some_feature/sql/#gpdb#50100#/some_action.sql")
self.loader.get_source(
None, "some_feature/sql/#gpdb#50100#/some_action.sql"
)
self.fail("No exception raised")
except TemplateNotFound:
return
@ -124,4 +213,6 @@ class TestVersionedTemplateLoader(BaseTestGenerator):
class FakeApp(Flask):
def __init__(self):
super(FakeApp, self).__init__("")
self.jinja_loader = FileSystemLoader(os.path.dirname(os.path.realpath(__file__)) + "/templates")
self.jinja_loader = FileSystemLoader(
os.path.dirname(os.path.realpath(__file__)) + "/templates"
)

View File

@ -35,22 +35,29 @@ class VersionedTemplateLoader(DispatchingJinjaLoader):
server_versions = postgres_versions
if len(template_path_parts) == 1:
return super(VersionedTemplateLoader, self).get_source(environment, template)
return super(VersionedTemplateLoader, self).get_source(
environment, template
)
if len(template_path_parts) == 4:
path_start, server_type, specified_version_number, file_name = template_path_parts
path_start, server_type, specified_version_number, file_name = \
template_path_parts
if server_type == 'gpdb':
server_versions = gpdb_versions
else:
path_start, specified_version_number, file_name = template_path_parts
path_start, specified_version_number, file_name = \
template_path_parts
for server_version in server_versions:
if server_version['number'] > int(specified_version_number):
continue
template_path = path_start + '/' + server_version['name'] + '/' + file_name
template_path = path_start + '/' + \
server_version['name'] + '/' + file_name
try:
return super(VersionedTemplateLoader, self).get_source(environment, template_path)
return super(VersionedTemplateLoader, self).get_source(
environment, template_path
)
except TemplateNotFound:
continue
raise TemplateNotFound(template)