mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-02-25 18:55:31 -06:00
Add support for SSH tunneled connections. Fixes #1447
This commit is contained in:
@@ -211,13 +211,25 @@ class Connection(BaseConnection):
|
||||
pg_conn = None
|
||||
password = None
|
||||
passfile = None
|
||||
mgr = self.manager
|
||||
manager = self.manager
|
||||
|
||||
encpass = kwargs['password'] if 'password' in kwargs else None
|
||||
passfile = kwargs['passfile'] if 'passfile' in kwargs else None
|
||||
tunnel_password = kwargs['tunnel_password'] if 'tunnel_password' in \
|
||||
kwargs else None
|
||||
|
||||
# Check SSH Tunnel needs to be created
|
||||
if manager.use_ssh_tunnel == 1 and tunnel_password is not None:
|
||||
status, error = manager.create_ssh_tunnel(tunnel_password)
|
||||
if not status:
|
||||
return False, error
|
||||
|
||||
# Check SSH Tunnel is alive or not.
|
||||
if manager.use_ssh_tunnel == 1:
|
||||
manager.check_ssh_tunnel_alive()
|
||||
|
||||
if encpass is None:
|
||||
encpass = self.password or getattr(mgr, 'password', None)
|
||||
encpass = self.password or getattr(manager, 'password', None)
|
||||
|
||||
# Reset the existing connection password
|
||||
if self.reconnecting is not False:
|
||||
@@ -240,6 +252,7 @@ class Connection(BaseConnection):
|
||||
password = password.decode()
|
||||
|
||||
except Exception as e:
|
||||
manager.stop_ssh_tunnel()
|
||||
current_app.logger.exception(e)
|
||||
return False, \
|
||||
_(
|
||||
@@ -251,16 +264,16 @@ class Connection(BaseConnection):
|
||||
# we will check for pgpass file availability from connection manager
|
||||
# if it's present then we will use it
|
||||
if not password and not encpass and not passfile:
|
||||
passfile = mgr.passfile if mgr.passfile else None
|
||||
passfile = manager.passfile if manager.passfile else None
|
||||
|
||||
try:
|
||||
if hasattr(str, 'decode'):
|
||||
database = self.db.encode('utf-8')
|
||||
user = mgr.user.encode('utf-8')
|
||||
user = manager.user.encode('utf-8')
|
||||
conn_id = self.conn_id.encode('utf-8')
|
||||
else:
|
||||
database = self.db
|
||||
user = mgr.user
|
||||
user = manager.user
|
||||
conn_id = self.conn_id
|
||||
|
||||
import os
|
||||
@@ -268,21 +281,24 @@ class Connection(BaseConnection):
|
||||
config.APP_NAME, conn_id)
|
||||
|
||||
pg_conn = psycopg2.connect(
|
||||
host=mgr.host,
|
||||
hostaddr=mgr.hostaddr,
|
||||
port=mgr.port,
|
||||
host=manager.local_bind_host if manager.use_ssh_tunnel
|
||||
else manager.host,
|
||||
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
|
||||
else manager.hostaddr,
|
||||
port=manager.local_bind_port if manager.use_ssh_tunnel
|
||||
else manager.port,
|
||||
database=database,
|
||||
user=user,
|
||||
password=password,
|
||||
async=self.async,
|
||||
passfile=get_complete_file_path(passfile),
|
||||
sslmode=mgr.ssl_mode,
|
||||
sslcert=get_complete_file_path(mgr.sslcert),
|
||||
sslkey=get_complete_file_path(mgr.sslkey),
|
||||
sslrootcert=get_complete_file_path(mgr.sslrootcert),
|
||||
sslcrl=get_complete_file_path(mgr.sslcrl),
|
||||
sslcompression=True if mgr.sslcompression else False,
|
||||
service=mgr.service
|
||||
sslmode=manager.ssl_mode,
|
||||
sslcert=get_complete_file_path(manager.sslcert),
|
||||
sslkey=get_complete_file_path(manager.sslkey),
|
||||
sslrootcert=get_complete_file_path(manager.sslrootcert),
|
||||
sslcrl=get_complete_file_path(manager.sslcrl),
|
||||
sslcompression=True if manager.sslcompression else False,
|
||||
service=manager.service
|
||||
)
|
||||
|
||||
# If connection is asynchronous then we will have to wait
|
||||
@@ -291,6 +307,7 @@ class Connection(BaseConnection):
|
||||
self._wait(pg_conn)
|
||||
|
||||
except psycopg2.Error as e:
|
||||
manager.stop_ssh_tunnel()
|
||||
if e.pgerror:
|
||||
msg = e.pgerror
|
||||
elif e.diag.message_detail:
|
||||
@@ -317,6 +334,7 @@ class Connection(BaseConnection):
|
||||
try:
|
||||
status, msg = self._initialize(conn_id, **kwargs)
|
||||
except Exception as e:
|
||||
manager.stop_ssh_tunnel()
|
||||
current_app.logger.exception(e)
|
||||
self.conn = None
|
||||
if not self.reconnecting:
|
||||
@@ -324,7 +342,7 @@ class Connection(BaseConnection):
|
||||
raise e
|
||||
|
||||
if status:
|
||||
mgr._update_password(encpass)
|
||||
manager._update_password(encpass)
|
||||
else:
|
||||
if not self.reconnecting:
|
||||
self.wasConnected = False
|
||||
@@ -342,7 +360,7 @@ class Connection(BaseConnection):
|
||||
|
||||
status, cur = self.__cursor()
|
||||
formatted_exception_msg = self._formatted_exception_msg
|
||||
mgr = self.manager
|
||||
manager = self.manager
|
||||
|
||||
def _execute(cur, query, params=None):
|
||||
try:
|
||||
@@ -381,8 +399,8 @@ class Connection(BaseConnection):
|
||||
|
||||
return False, status
|
||||
|
||||
if mgr.role:
|
||||
status = _execute(cur, u"SET ROLE TO %s", [mgr.role])
|
||||
if manager.role:
|
||||
status = _execute(cur, u"SET ROLE TO %s", [manager.role])
|
||||
|
||||
if status is not None:
|
||||
self.conn.close()
|
||||
@@ -401,7 +419,7 @@ class Connection(BaseConnection):
|
||||
"Failed to setup the role with error message:\n{0}"
|
||||
).format(status)
|
||||
|
||||
if mgr.ver is None:
|
||||
if manager.ver is None:
|
||||
status = _execute(cur, "SELECT version()")
|
||||
|
||||
if status is not None:
|
||||
@@ -421,8 +439,8 @@ class Connection(BaseConnection):
|
||||
|
||||
if cur.rowcount > 0:
|
||||
row = cur.fetchmany(1)[0]
|
||||
mgr.ver = row['version']
|
||||
mgr.sversion = self.conn.server_version
|
||||
manager.ver = row['version']
|
||||
manager.sversion = self.conn.server_version
|
||||
|
||||
status = _execute(cur, """
|
||||
SELECT
|
||||
@@ -434,14 +452,14 @@ FROM
|
||||
WHERE db.datname = current_database()""")
|
||||
|
||||
if status is None:
|
||||
mgr.db_info = mgr.db_info or dict()
|
||||
manager.db_info = manager.db_info or dict()
|
||||
if cur.rowcount > 0:
|
||||
res = cur.fetchmany(1)[0]
|
||||
mgr.db_info[res['did']] = res.copy()
|
||||
manager.db_info[res['did']] = res.copy()
|
||||
|
||||
# We do not have database oid for the maintenance database.
|
||||
if len(mgr.db_info) == 1:
|
||||
mgr.did = res['did']
|
||||
if len(manager.db_info) == 1:
|
||||
manager.did = res['did']
|
||||
|
||||
status = _execute(cur, """
|
||||
SELECT
|
||||
@@ -453,33 +471,39 @@ WHERE
|
||||
rolname = current_user""")
|
||||
|
||||
if status is None:
|
||||
mgr.user_info = dict()
|
||||
manager.user_info = dict()
|
||||
if cur.rowcount > 0:
|
||||
mgr.user_info = cur.fetchmany(1)[0]
|
||||
manager.user_info = cur.fetchmany(1)[0]
|
||||
|
||||
if 'password' in kwargs:
|
||||
mgr.password = kwargs['password']
|
||||
manager.password = kwargs['password']
|
||||
|
||||
server_types = None
|
||||
if 'server_types' in kwargs and isinstance(
|
||||
kwargs['server_types'], list):
|
||||
server_types = mgr.server_types = kwargs['server_types']
|
||||
server_types = manager.server_types = kwargs['server_types']
|
||||
|
||||
if server_types is None:
|
||||
from pgadmin.browser.server_groups.servers.types import ServerType
|
||||
server_types = ServerType.types()
|
||||
|
||||
for st in server_types:
|
||||
if st.instanceOf(mgr.ver):
|
||||
mgr.server_type = st.stype
|
||||
mgr.server_cls = st
|
||||
if st.instanceOf(manager.ver):
|
||||
manager.server_type = st.stype
|
||||
manager.server_cls = st
|
||||
break
|
||||
|
||||
mgr.update_session()
|
||||
manager.update_session()
|
||||
|
||||
return True, None
|
||||
|
||||
def __cursor(self, server_cursor=False):
|
||||
|
||||
# Check SSH Tunnel is alive or not. If used by the database
|
||||
# server for the connection.
|
||||
if self.manager.use_ssh_tunnel == 1:
|
||||
self.manager.check_ssh_tunnel_alive()
|
||||
|
||||
if self.wasConnected is False:
|
||||
raise ConnectionLost(
|
||||
self.manager.sid,
|
||||
@@ -1166,9 +1190,9 @@ WHERE
|
||||
if self.conn.closed:
|
||||
self.conn = None
|
||||
pg_conn = None
|
||||
mgr = self.manager
|
||||
manager = self.manager
|
||||
|
||||
password = getattr(mgr, 'password', None)
|
||||
password = getattr(manager, 'password', None)
|
||||
|
||||
if password:
|
||||
# Fetch Logged in User Details.
|
||||
@@ -1181,20 +1205,23 @@ WHERE
|
||||
|
||||
try:
|
||||
pg_conn = psycopg2.connect(
|
||||
host=mgr.host,
|
||||
hostaddr=mgr.hostaddr,
|
||||
port=mgr.port,
|
||||
host=manager.local_bind_host if manager.use_ssh_tunnel
|
||||
else manager.host,
|
||||
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
|
||||
else manager.hostaddr,
|
||||
port=manager.local_bind_port if manager.use_ssh_tunnel
|
||||
else manager.port,
|
||||
database=self.db,
|
||||
user=mgr.user,
|
||||
user=manager.user,
|
||||
password=password,
|
||||
passfile=get_complete_file_path(mgr.passfile),
|
||||
sslmode=mgr.ssl_mode,
|
||||
sslcert=get_complete_file_path(mgr.sslcert),
|
||||
sslkey=get_complete_file_path(mgr.sslkey),
|
||||
sslrootcert=get_complete_file_path(mgr.sslrootcert),
|
||||
sslcrl=get_complete_file_path(mgr.sslcrl),
|
||||
sslcompression=True if mgr.sslcompression else False,
|
||||
service=mgr.service
|
||||
passfile=get_complete_file_path(manager.passfile),
|
||||
sslmode=manager.ssl_mode,
|
||||
sslcert=get_complete_file_path(manager.sslcert),
|
||||
sslkey=get_complete_file_path(manager.sslkey),
|
||||
sslrootcert=get_complete_file_path(manager.sslrootcert),
|
||||
sslcrl=get_complete_file_path(manager.sslcrl),
|
||||
sslcompression=True if manager.sslcompression else False,
|
||||
service=manager.service
|
||||
)
|
||||
|
||||
except psycopg2.Error as e:
|
||||
@@ -1456,9 +1483,13 @@ Failed to reset the connection to the server due to following error:
|
||||
|
||||
try:
|
||||
pg_conn = psycopg2.connect(
|
||||
host=self.manager.host,
|
||||
hostaddr=self.manager.hostaddr,
|
||||
port=self.manager.port,
|
||||
host=self.manager.local_bind_host if
|
||||
self.manager.use_ssh_tunnel else self.manager.host,
|
||||
hostaddr=self.manager.local_bind_host if
|
||||
self.manager.use_ssh_tunnel else
|
||||
self.manager.hostaddr,
|
||||
port=self.manager.local_bind_port if
|
||||
self.manager.use_ssh_tunnel else self.manager.port,
|
||||
database=self.db,
|
||||
user=self.manager.user,
|
||||
password=password,
|
||||
|
||||
@@ -12,14 +12,19 @@ Implementation of ServerManager
|
||||
"""
|
||||
import os
|
||||
import datetime
|
||||
import config
|
||||
from flask import current_app, session
|
||||
from flask_security import current_user
|
||||
from flask_babelex import gettext
|
||||
|
||||
from pgadmin.utils import get_complete_file_path
|
||||
from pgadmin.utils.crypto import decrypt
|
||||
from .connection import Connection
|
||||
from pgadmin.model import Server
|
||||
from pgadmin.utils.exception import ConnectionLost
|
||||
from pgadmin.model import Server, User
|
||||
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
|
||||
|
||||
if config.SUPPORT_SSH_TUNNEL:
|
||||
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
|
||||
|
||||
|
||||
class ServerManager(object):
|
||||
@@ -32,6 +37,9 @@ class ServerManager(object):
|
||||
|
||||
def __init__(self, server):
|
||||
self.connections = dict()
|
||||
self.local_bind_host = '127.0.0.1'
|
||||
self.local_bind_port = None
|
||||
self.tunnel_object = None
|
||||
|
||||
self.update(server)
|
||||
|
||||
@@ -66,6 +74,20 @@ class ServerManager(object):
|
||||
self.sslcrl = server.sslcrl
|
||||
self.sslcompression = True if server.sslcompression else False
|
||||
self.service = server.service
|
||||
if config.SUPPORT_SSH_TUNNEL:
|
||||
self.use_ssh_tunnel = server.use_ssh_tunnel
|
||||
self.tunnel_host = server.tunnel_host
|
||||
self.tunnel_port = server.tunnel_port
|
||||
self.tunnel_username = server.tunnel_username
|
||||
self.tunnel_authentication = server.tunnel_authentication
|
||||
self.tunnel_identity_file = server.tunnel_identity_file
|
||||
else:
|
||||
self.use_ssh_tunnel = 0
|
||||
self.tunnel_host = None
|
||||
self.tunnel_port = 22
|
||||
self.tunnel_username = None
|
||||
self.tunnel_authentication = None
|
||||
self.tunnel_identity_file = None
|
||||
|
||||
for con in self.connections:
|
||||
self.connections[con]._release()
|
||||
@@ -167,7 +189,11 @@ WHERE db.oid = {0}""".format(did))
|
||||
))
|
||||
|
||||
if database is None:
|
||||
raise ConnectionLost(self.sid, None, None)
|
||||
# Check SSH Tunnel is alive or not.
|
||||
if self.use_ssh_tunnel == 1:
|
||||
self.check_ssh_tunnel_alive()
|
||||
else:
|
||||
raise ConnectionLost(self.sid, None, None)
|
||||
|
||||
my_id = (u'CONN:{0}'.format(conn_id)) if conn_id is not None else \
|
||||
(u'DB:{0}'.format(database))
|
||||
@@ -247,6 +273,9 @@ WHERE db.oid = {0}""".format(did))
|
||||
self.connections.pop(conn_info['conn_id'])
|
||||
|
||||
def release(self, database=None, conn_id=None, did=None):
|
||||
# Stop the SSH tunnel if created.
|
||||
self.stop_ssh_tunnel()
|
||||
|
||||
if did is not None:
|
||||
if did in self.db_info and 'datname' in self.db_info[did]:
|
||||
database = self.db_info[did]['datname']
|
||||
@@ -332,3 +361,73 @@ WHERE db.oid = {0}""".format(did))
|
||||
self.password, current_user.password
|
||||
).decode()
|
||||
os.environ[str(env)] = password
|
||||
|
||||
def create_ssh_tunnel(self, tunnel_password):
|
||||
"""
|
||||
This method is used to create ssh tunnel and update the IP Address and
|
||||
IP Address and port to localhost and the local bind port return by the
|
||||
SSHTunnelForwarder class.
|
||||
:return: True if tunnel is successfully created else error message.
|
||||
"""
|
||||
# Fetch Logged in User Details.
|
||||
user = User.query.filter_by(id=current_user.id).first()
|
||||
if user is None:
|
||||
return False, gettext("Unauthorized request.")
|
||||
|
||||
try:
|
||||
tunnel_password = decrypt(tunnel_password, user.password)
|
||||
# Handling of non ascii password (Python2)
|
||||
if hasattr(str, 'decode'):
|
||||
tunnel_password = \
|
||||
tunnel_password.decode('utf-8').encode('utf-8')
|
||||
# password is in bytes, for python3 we need it in string
|
||||
elif isinstance(tunnel_password, bytes):
|
||||
tunnel_password = tunnel_password.decode()
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.exception(e)
|
||||
return False, "Failed to decrypt the SSH tunnel " \
|
||||
"password.\nError: {0}".format(str(e))
|
||||
|
||||
try:
|
||||
# If authentication method is 1 then it uses identity file
|
||||
# and password
|
||||
if self.tunnel_authentication == 1:
|
||||
self.tunnel_object = SSHTunnelForwarder(
|
||||
self.tunnel_host,
|
||||
ssh_username=self.tunnel_username,
|
||||
ssh_pkey=get_complete_file_path(self.tunnel_identity_file),
|
||||
ssh_private_key_password=tunnel_password,
|
||||
remote_bind_address=(self.host, self.port)
|
||||
)
|
||||
else:
|
||||
self.tunnel_object = SSHTunnelForwarder(
|
||||
self.tunnel_host,
|
||||
ssh_username=self.tunnel_username,
|
||||
ssh_password=tunnel_password,
|
||||
remote_bind_address=(self.host, self.port)
|
||||
)
|
||||
|
||||
self.tunnel_object.start()
|
||||
except BaseSSHTunnelForwarderError as e:
|
||||
current_app.logger.exception(e)
|
||||
return False, "Failed to create the SSH tunnel." \
|
||||
"\nError: {0}".format(str(e))
|
||||
|
||||
# Update the port to communicate locally
|
||||
self.local_bind_port = self.tunnel_object.local_bind_port
|
||||
|
||||
return True, None
|
||||
|
||||
def check_ssh_tunnel_alive(self):
|
||||
# Check SSH Tunnel is alive or not. if it is not then
|
||||
# raise the ConnectionLost exception.
|
||||
if self.tunnel_object is None or not self.tunnel_object.is_active:
|
||||
raise SSHTunnelConnectionLost(self.tunnel_host)
|
||||
|
||||
def stop_ssh_tunnel(self):
|
||||
# Stop the SSH tunnel if created.
|
||||
if self.tunnel_object and self.tunnel_object.is_active:
|
||||
self.tunnel_object.stop()
|
||||
self.local_bind_port = None
|
||||
self.tunnel_object = None
|
||||
|
||||
@@ -48,3 +48,35 @@ class ConnectionLost(HTTPException):
|
||||
def __repr__(self):
|
||||
return "Connection (id #{2}) lost for the server (#{0}) on " \
|
||||
"database ({1})".format(self.sid, self.db, self.conn_id)
|
||||
|
||||
|
||||
class SSHTunnelConnectionLost(HTTPException):
|
||||
"""
|
||||
Exception when connection to SSH tunnel is lost
|
||||
"""
|
||||
|
||||
def __init__(self, _tunnel_host):
|
||||
self.tunnel_host = _tunnel_host
|
||||
HTTPException.__init__(self)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return HTTP_STATUS_CODES.get(503, 'Service Unavailable')
|
||||
|
||||
def get_response(self, environ=None):
|
||||
return service_unavailable(
|
||||
_("Connection to the SSH Tunnel for host '{0}' has been lost. "
|
||||
"Reconnect to the database server.").format(self.tunnel_host),
|
||||
info="SSH_TUNNEL_CONNECTION_LOST",
|
||||
data={
|
||||
'tunnel_host': self.tunnel_host
|
||||
}
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Connection to the SSH Tunnel for host '{0}' has been lost. " \
|
||||
"Reconnect to the database server".format(self.tunnel_host)
|
||||
|
||||
def __repr__(self):
|
||||
return "Connection to the SSH Tunnel for host '{0}' has been lost. " \
|
||||
"Reconnect to the database server".format(self.tunnel_host)
|
||||
|
||||
Reference in New Issue
Block a user