mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2024-12-02 13:29:11 -06:00
391 lines
12 KiB
Python
391 lines
12 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2019, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
"""
|
|
Implementation of Driver class
|
|
It is a wrapper around the actual psycopg2 driver, and connection
|
|
object.
|
|
|
|
"""
|
|
import datetime
|
|
from flask import session
|
|
from flask_babelex import gettext
|
|
import psycopg2
|
|
from psycopg2.extensions import adapt
|
|
|
|
import config
|
|
from pgadmin.model import Server, User
|
|
from .keywords import ScanKeyword
|
|
from ..abstract import BaseDriver
|
|
from .connection import Connection
|
|
from .server_manager import ServerManager
|
|
|
|
|
|
class Driver(BaseDriver):
|
|
"""
|
|
class Driver(BaseDriver):
|
|
|
|
This driver acts as a wrapper around psycopg2 connection driver
|
|
implementation. We will be using psycopg2 for makeing connection with
|
|
the PostgreSQL/EDB Postgres Advanced Server (EnterpriseDB).
|
|
|
|
Properties:
|
|
----------
|
|
|
|
* Version (string):
|
|
Version of psycopg2 driver
|
|
|
|
Methods:
|
|
-------
|
|
* get_connection(sid, database, conn_id, auto_reconnect)
|
|
- It returns a Connection class object, which may/may not be connected
|
|
to the database server for this sesssion
|
|
|
|
* release_connection(seid, database, conn_id)
|
|
- It releases the connection object for the given conn_id/database for this
|
|
session.
|
|
|
|
* connection_manager(sid, reset)
|
|
- It returns the server connection manager for this session.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.managers = dict()
|
|
|
|
super(Driver, self).__init__()
|
|
|
|
def connection_manager(self, sid=None):
|
|
"""
|
|
connection_manager(...)
|
|
|
|
Returns the ServerManager object for the current session. It will
|
|
create new ServerManager object (if necessary).
|
|
|
|
Parameters:
|
|
sid
|
|
- Server ID
|
|
"""
|
|
assert (sid is not None and isinstance(sid, int))
|
|
managers = None
|
|
|
|
if session.sid not in self.managers:
|
|
self.managers[session.sid] = managers = dict()
|
|
if '__pgsql_server_managers' in session:
|
|
session_managers = session['__pgsql_server_managers'].copy()
|
|
session['__pgsql_server_managers'] = dict()
|
|
|
|
for server_id in session_managers:
|
|
s = Server.query.filter_by(id=server_id).first()
|
|
|
|
if not s:
|
|
continue
|
|
|
|
manager = managers[str(server_id)] = ServerManager(s)
|
|
manager._restore(session_managers[server_id])
|
|
manager.update_session()
|
|
else:
|
|
managers = self.managers[session.sid]
|
|
|
|
managers['pinged'] = datetime.datetime.now()
|
|
if str(sid) not in managers:
|
|
s = Server.query.filter_by(id=sid).first()
|
|
|
|
if not s:
|
|
return None
|
|
|
|
managers[str(sid)] = ServerManager(s)
|
|
|
|
return managers[str(sid)]
|
|
|
|
return managers[str(sid)]
|
|
|
|
def Version(cls):
|
|
"""
|
|
Version(...)
|
|
|
|
Returns the current version of psycopg2 driver
|
|
"""
|
|
version = getattr(psycopg2, '__version__', None)
|
|
|
|
if version:
|
|
return version
|
|
|
|
raise Exception(
|
|
"Driver Version information for psycopg2 is not available!"
|
|
)
|
|
|
|
def libpq_version(cls):
|
|
"""
|
|
Returns the loaded libpq version
|
|
"""
|
|
version = getattr(psycopg2, '__libpq_version__', None)
|
|
if version:
|
|
return version
|
|
|
|
raise Exception(
|
|
"libpq version information is not available!"
|
|
)
|
|
|
|
def get_connection(
|
|
self, sid, database=None, conn_id=None, auto_reconnect=True
|
|
):
|
|
"""
|
|
get_connection(...)
|
|
|
|
Returns the connection object for the certain connection-id/database
|
|
for the specific server, identified by sid. Create a new Connection
|
|
object (if require).
|
|
|
|
Parameters:
|
|
sid
|
|
- Server ID
|
|
database
|
|
- Database, on which the connection needs to be made
|
|
If provided none, maintenance_db for the server will be used,
|
|
while generating new Connection object.
|
|
conn_id
|
|
- Identification String for the Connection This will be used by
|
|
certain tools, which will require a dedicated connection for it.
|
|
i.e. Debugger, Query Tool, etc.
|
|
auto_reconnect
|
|
- This parameters define, if we should attempt to reconnect the
|
|
database server automatically, when connection has been lost for
|
|
any reason. Certain tools like Debugger will require a permenant
|
|
connection, and it stops working on disconnection.
|
|
|
|
"""
|
|
manager = self.connection_manager(sid)
|
|
|
|
return manager.connection(database, conn_id, auto_reconnect)
|
|
|
|
def release_connection(self, sid, database=None, conn_id=None):
|
|
"""
|
|
Release the connection for the given connection-id/database in this
|
|
session.
|
|
"""
|
|
return self.connection_manager(sid).release(database, conn_id)
|
|
|
|
def delete_manager(self, sid):
|
|
"""
|
|
Delete manager for given server id.
|
|
"""
|
|
manager = self.connection_manager(sid)
|
|
if manager is not None:
|
|
manager.release()
|
|
if session.sid in self.managers and \
|
|
str(sid) in self.managers[session.sid]:
|
|
del self.managers[session.sid][str(sid)]
|
|
|
|
def gc(self):
|
|
"""
|
|
Release the connections for the sessions, which have not pinged the
|
|
server for more than config.MAX_SESSION_IDLE_TIME.
|
|
"""
|
|
|
|
# Minimum session idle is 20 minutes
|
|
max_idle_time = max(config.MAX_SESSION_IDLE_TIME or 60, 20)
|
|
session_idle_timeout = datetime.timedelta(minutes=max_idle_time)
|
|
|
|
curr_time = datetime.datetime.now()
|
|
|
|
for sess in self.managers:
|
|
sess_mgr = self.managers[sess]
|
|
|
|
if sess == session.sid:
|
|
sess_mgr['pinged'] = curr_time
|
|
continue
|
|
if curr_time - sess_mgr['pinged'] >= session_idle_timeout:
|
|
for mgr in [
|
|
m for m in sess_mgr.values() if isinstance(m,
|
|
ServerManager)
|
|
]:
|
|
mgr.release()
|
|
|
|
def gc_own(self):
|
|
"""
|
|
Release the connections for current session
|
|
This is useful when (eg. logout) we want to release all
|
|
connections (except dedicated connections created by utilities
|
|
like backup, restore etc) of all servers for current user.
|
|
"""
|
|
|
|
sess_mgr = self.managers.get(session.sid, None)
|
|
|
|
if sess_mgr:
|
|
for mgr in (
|
|
m for m in sess_mgr.values() if isinstance(m, ServerManager)
|
|
):
|
|
mgr.release()
|
|
|
|
@staticmethod
|
|
def qtLiteral(value):
|
|
adapted = adapt(value)
|
|
|
|
# Not all adapted objects have encoding
|
|
# e.g.
|
|
# psycopg2.extensions.BOOLEAN
|
|
# psycopg2.extensions.FLOAT
|
|
# psycopg2.extensions.INTEGER
|
|
# etc...
|
|
if hasattr(adapted, 'encoding'):
|
|
adapted.encoding = 'utf8'
|
|
res = adapted.getquoted()
|
|
|
|
if isinstance(res, bytes):
|
|
return res.decode('utf-8')
|
|
return res
|
|
|
|
@staticmethod
|
|
def ScanKeywordExtraLookup(key):
|
|
# UNRESERVED_KEYWORD 0
|
|
# COL_NAME_KEYWORD 1
|
|
# TYPE_FUNC_NAME_KEYWORD 2
|
|
# RESERVED_KEYWORD 3
|
|
extraKeywords = {
|
|
'connect': 3,
|
|
'convert': 3,
|
|
'distributed': 0,
|
|
'exec': 3,
|
|
'log': 0,
|
|
'long': 3,
|
|
'minus': 3,
|
|
'nocache': 3,
|
|
'number': 3,
|
|
'package': 3,
|
|
'pls_integer': 3,
|
|
'raw': 3,
|
|
'return': 3,
|
|
'smalldatetime': 3,
|
|
'smallfloat': 3,
|
|
'smallmoney': 3,
|
|
'sysdate': 3,
|
|
'systimestap': 3,
|
|
'tinyint': 3,
|
|
'tinytext': 3,
|
|
'varchar2': 3
|
|
}
|
|
|
|
return extraKeywords.get(key, None) or ScanKeyword(key)
|
|
|
|
@staticmethod
|
|
def needsQuoting(key, forTypes):
|
|
value = key
|
|
valNoArray = value
|
|
|
|
# check if the string is number or not
|
|
if isinstance(value, int):
|
|
return True
|
|
# certain types should not be quoted even though it contains a space.
|
|
# Evilness.
|
|
elif forTypes and value[-2:] == u"[]":
|
|
valNoArray = value[:-2]
|
|
|
|
if forTypes and valNoArray.lower() in [
|
|
u'bit varying',
|
|
u'"char"',
|
|
u'character varying',
|
|
u'double precision',
|
|
u'timestamp without time zone',
|
|
u'timestamp with time zone',
|
|
u'time without time zone',
|
|
u'time with time zone',
|
|
u'"trigger"',
|
|
u'"unknown"'
|
|
]:
|
|
return False
|
|
|
|
# If already quoted?, If yes then do not quote again
|
|
if forTypes and valNoArray:
|
|
if valNoArray.startswith('"') \
|
|
or valNoArray.endswith('"'):
|
|
return False
|
|
|
|
if u'0' <= valNoArray[0] <= u'9':
|
|
return True
|
|
|
|
for c in valNoArray:
|
|
if (not (u'a' <= c <= u'z') and c != u'_' and
|
|
not (u'0' <= c <= u'9')):
|
|
return True
|
|
|
|
# check string is keywaord or not
|
|
category = Driver.ScanKeywordExtraLookup(value)
|
|
|
|
if category is None:
|
|
return False
|
|
|
|
# UNRESERVED_KEYWORD
|
|
if category == 0:
|
|
return False
|
|
|
|
# COL_NAME_KEYWORD
|
|
if forTypes and category == 1:
|
|
return False
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def qtTypeIdent(conn, *args):
|
|
# We're not using the conn object at the moment, but - we will
|
|
# modify the
|
|
# logic to use the server version specific keywords later.
|
|
res = None
|
|
value = None
|
|
|
|
for val in args:
|
|
if len(val) == 0:
|
|
continue
|
|
if hasattr(str, 'decode') and not isinstance(val, unicode):
|
|
# Handling for python2
|
|
try:
|
|
val = str(val).encode('utf-8')
|
|
except UnicodeDecodeError:
|
|
# If already unicode, most likely coming from db
|
|
val = str(val).decode('utf-8')
|
|
value = val
|
|
|
|
if (Driver.needsQuoting(val, True)):
|
|
value = value.replace("\"", "\"\"")
|
|
value = "\"" + value + "\""
|
|
|
|
res = ((res and res + '.') or '') + value
|
|
|
|
return res
|
|
|
|
@staticmethod
|
|
def qtIdent(conn, *args):
|
|
# We're not using the conn object at the moment, but - we will
|
|
# modify the logic to use the server version specific keywords later.
|
|
res = None
|
|
value = None
|
|
|
|
for val in args:
|
|
if type(val) == list:
|
|
return map(lambda w: Driver.qtIdent(conn, w), val)
|
|
if hasattr(str, 'decode') and not isinstance(val, unicode):
|
|
# Handling for python2
|
|
try:
|
|
val = str(val).encode('utf-8')
|
|
except UnicodeDecodeError:
|
|
# If already unicode, most likely coming from db
|
|
val = str(val).decode('utf-8')
|
|
|
|
if len(val) == 0:
|
|
continue
|
|
|
|
value = val
|
|
|
|
if (Driver.needsQuoting(val, False)):
|
|
value = value.replace("\"", "\"\"")
|
|
value = "\"" + value + "\""
|
|
|
|
res = ((res and res + '.') or '') + value
|
|
|
|
return res
|