Add support for SSH tunneled connections. Fixes #1447

This commit is contained in:
Akshay Joshi
2018-05-04 11:27:27 +01:00
committed by Dave Page
parent 455c45ea85
commit b7fb01ab04
26 changed files with 697 additions and 129 deletions

View File

@@ -211,13 +211,25 @@ class Connection(BaseConnection):
pg_conn = None
password = None
passfile = None
mgr = self.manager
manager = self.manager
encpass = kwargs['password'] if 'password' in kwargs else None
passfile = kwargs['passfile'] if 'passfile' in kwargs else None
tunnel_password = kwargs['tunnel_password'] if 'tunnel_password' in \
kwargs else None
# Check SSH Tunnel needs to be created
if manager.use_ssh_tunnel == 1 and tunnel_password is not None:
status, error = manager.create_ssh_tunnel(tunnel_password)
if not status:
return False, error
# Check SSH Tunnel is alive or not.
if manager.use_ssh_tunnel == 1:
manager.check_ssh_tunnel_alive()
if encpass is None:
encpass = self.password or getattr(mgr, 'password', None)
encpass = self.password or getattr(manager, 'password', None)
# Reset the existing connection password
if self.reconnecting is not False:
@@ -240,6 +252,7 @@ class Connection(BaseConnection):
password = password.decode()
except Exception as e:
manager.stop_ssh_tunnel()
current_app.logger.exception(e)
return False, \
_(
@@ -251,16 +264,16 @@ class Connection(BaseConnection):
# we will check for pgpass file availability from connection manager
# if it's present then we will use it
if not password and not encpass and not passfile:
passfile = mgr.passfile if mgr.passfile else None
passfile = manager.passfile if manager.passfile else None
try:
if hasattr(str, 'decode'):
database = self.db.encode('utf-8')
user = mgr.user.encode('utf-8')
user = manager.user.encode('utf-8')
conn_id = self.conn_id.encode('utf-8')
else:
database = self.db
user = mgr.user
user = manager.user
conn_id = self.conn_id
import os
@@ -268,21 +281,24 @@ class Connection(BaseConnection):
config.APP_NAME, conn_id)
pg_conn = psycopg2.connect(
host=mgr.host,
hostaddr=mgr.hostaddr,
port=mgr.port,
host=manager.local_bind_host if manager.use_ssh_tunnel
else manager.host,
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
else manager.hostaddr,
port=manager.local_bind_port if manager.use_ssh_tunnel
else manager.port,
database=database,
user=user,
password=password,
async=self.async,
passfile=get_complete_file_path(passfile),
sslmode=mgr.ssl_mode,
sslcert=get_complete_file_path(mgr.sslcert),
sslkey=get_complete_file_path(mgr.sslkey),
sslrootcert=get_complete_file_path(mgr.sslrootcert),
sslcrl=get_complete_file_path(mgr.sslcrl),
sslcompression=True if mgr.sslcompression else False,
service=mgr.service
sslmode=manager.ssl_mode,
sslcert=get_complete_file_path(manager.sslcert),
sslkey=get_complete_file_path(manager.sslkey),
sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcrl=get_complete_file_path(manager.sslcrl),
sslcompression=True if manager.sslcompression else False,
service=manager.service
)
# If connection is asynchronous then we will have to wait
@@ -291,6 +307,7 @@ class Connection(BaseConnection):
self._wait(pg_conn)
except psycopg2.Error as e:
manager.stop_ssh_tunnel()
if e.pgerror:
msg = e.pgerror
elif e.diag.message_detail:
@@ -317,6 +334,7 @@ class Connection(BaseConnection):
try:
status, msg = self._initialize(conn_id, **kwargs)
except Exception as e:
manager.stop_ssh_tunnel()
current_app.logger.exception(e)
self.conn = None
if not self.reconnecting:
@@ -324,7 +342,7 @@ class Connection(BaseConnection):
raise e
if status:
mgr._update_password(encpass)
manager._update_password(encpass)
else:
if not self.reconnecting:
self.wasConnected = False
@@ -342,7 +360,7 @@ class Connection(BaseConnection):
status, cur = self.__cursor()
formatted_exception_msg = self._formatted_exception_msg
mgr = self.manager
manager = self.manager
def _execute(cur, query, params=None):
try:
@@ -381,8 +399,8 @@ class Connection(BaseConnection):
return False, status
if mgr.role:
status = _execute(cur, u"SET ROLE TO %s", [mgr.role])
if manager.role:
status = _execute(cur, u"SET ROLE TO %s", [manager.role])
if status is not None:
self.conn.close()
@@ -401,7 +419,7 @@ class Connection(BaseConnection):
"Failed to setup the role with error message:\n{0}"
).format(status)
if mgr.ver is None:
if manager.ver is None:
status = _execute(cur, "SELECT version()")
if status is not None:
@@ -421,8 +439,8 @@ class Connection(BaseConnection):
if cur.rowcount > 0:
row = cur.fetchmany(1)[0]
mgr.ver = row['version']
mgr.sversion = self.conn.server_version
manager.ver = row['version']
manager.sversion = self.conn.server_version
status = _execute(cur, """
SELECT
@@ -434,14 +452,14 @@ FROM
WHERE db.datname = current_database()""")
if status is None:
mgr.db_info = mgr.db_info or dict()
manager.db_info = manager.db_info or dict()
if cur.rowcount > 0:
res = cur.fetchmany(1)[0]
mgr.db_info[res['did']] = res.copy()
manager.db_info[res['did']] = res.copy()
# We do not have database oid for the maintenance database.
if len(mgr.db_info) == 1:
mgr.did = res['did']
if len(manager.db_info) == 1:
manager.did = res['did']
status = _execute(cur, """
SELECT
@@ -453,33 +471,39 @@ WHERE
rolname = current_user""")
if status is None:
mgr.user_info = dict()
manager.user_info = dict()
if cur.rowcount > 0:
mgr.user_info = cur.fetchmany(1)[0]
manager.user_info = cur.fetchmany(1)[0]
if 'password' in kwargs:
mgr.password = kwargs['password']
manager.password = kwargs['password']
server_types = None
if 'server_types' in kwargs and isinstance(
kwargs['server_types'], list):
server_types = mgr.server_types = kwargs['server_types']
server_types = manager.server_types = kwargs['server_types']
if server_types is None:
from pgadmin.browser.server_groups.servers.types import ServerType
server_types = ServerType.types()
for st in server_types:
if st.instanceOf(mgr.ver):
mgr.server_type = st.stype
mgr.server_cls = st
if st.instanceOf(manager.ver):
manager.server_type = st.stype
manager.server_cls = st
break
mgr.update_session()
manager.update_session()
return True, None
def __cursor(self, server_cursor=False):
# Check SSH Tunnel is alive or not. If used by the database
# server for the connection.
if self.manager.use_ssh_tunnel == 1:
self.manager.check_ssh_tunnel_alive()
if self.wasConnected is False:
raise ConnectionLost(
self.manager.sid,
@@ -1166,9 +1190,9 @@ WHERE
if self.conn.closed:
self.conn = None
pg_conn = None
mgr = self.manager
manager = self.manager
password = getattr(mgr, 'password', None)
password = getattr(manager, 'password', None)
if password:
# Fetch Logged in User Details.
@@ -1181,20 +1205,23 @@ WHERE
try:
pg_conn = psycopg2.connect(
host=mgr.host,
hostaddr=mgr.hostaddr,
port=mgr.port,
host=manager.local_bind_host if manager.use_ssh_tunnel
else manager.host,
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
else manager.hostaddr,
port=manager.local_bind_port if manager.use_ssh_tunnel
else manager.port,
database=self.db,
user=mgr.user,
user=manager.user,
password=password,
passfile=get_complete_file_path(mgr.passfile),
sslmode=mgr.ssl_mode,
sslcert=get_complete_file_path(mgr.sslcert),
sslkey=get_complete_file_path(mgr.sslkey),
sslrootcert=get_complete_file_path(mgr.sslrootcert),
sslcrl=get_complete_file_path(mgr.sslcrl),
sslcompression=True if mgr.sslcompression else False,
service=mgr.service
passfile=get_complete_file_path(manager.passfile),
sslmode=manager.ssl_mode,
sslcert=get_complete_file_path(manager.sslcert),
sslkey=get_complete_file_path(manager.sslkey),
sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcrl=get_complete_file_path(manager.sslcrl),
sslcompression=True if manager.sslcompression else False,
service=manager.service
)
except psycopg2.Error as e:
@@ -1456,9 +1483,13 @@ Failed to reset the connection to the server due to following error:
try:
pg_conn = psycopg2.connect(
host=self.manager.host,
hostaddr=self.manager.hostaddr,
port=self.manager.port,
host=self.manager.local_bind_host if
self.manager.use_ssh_tunnel else self.manager.host,
hostaddr=self.manager.local_bind_host if
self.manager.use_ssh_tunnel else
self.manager.hostaddr,
port=self.manager.local_bind_port if
self.manager.use_ssh_tunnel else self.manager.port,
database=self.db,
user=self.manager.user,
password=password,

View File

@@ -12,14 +12,19 @@ Implementation of ServerManager
"""
import os
import datetime
import config
from flask import current_app, session
from flask_security import current_user
from flask_babelex import gettext
from pgadmin.utils import get_complete_file_path
from pgadmin.utils.crypto import decrypt
from .connection import Connection
from pgadmin.model import Server
from pgadmin.utils.exception import ConnectionLost
from pgadmin.model import Server, User
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
if config.SUPPORT_SSH_TUNNEL:
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
class ServerManager(object):
@@ -32,6 +37,9 @@ class ServerManager(object):
def __init__(self, server):
self.connections = dict()
self.local_bind_host = '127.0.0.1'
self.local_bind_port = None
self.tunnel_object = None
self.update(server)
@@ -66,6 +74,20 @@ class ServerManager(object):
self.sslcrl = server.sslcrl
self.sslcompression = True if server.sslcompression else False
self.service = server.service
if config.SUPPORT_SSH_TUNNEL:
self.use_ssh_tunnel = server.use_ssh_tunnel
self.tunnel_host = server.tunnel_host
self.tunnel_port = server.tunnel_port
self.tunnel_username = server.tunnel_username
self.tunnel_authentication = server.tunnel_authentication
self.tunnel_identity_file = server.tunnel_identity_file
else:
self.use_ssh_tunnel = 0
self.tunnel_host = None
self.tunnel_port = 22
self.tunnel_username = None
self.tunnel_authentication = None
self.tunnel_identity_file = None
for con in self.connections:
self.connections[con]._release()
@@ -167,7 +189,11 @@ WHERE db.oid = {0}""".format(did))
))
if database is None:
raise ConnectionLost(self.sid, None, None)
# Check SSH Tunnel is alive or not.
if self.use_ssh_tunnel == 1:
self.check_ssh_tunnel_alive()
else:
raise ConnectionLost(self.sid, None, None)
my_id = (u'CONN:{0}'.format(conn_id)) if conn_id is not None else \
(u'DB:{0}'.format(database))
@@ -247,6 +273,9 @@ WHERE db.oid = {0}""".format(did))
self.connections.pop(conn_info['conn_id'])
def release(self, database=None, conn_id=None, did=None):
# Stop the SSH tunnel if created.
self.stop_ssh_tunnel()
if did is not None:
if did in self.db_info and 'datname' in self.db_info[did]:
database = self.db_info[did]['datname']
@@ -332,3 +361,73 @@ WHERE db.oid = {0}""".format(did))
self.password, current_user.password
).decode()
os.environ[str(env)] = password
def create_ssh_tunnel(self, tunnel_password):
"""
This method is used to create ssh tunnel and update the IP Address and
IP Address and port to localhost and the local bind port return by the
SSHTunnelForwarder class.
:return: True if tunnel is successfully created else error message.
"""
# Fetch Logged in User Details.
user = User.query.filter_by(id=current_user.id).first()
if user is None:
return False, gettext("Unauthorized request.")
try:
tunnel_password = decrypt(tunnel_password, user.password)
# Handling of non ascii password (Python2)
if hasattr(str, 'decode'):
tunnel_password = \
tunnel_password.decode('utf-8').encode('utf-8')
# password is in bytes, for python3 we need it in string
elif isinstance(tunnel_password, bytes):
tunnel_password = tunnel_password.decode()
except Exception as e:
current_app.logger.exception(e)
return False, "Failed to decrypt the SSH tunnel " \
"password.\nError: {0}".format(str(e))
try:
# If authentication method is 1 then it uses identity file
# and password
if self.tunnel_authentication == 1:
self.tunnel_object = SSHTunnelForwarder(
self.tunnel_host,
ssh_username=self.tunnel_username,
ssh_pkey=get_complete_file_path(self.tunnel_identity_file),
ssh_private_key_password=tunnel_password,
remote_bind_address=(self.host, self.port)
)
else:
self.tunnel_object = SSHTunnelForwarder(
self.tunnel_host,
ssh_username=self.tunnel_username,
ssh_password=tunnel_password,
remote_bind_address=(self.host, self.port)
)
self.tunnel_object.start()
except BaseSSHTunnelForwarderError as e:
current_app.logger.exception(e)
return False, "Failed to create the SSH tunnel." \
"\nError: {0}".format(str(e))
# Update the port to communicate locally
self.local_bind_port = self.tunnel_object.local_bind_port
return True, None
def check_ssh_tunnel_alive(self):
# Check SSH Tunnel is alive or not. if it is not then
# raise the ConnectionLost exception.
if self.tunnel_object is None or not self.tunnel_object.is_active:
raise SSHTunnelConnectionLost(self.tunnel_host)
def stop_ssh_tunnel(self):
# Stop the SSH tunnel if created.
if self.tunnel_object and self.tunnel_object.is_active:
self.tunnel_object.stop()
self.local_bind_port = None
self.tunnel_object = None

View File

@@ -48,3 +48,35 @@ class ConnectionLost(HTTPException):
def __repr__(self):
return "Connection (id #{2}) lost for the server (#{0}) on " \
"database ({1})".format(self.sid, self.db, self.conn_id)
class SSHTunnelConnectionLost(HTTPException):
"""
Exception when connection to SSH tunnel is lost
"""
def __init__(self, _tunnel_host):
self.tunnel_host = _tunnel_host
HTTPException.__init__(self)
@property
def name(self):
return HTTP_STATUS_CODES.get(503, 'Service Unavailable')
def get_response(self, environ=None):
return service_unavailable(
_("Connection to the SSH Tunnel for host '{0}' has been lost. "
"Reconnect to the database server.").format(self.tunnel_host),
info="SSH_TUNNEL_CONNECTION_LOST",
data={
'tunnel_host': self.tunnel_host
}
)
def __str__(self):
return "Connection to the SSH Tunnel for host '{0}' has been lost. " \
"Reconnect to the database server".format(self.tunnel_host)
def __repr__(self):
return "Connection to the SSH Tunnel for host '{0}' has been lost. " \
"Reconnect to the database server".format(self.tunnel_host)