########################################################################## # # pgAdmin 4 - PostgreSQL Tools # # Copyright (C) 2013 - 2018, The pgAdmin Development Team # This software is released under the PostgreSQL Licence # ########################################################################## """ Implementation of Driver class It is a wrapper around the actual psycopg2 driver, and connection object. """ import datetime from flask import session from flask_babelex import gettext import psycopg2 from psycopg2.extensions import adapt import config from pgadmin.model import Server, User from .keywords import ScanKeyword from ..abstract import BaseDriver from .connection import Connection from .server_manager import ServerManager class Driver(BaseDriver): """ class Driver(BaseDriver): This driver acts as a wrapper around psycopg2 connection driver implementation. We will be using psycopg2 for makeing connection with the PostgreSQL/EDB Postgres Advanced Server (EnterpriseDB). Properties: ---------- * Version (string): Version of psycopg2 driver Methods: ------- * get_connection(sid, database, conn_id, auto_reconnect) - It returns a Connection class object, which may/may not be connected to the database server for this sesssion * release_connection(seid, database, conn_id) - It releases the connection object for the given conn_id/database for this session. * connection_manager(sid, reset) - It returns the server connection manager for this session. """ def __init__(self, **kwargs): self.managers = dict() super(Driver, self).__init__() def connection_manager(self, sid=None): """ connection_manager(...) Returns the ServerManager object for the current session. It will create new ServerManager object (if necessary). Parameters: sid - Server ID """ assert (sid is not None and isinstance(sid, int)) managers = None if session.sid not in self.managers: self.managers[session.sid] = managers = dict() if '__pgsql_server_managers' in session: session_managers = session['__pgsql_server_managers'].copy() session['__pgsql_server_managers'] = dict() for server_id in session_managers: s = Server.query.filter_by(id=server_id).first() if not s: continue manager = managers[str(server_id)] = ServerManager(s) manager._restore(session_managers[server_id]) manager.update_session() else: managers = self.managers[session.sid] managers['pinged'] = datetime.datetime.now() if str(sid) not in managers: s = Server.query.filter_by(id=sid).first() if not s: return None managers[str(sid)] = ServerManager(s) return managers[str(sid)] return managers[str(sid)] def Version(cls): """ Version(...) Returns the current version of psycopg2 driver """ version = getattr(psycopg2, '__version__', None) if version: return version raise Exception( "Driver Version information for psycopg2 is not available!" ) def libpq_version(cls): """ Returns the loaded libpq version """ version = getattr(psycopg2, '__libpq_version__', None) if version: return version raise Exception( "libpq version information is not available!" ) def get_connection( self, sid, database=None, conn_id=None, auto_reconnect=True ): """ get_connection(...) Returns the connection object for the certain connection-id/database for the specific server, identified by sid. Create a new Connection object (if require). Parameters: sid - Server ID database - Database, on which the connection needs to be made If provided none, maintenance_db for the server will be used, while generating new Connection object. conn_id - Identification String for the Connection This will be used by certain tools, which will require a dedicated connection for it. i.e. Debugger, Query Tool, etc. auto_reconnect - This parameters define, if we should attempt to reconnect the database server automatically, when connection has been lost for any reason. Certain tools like Debugger will require a permenant connection, and it stops working on disconnection. """ manager = self.connection_manager(sid) return manager.connection(database, conn_id, auto_reconnect) def release_connection(self, sid, database=None, conn_id=None): """ Release the connection for the given connection-id/database in this session. """ return self.connection_manager(sid).release(database, conn_id) def delete_manager(self, sid): """ Delete manager for given server id. """ manager = self.connection_manager(sid) if manager is not None: manager.release() if session.sid in self.managers and \ str(sid) in self.managers[session.sid]: del self.managers[session.sid][str(sid)] def gc(self): """ Release the connections for the sessions, which have not pinged the server for more than config.MAX_SESSION_IDLE_TIME. """ # Minimum session idle is 20 minutes max_idle_time = max(config.MAX_SESSION_IDLE_TIME or 60, 20) session_idle_timeout = datetime.timedelta(minutes=max_idle_time) curr_time = datetime.datetime.now() for sess in self.managers: sess_mgr = self.managers[sess] if sess == session.sid: sess_mgr['pinged'] = curr_time continue if curr_time - sess_mgr['pinged'] >= session_idle_timeout: for mgr in [ m for m in sess_mgr if isinstance(m, ServerManager) ]: mgr.release() @staticmethod def qtLiteral(value): adapted = adapt(value) # Not all adapted objects have encoding # e.g. # psycopg2.extensions.BOOLEAN # psycopg2.extensions.FLOAT # psycopg2.extensions.INTEGER # etc... if hasattr(adapted, 'encoding'): adapted.encoding = 'utf8' res = adapted.getquoted() if isinstance(res, bytes): return res.decode('utf-8') return res @staticmethod def ScanKeywordExtraLookup(key): # UNRESERVED_KEYWORD 0 # COL_NAME_KEYWORD 1 # TYPE_FUNC_NAME_KEYWORD 2 # RESERVED_KEYWORD 3 extraKeywords = { 'connect': 3, 'convert': 3, 'distributed': 0, 'exec': 3, 'log': 0, 'long': 3, 'minus': 3, 'nocache': 3, 'number': 3, 'package': 3, 'pls_integer': 3, 'raw': 3, 'return': 3, 'smalldatetime': 3, 'smallfloat': 3, 'smallmoney': 3, 'sysdate': 3, 'systimestap': 3, 'tinyint': 3, 'tinytext': 3, 'varchar2': 3 } return extraKeywords.get(key, None) or ScanKeyword(key) @staticmethod def needsQuoting(key, forTypes): value = key valNoArray = value # check if the string is number or not if isinstance(value, int): return True # certain types should not be quoted even though it contains a space. # Evilness. elif forTypes and value[-2:] == u"[]": valNoArray = value[:-2] if forTypes and valNoArray.lower() in [ u'bit varying', u'"char"', u'character varying', u'double precision', u'timestamp without time zone', u'timestamp with time zone', u'time without time zone', u'time with time zone', u'"trigger"', u'"unknown"' ]: return False # If already quoted?, If yes then do not quote again if forTypes and valNoArray: if valNoArray.startswith('"') \ or valNoArray.endswith('"'): return False if u'0' <= valNoArray[0] <= u'9': return True for c in valNoArray: if (not (u'a' <= c <= u'z') and c != u'_' and not (u'0' <= c <= u'9')): return True # check string is keywaord or not category = Driver.ScanKeywordExtraLookup(value) if category is None: return False # UNRESERVED_KEYWORD if category == 0: return False # COL_NAME_KEYWORD if forTypes and category == 1: return False return True @staticmethod def qtTypeIdent(conn, *args): # We're not using the conn object at the moment, but - we will # modify the # logic to use the server version specific keywords later. res = None value = None for val in args: if len(val) == 0: continue if hasattr(str, 'decode') and not isinstance(val, unicode): # Handling for python2 try: val = str(val).encode('utf-8') except UnicodeDecodeError: # If already unicode, most likely coming from db val = str(val).decode('utf-8') value = val if (Driver.needsQuoting(val, True)): value = value.replace("\"", "\"\"") value = "\"" + value + "\"" res = ((res and res + '.') or '') + value return res @staticmethod def qtIdent(conn, *args): # We're not using the conn object at the moment, but - we will # modify the logic to use the server version specific keywords later. res = None value = None for val in args: if type(val) == list: return map(lambda w: Driver.qtIdent(conn, w), val) if hasattr(str, 'decode') and not isinstance(val, unicode): # Handling for python2 try: val = str(val).encode('utf-8') except UnicodeDecodeError: # If already unicode, most likely coming from db val = str(val).decode('utf-8') if len(val) == 0: continue value = val if (Driver.needsQuoting(val, False)): value = value.replace("\"", "\"\"") value = "\"" + value + "\"" res = ((res and res + '.') or '') + value return res