Add support for SSH tunneled connections. Fixes #1447

This commit is contained in:
Akshay Joshi
2018-05-04 11:27:27 +01:00
committed by Dave Page
parent 455c45ea85
commit b7fb01ab04
26 changed files with 697 additions and 129 deletions

View File

@@ -378,3 +378,13 @@ try:
from config_local import *
except ImportError:
pass
##########################################################################
# SSH Tunneling supports only for Python 2.7 and 3.4+
##########################################################################
SUPPORT_SSH_TUNNEL = False
if (
(sys.version_info[0] == 2 and sys.version_info[1] >= 7) or
(sys.version_info[0] == 3 and sys.version_info[1] >= 4)
):
SUPPORT_SSH_TUNNEL = True

View File

@@ -0,0 +1,41 @@
"""Added columns for SSH tunneling
Revision ID: a68b374fe373
Revises: 50aad68f99c2
Create Date: 2018-04-05 13:59:57.588355
"""
from alembic import op
import sqlalchemy as sa
from pgadmin.model import db
# revision identifiers, used by Alembic.
revision = 'a68b374fe373'
down_revision = '50aad68f99c2'
branch_labels = None
depends_on = None
def upgrade():
db.engine.execute(
'ALTER TABLE server ADD COLUMN use_ssh_tunnel INTEGER DEFAULT 0'
)
db.engine.execute(
'ALTER TABLE server ADD COLUMN tunnel_host TEXT'
)
db.engine.execute(
'ALTER TABLE server ADD COLUMN tunnel_port TEXT'
)
db.engine.execute(
'ALTER TABLE server ADD COLUMN tunnel_username TEXT'
)
db.engine.execute(
'ALTER TABLE server ADD COLUMN tunnel_authentication INTEGER DEFAULT 0'
)
db.engine.execute(
'ALTER TABLE server ADD COLUMN tunnel_identity_file TEXT'
)
def downgrade():
pass

View File

@@ -617,7 +617,8 @@ def utils():
editor_insert_pair_brackets=insert_pair_brackets,
editor_indent_with_tabs=editor_indent_with_tabs,
app_name=config.APP_NAME,
pg_libpq_version=pg_libpq_version
pg_libpq_version=pg_libpq_version,
support_ssh_tunnel=config.SUPPORT_SSH_TUNNEL
),
200, {'Content-Type': 'application/x-javascript'})

View File

@@ -479,7 +479,13 @@ class ServerNode(PGChildNodeView):
'sslcompression': 'sslcompression',
'bgcolor': 'bgcolor',
'fgcolor': 'fgcolor',
'service': 'service'
'service': 'service',
'use_ssh_tunnel': 'use_ssh_tunnel',
'tunnel_host': 'tunnel_host',
'tunnel_port': 'tunnel_port',
'tunnel_username': 'tunnel_username',
'tunnel_authentication': 'tunnel_authentication',
'tunnel_identity_file': 'tunnel_identity_file',
}
disp_lbl = {
@@ -665,7 +671,19 @@ class ServerNode(PGChildNodeView):
'sslcrl': server.sslcrl if is_ssl else None,
'sslcompression': True if is_ssl and server.sslcompression
else False,
'service': server.service if server.service else None
'service': server.service if server.service else None,
'use_ssh_tunnel': server.use_ssh_tunnel
if server.use_ssh_tunnel else 0,
'tunnel_host': server.tunnel_host if server.tunnel_host
else None,
'tunnel_port': server.tunnel_port if server.tunnel_port
else 22,
'tunnel_username': server.tunnel_username
if server.tunnel_username else None,
'tunnel_identity_file': server.tunnel_identity_file
if server.tunnel_identity_file else None,
'tunnel_authentication': server.tunnel_authentication
if server.tunnel_authentication else 0
}
)
@@ -736,7 +754,13 @@ class ServerNode(PGChildNodeView):
sslcompression=1 if is_ssl and data['sslcompression'] else 0,
bgcolor=data.get('bgcolor', None),
fgcolor=data.get('fgcolor', None),
service=data.get('service', None)
service=data.get('service', None),
use_ssh_tunnel=data.get('use_ssh_tunnel', 0),
tunnel_host=data.get('tunnel_host', None),
tunnel_port=data.get('tunnel_port', 22),
tunnel_username=data.get('tunnel_username', None),
tunnel_authentication=data.get('tunnel_authentication', 0),
tunnel_identity_file=data.get('tunnel_identity_file', None)
)
db.session.add(server)
db.session.commit()
@@ -754,6 +778,7 @@ class ServerNode(PGChildNodeView):
have_password = False
password = None
passfile = None
tunnel_password = None
if 'password' in data and data["password"] != '':
# login with password
have_password = True
@@ -764,9 +789,15 @@ class ServerNode(PGChildNodeView):
setattr(server, 'passfile', passfile)
db.session.commit()
if 'tunnel_password' in data and data["tunnel_password"] != '':
tunnel_password = data['tunnel_password']
tunnel_password = \
encrypt(tunnel_password, current_user.password)
status, errmsg = conn.connect(
password=password,
passfile=passfile,
tunnel_password=tunnel_password,
server_types=ServerType.types()
)
if hasattr(str, 'decode') and errmsg is not None:
@@ -877,10 +908,11 @@ class ServerNode(PGChildNodeView):
res = conn.connected()
if res:
from pgadmin.utils.exception import ConnectionLost
from pgadmin.utils.exception import ConnectionLost, \
SSHTunnelConnectionLost
try:
conn.execute_scalar('SELECT 1')
except ConnectionLost:
except (ConnectionLost, SSHTunnelConnectionLost):
res = False
return make_json_response(data={'connected': res})
@@ -924,28 +956,37 @@ class ServerNode(PGChildNodeView):
password = None
passfile = None
tunnel_password = None
save_password = False
# Connect the Server
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid)
conn = manager.connection()
# If server using SSH Tunnel
if server.use_ssh_tunnel:
if 'tunnel_password' not in data:
return self.get_response_for_password(server, 428)
else:
tunnel_password = data['tunnel_password'] if 'tunnel_password'\
in data else None
# Encrypt the password before saving with user's login
# password key.
try:
tunnel_password = encrypt(tunnel_password, user.password) \
if tunnel_password is not None else \
server.tunnel_password
except Exception as e:
current_app.logger.exception(e)
return internal_server_error(errormsg=e.message)
if 'password' not in data:
conn_passwd = getattr(conn, 'password', None)
if conn_passwd is None and server.password is None and \
server.passfile is None and server.service is None:
# Return the password template in case password is not
# provided, or password has not been saved earlier.
return make_json_response(
success=0,
status=428,
result=render_template(
'servers/password.html',
server_label=server.name,
username=server.username,
_=gettext
)
)
return self.get_response_for_password(server, 428)
elif server.passfile and server.passfile != '':
passfile = server.passfile
else:
@@ -969,22 +1010,13 @@ class ServerNode(PGChildNodeView):
status, errmsg = conn.connect(
password=password,
passfile=passfile,
tunnel_password=tunnel_password,
server_types=ServerType.types()
)
except Exception as e:
current_app.logger.exception(e)
return make_json_response(
success=0,
status=401,
result=render_template(
'servers/password.html',
server_label=server.name,
username=server.username,
errmsg=getattr(e, 'message', str(e)),
_=gettext
)
)
return self.get_response_for_password(
server, 401, getattr(e, 'message', str(e)))
if not status:
if hasattr(str, 'decode'):
@@ -995,17 +1027,7 @@ class ServerNode(PGChildNodeView):
.format(server.id, server.name, errmsg)
)
return make_json_response(
success=0,
status=401,
result=render_template(
'servers/password.html',
server_label=server.name,
username=server.username,
errmsg=errmsg,
_=gettext
)
)
return self.get_response_for_password(server, 401, errmsg)
else:
if save_password and config.ALLOW_SAVE_PASSWORD:
try:
@@ -1376,5 +1398,34 @@ class ServerNode(PGChildNodeView):
)
return internal_server_error(errormsg=str(e))
def get_response_for_password(self, server, status, errmsg=None):
if server.use_ssh_tunnel:
return make_json_response(
success=0,
status=status,
result=render_template(
'servers/tunnel_password.html',
server_label=server.name,
username=server.username,
tunnel_username=server.tunnel_username,
tunnel_host=server.tunnel_host,
tunnel_identity_file=server.tunnel_identity_file,
errmsg=errmsg,
_=gettext
)
)
else:
return make_json_response(
success=0,
status=status,
result=render_template(
'servers/password.html',
server_label=server.name,
username=server.username,
errmsg=errmsg,
_=gettext
)
)
ServerNode.register_node_view(blueprint)

View File

@@ -669,6 +669,13 @@ define('pgadmin.node.server', [
sslrootcert: undefined,
sslcrl: undefined,
service: undefined,
use_ssh_tunnel: 0,
tunnel_host: undefined,
tunnel_port: 22,
tunnel_username: undefined,
tunnel_identity_file: undefined,
tunnel_password: undefined,
tunnel_authentication: 0,
},
// Default values!
initialize: function(attrs, args) {
@@ -695,8 +702,7 @@ define('pgadmin.node.server', [
},{
id: 'connected', label: gettext('Connected?'), type: 'switch',
mode: ['properties'], group: gettext('Connection'), 'options': {
'onText': gettext('True'), 'offText': gettext('False'), 'onColor': 'success',
'offColor': 'danger', 'size': 'small',
'onText': gettext('True'), 'offText': gettext('False'), 'size': 'small',
},
},{
id: 'version', label: gettext('Version'), type: 'text', group: null,
@@ -729,17 +735,35 @@ define('pgadmin.node.server', [
},{
id: 'password', label: gettext('Password'), type: 'password',
group: gettext('Connection'), control: 'input', mode: ['create'], deps: ['connect_now'],
visible: function(m) {
return m.get('connect_now') && m.isNew();
visible: function(model) {
return model.get('connect_now') && model.isNew();
},
},{
id: 'save_password', controlLabel: gettext('Save password?'),
type: 'checkbox', group: gettext('Connection'), mode: ['create'],
deps: ['connect_now'], visible: function(m) {
return m.get('connect_now') && m.isNew();
deps: ['connect_now', 'use_ssh_tunnel'], visible: function(model) {
return model.get('connect_now') && model.isNew();
},
disabled: function() {
return !current_user.allow_save_password;
disabled: function(model) {
if (!current_user.allow_save_password)
return true;
if (model.get('use_ssh_tunnel')) {
if (model.get('save_password')) {
Alertify.alert(
gettext('Stored Password'),
gettext('Database passwords cannot be stored when using SSH tunnelling. The \'Save password\' option has been turned off.')
);
}
setTimeout(function() {
model.set('save_password', false);
}, 10);
return true;
}
return false;
},
},{
id: 'role', label: gettext('Role'), type: 'text', group: gettext('Connection'),
@@ -782,51 +806,114 @@ define('pgadmin.node.server', [
},{
id: 'sslcompression', label: gettext('SSL compression?'), type: 'switch',
mode: ['edit', 'create'], group: gettext('SSL'),
'options': { 'onText': gettext('True'), 'offText': gettext('False'),
'onColor': 'success', 'offColor': 'danger', 'size': 'small'},
'options': {'size': 'small'},
deps: ['sslmode'], disabled: 'isSSL',
},{
id: 'sslcert', label: gettext('Client certificate'), type: 'text',
group: gettext('SSL'), mode: ['properties'],
deps: ['sslmode'],
visible: function(m) {
var sslcert = m.get('sslcert');
visible: function(model) {
var sslcert = model.get('sslcert');
return !_.isUndefined(sslcert) && !_.isNull(sslcert);
},
},{
id: 'sslkey', label: gettext('Client certificate key'), type: 'text',
group: gettext('SSL'), mode: ['properties'],
deps: ['sslmode'],
visible: function(m) {
var sslkey = m.get('sslkey');
visible: function(model) {
var sslkey = model.get('sslkey');
return !_.isUndefined(sslkey) && !_.isNull(sslkey);
},
},{
id: 'sslrootcert', label: gettext('Root certificate'), type: 'text',
group: gettext('SSL'), mode: ['properties'],
deps: ['sslmode'],
visible: function(m) {
var sslrootcert = m.get('sslrootcert');
visible: function(model) {
var sslrootcert = model.get('sslrootcert');
return !_.isUndefined(sslrootcert) && !_.isNull(sslrootcert);
},
},{
id: 'sslcrl', label: gettext('Certificate revocation list'), type: 'text',
group: gettext('SSL'), mode: ['properties'],
deps: ['sslmode'],
visible: function(m) {
var sslcrl = m.get('sslcrl');
visible: function(model) {
var sslcrl = model.get('sslcrl');
return !_.isUndefined(sslcrl) && !_.isNull(sslcrl);
},
},{
id: 'sslcompression', label: gettext('SSL compression?'), type: 'switch',
mode: ['properties'], group: gettext('SSL'),
'options': { 'onText': gettext('True'), 'offText': gettext('False'),
'onColor': 'success', 'offColor': 'danger', 'size': 'small'},
deps: ['sslmode'], visible: function(m) {
var sslmode = m.get('sslmode');
'options': {'size': 'small'},
deps: ['sslmode'], visible: function(model) {
var sslmode = model.get('sslmode');
return _.indexOf(SSL_MODES, sslmode) != -1;
},
},{
id: 'use_ssh_tunnel', label: gettext('Use SSH tunneling'), type: 'switch',
mode: ['properties', 'edit', 'create'], group: gettext('SSH Tunnel'),
'options': {'size': 'small'},
disabled: function(model) {
if (!pgAdmin.Browser.utils.support_ssh_tunnel) {
setTimeout(function() {
model.set('use_ssh_tunnel', 0);
}, 10);
return true;
}
return model.get('connected');
},
},{
id: 'tunnel_host', label: gettext('Tunnel host'), type: 'text', group: gettext('SSH Tunnel'),
mode: ['properties', 'edit', 'create'], deps: ['use_ssh_tunnel'],
disabled: function(model) {
return !model.get('use_ssh_tunnel');
},
},{
id: 'tunnel_port', label: gettext('Tunnel port'), type: 'int', group: gettext('SSH Tunnel'),
mode: ['properties', 'edit', 'create'], deps: ['use_ssh_tunnel'], max: 65535,
disabled: function(model) {
return !model.get('use_ssh_tunnel');
},
},{
id: 'tunnel_username', label: gettext('Username'), type: 'text', group: gettext('SSH Tunnel'),
mode: ['properties', 'edit', 'create'], deps: ['use_ssh_tunnel'],
disabled: function(model) {
return !model.get('use_ssh_tunnel');
},
},{
id: 'tunnel_authentication', label: gettext('Authentication'), type: 'switch',
mode: ['properties', 'edit', 'create'], group: gettext('SSH Tunnel'),
'options': {'onText': gettext('Identity file'),
'offText': gettext('Password'), 'size': 'small'},
deps: ['use_ssh_tunnel'],
disabled: function(model) {
return !model.get('use_ssh_tunnel');
},
}, {
id: 'tunnel_identity_file', label: gettext('Identity file'), type: 'text',
group: gettext('SSH Tunnel'), mode: ['edit', 'create'],
control: Backform.FileControl, dialog_type: 'select_file', supp_types: ['*'],
deps: ['tunnel_authentication', 'use_ssh_tunnel'],
disabled: function(model) {
if (!model.get('tunnel_authentication') || !model.get('use_ssh_tunnel')) {
setTimeout(function() {
model.set('tunnel_identity_file', '');
}, 10);
}
return !model.get('tunnel_authentication');
},
},{
id: 'tunnel_identity_file', label: gettext('Identity file'), type: 'text',
group: gettext('SSH Tunnel'), mode: ['properties'],
},{
id: 'tunnel_password', label: gettext('Password'), type: 'password',
group: gettext('SSH Tunnel'), control: 'input', mode: ['create'],
deps: ['use_ssh_tunnel'],
disabled: function(model) {
return !model.get('use_ssh_tunnel');
},
}, {
id: 'hostaddr', label: gettext('Host address'), type: 'text', group: gettext('Advanced'),
mode: ['properties', 'edit', 'create'], disabled: 'isConnected',
},{
@@ -841,8 +928,8 @@ define('pgadmin.node.server', [
},{
id: 'passfile', label: gettext('Password file'), type: 'text',
group: gettext('Advanced'), mode: ['properties'],
visible: function(m) {
var passfile = m.get('passfile');
visible: function(model) {
var passfile = model.get('passfile');
return !_.isUndefined(passfile) && !_.isNull(passfile);
},
},{

View File

@@ -0,0 +1,28 @@
<form name="frmPassword" id="frmPassword" style="height: 100%; width: 100%" onsubmit="return false;">
<div>{% if errmsg %}
<div class="highlight has-error">
<div class='control-label'>{{ errmsg }}</div>
</div>
{% endif %}
{% if tunnel_identity_file %}
<div><b>{{ _('SSH Tunnel password for the identity file \'{0}\' to connect the server "{1}"').format(tunnel_identity_file, tunnel_host) }}</b></div>
{% else %}
<div><b>{{ _('SSH Tunnel password for the user \'{0}\' to connect the server "{1}"').format(tunnel_username, tunnel_host) }}</b></div>
{% endif %}
<div style="padding: 5px; height: 1px;"></div>
<div style="width: 100%">
<span style="width: 97%;display: inline-block;">
<input style="width:100%" id="tunnel_password" class="form-control" name="tunnel_password" type="password">
</span>
</div>
<div style="padding: 5px; height: 1px;"></div>
<div><b>{{ _('Database server password for the user \'{0}\' to connect the server "{1}"').format(username, server_label) }}</b></div>
<div style="padding: 5px; height: 1px;"></div>
<div style="width: 100%">
<span style="width: 97%;display: inline-block;">
<input style="width:100%" id="password" class="form-control" name="password" type="password">
</span>
</div>
<div style="padding: 5px; height: 1px;"></div>
</div>
</form>

View File

@@ -0,0 +1,62 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2018, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
import json
from pgadmin.utils.route import BaseTestGenerator
from regression.python_test_utils import test_utils as utils
class ServersWithSSHTunnelAddTestCase(BaseTestGenerator):
""" This class will add the servers under default server group. """
scenarios = [
(
'Add server using SSH tunnel with password', dict(
url='/browser/server/obj/',
with_password=True
)
),
(
'Add server using SSH tunnel with identity file', dict(
url='/browser/server/obj/',
with_password=False
)
),
]
def setUp(self):
pass
def runTest(self):
""" This function will add the server under default server group."""
url = "{0}{1}/".format(self.url, utils.SERVER_GROUP)
# Add service name in the config
self.server['use_ssh_tunnel'] = 1
self.server['tunnel_host'] = '127.0.0.1'
self.server['tunnel_port'] = 22
self.server['tunnel_username'] = 'user'
if self.with_password:
self.server['tunnel_authentication'] = 0
else:
self.server['tunnel_authentication'] = 1
self.server['tunnel_identity_file'] = 'pkey_rsa'
response = self.tester.post(
url,
data=json.dumps(self.server),
content_type='html/json'
)
self.assertEquals(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.server_id = response_data['node']['_id']
def tearDown(self):
"""This function delete the server from SQLite """
utils.delete_server_with_api(self.tester, self.server_id)

View File

@@ -26,6 +26,7 @@ define('pgadmin.browser.utils',
is_indent_with_tabs: '{{ editor_indent_with_tabs }}' == 'True',
app_name: '{{ app_name }}',
pg_libpq_version: {{pg_libpq_version|e}},
support_ssh_tunnel: '{{ support_ssh_tunnel }}' == 'True',
counter: {total: 0, loaded: 0},
registerScripts: function (ctx) {

View File

@@ -145,6 +145,24 @@ class Server(db.Model):
bgcolor = db.Column(db.Text(10), nullable=True)
fgcolor = db.Column(db.Text(10), nullable=True)
service = db.Column(db.Text(), nullable=True)
use_ssh_tunnel = db.Column(
db.Integer(),
db.CheckConstraint('use_ssh_tunnel >= 0 AND use_ssh_tunnel <= 1'),
nullable=False
)
tunnel_host = db.Column(db.String(128), nullable=True)
tunnel_port = db.Column(
db.Integer(),
db.CheckConstraint('port <= 65534'),
nullable=True)
tunnel_username = db.Column(db.String(64), nullable=True)
tunnel_authentication = db.Column(
db.Integer(),
db.CheckConstraint('tunnel_authentication >= 0 AND '
'tunnel_authentication <= 1'),
nullable=False
)
tunnel_identity_file = db.Column(db.String(64), nullable=True)
class ModulePreference(db.Model):

View File

@@ -43,6 +43,15 @@ export class ModelValidation {
this.checkForEmpty('username', gettext('Username must be specified.'));
this.checkForEmpty('port', gettext('Port must be specified.'));
if (this.model.get('use_ssh_tunnel')) {
this.checkForEmpty('tunnel_host', gettext('SSH Tunnel host must be specified.'));
this.checkForEmpty('tunnel_port', gettext('SSH Tunnel port must be specified.'));
this.checkForEmpty('tunnel_username', gettext('SSH Tunnel username must be specified.'));
if (this.model.get('tunnel_authentication')) {
this.checkForEmpty('tunnel_identity_file', gettext('SSH Tunnel identity file must be specified.'));
}
}
this.model.errorModel.set(this.err);
if (_.size(this.err)) {

View File

@@ -27,7 +27,7 @@ from config import PG_DEFAULT_DRIVER
from pgadmin.utils.preferences import Preferences
from pgadmin.model import Server
from pgadmin.utils.driver import get_driver
from pgadmin.utils.exception import ConnectionLost
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
from pgadmin.tools.sqleditor.utils.query_tool_preferences import \
get_query_tool_keyboard_shortcuts, get_text_representation_of_shortcut
@@ -135,7 +135,7 @@ def initialize_datagrid(cmd_type, obj_type, sgid, sid, did, obj_id):
auto_reconnect=False,
use_binary_placeholder=True,
array_to_string=True)
except ConnectionLost as e:
except (ConnectionLost, SSHTunnelConnectionLost) as e:
raise
except Exception as e:
app.logger.error(e)
@@ -363,7 +363,7 @@ def initialize_query_tool(sgid, sid, did=None):
array_to_string=True)
if connect:
conn.connect()
except ConnectionLost as e:
except (ConnectionLost, SSHTunnelConnectionLost) as e:
raise
except Exception as e:
app.logger.error(e)

View File

@@ -34,7 +34,7 @@ from pgadmin.utils.ajax import make_json_response, bad_request, \
success_return, internal_server_error, unauthorized
from pgadmin.utils.driver import get_driver
from pgadmin.utils.menu import MenuItem
from pgadmin.utils.exception import ConnectionLost
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
from pgadmin.utils.sqlautocomplete.autocomplete import SQLAutoComplete
from pgadmin.tools.sqleditor.utils.query_tool_preferences import \
RegisterQueryToolPreferences
@@ -166,7 +166,7 @@ def check_transaction_status(trans_id):
use_binary_placeholder=True,
array_to_string=True
)
except ConnectionLost as e:
except (ConnectionLost, SSHTunnelConnectionLost) as e:
raise
except Exception as e:
current_app.logger.error(e)
@@ -212,7 +212,7 @@ def start_view_data(trans_id):
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(
trans_obj.sid)
default_conn = manager.connection(did=trans_obj.did)
except ConnectionLost as e:
except (ConnectionLost, SSHTunnelConnectionLost) as e:
raise
except Exception as e:
current_app.logger.error(e)
@@ -261,7 +261,7 @@ def start_view_data(trans_id):
# Execute sql asynchronously
try:
status, result = conn.execute_async(sql)
except ConnectionLost as e:
except (ConnectionLost, SSHTunnelConnectionLost) as e:
raise
else:
status = False

View File

@@ -25,7 +25,7 @@ from pgadmin.tools.sqleditor.utils.update_session_grid_transaction import \
update_session_grid_transaction
from pgadmin.utils.ajax import make_json_response, internal_server_error
from pgadmin.utils.driver import get_driver
from pgadmin.utils.exception import ConnectionLost
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
class StartRunningQuery:
@@ -61,7 +61,7 @@ class StartRunningQuery:
auto_reconnect=False,
use_binary_placeholder=True,
array_to_string=True)
except ConnectionLost:
except (ConnectionLost, SSHTunnelConnectionLost):
raise
except Exception as e:
self.logger.error(e)
@@ -127,7 +127,7 @@ class StartRunningQuery:
# and formatted_error is True.
try:
status, result = conn.execute_async(sql)
except ConnectionLost:
except (ConnectionLost, SSHTunnelConnectionLost):
raise
# If the transaction aborted for some reason and

View File

@@ -12,7 +12,7 @@ from flask import Response
import simplejson as json
from pgadmin.tools.sqleditor.utils.start_running_query import StartRunningQuery
from pgadmin.utils.exception import ConnectionLost
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
from pgadmin.utils.route import BaseTestGenerator
if sys.version_info < (3, 3):
@@ -176,6 +176,35 @@ class StartRunningQueryTest(BaseTestGenerator):
is_rollback_required=False,
apply_explain_plan_wrapper_if_needed_return_value='some sql',
expect_make_json_response_to_have_been_called_with=None,
expect_internal_server_error_called_with=None,
expected_logger_error=None,
expect_execute_void_called_with='some sql',
)),
('When SSHTunnelConnectionLost happens while retrieving the '
'database connection, '
'it returns an error',
dict(
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData={'123': dict(command_obj='')})
),
pickle_load_return=MagicMock(
conn_id=1,
update_fetched_row_cnt=MagicMock()
),
get_driver_exception=False,
get_connection_lost_exception=False,
manager_connection_exception=SSHTunnelConnectionLost('1.1.1.1'),
is_connected_to_server=False,
connection_connect_return=None,
execute_async_return_value=None,
is_begin_required=False,
is_rollback_required=False,
apply_explain_plan_wrapper_if_needed_return_value='some sql',
expect_make_json_response_to_have_been_called_with=None,
expect_internal_server_error_called_with=None,
expected_logger_error=None,

View File

@@ -211,13 +211,25 @@ class Connection(BaseConnection):
pg_conn = None
password = None
passfile = None
mgr = self.manager
manager = self.manager
encpass = kwargs['password'] if 'password' in kwargs else None
passfile = kwargs['passfile'] if 'passfile' in kwargs else None
tunnel_password = kwargs['tunnel_password'] if 'tunnel_password' in \
kwargs else None
# Check SSH Tunnel needs to be created
if manager.use_ssh_tunnel == 1 and tunnel_password is not None:
status, error = manager.create_ssh_tunnel(tunnel_password)
if not status:
return False, error
# Check SSH Tunnel is alive or not.
if manager.use_ssh_tunnel == 1:
manager.check_ssh_tunnel_alive()
if encpass is None:
encpass = self.password or getattr(mgr, 'password', None)
encpass = self.password or getattr(manager, 'password', None)
# Reset the existing connection password
if self.reconnecting is not False:
@@ -240,6 +252,7 @@ class Connection(BaseConnection):
password = password.decode()
except Exception as e:
manager.stop_ssh_tunnel()
current_app.logger.exception(e)
return False, \
_(
@@ -251,16 +264,16 @@ class Connection(BaseConnection):
# we will check for pgpass file availability from connection manager
# if it's present then we will use it
if not password and not encpass and not passfile:
passfile = mgr.passfile if mgr.passfile else None
passfile = manager.passfile if manager.passfile else None
try:
if hasattr(str, 'decode'):
database = self.db.encode('utf-8')
user = mgr.user.encode('utf-8')
user = manager.user.encode('utf-8')
conn_id = self.conn_id.encode('utf-8')
else:
database = self.db
user = mgr.user
user = manager.user
conn_id = self.conn_id
import os
@@ -268,21 +281,24 @@ class Connection(BaseConnection):
config.APP_NAME, conn_id)
pg_conn = psycopg2.connect(
host=mgr.host,
hostaddr=mgr.hostaddr,
port=mgr.port,
host=manager.local_bind_host if manager.use_ssh_tunnel
else manager.host,
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
else manager.hostaddr,
port=manager.local_bind_port if manager.use_ssh_tunnel
else manager.port,
database=database,
user=user,
password=password,
async=self.async,
passfile=get_complete_file_path(passfile),
sslmode=mgr.ssl_mode,
sslcert=get_complete_file_path(mgr.sslcert),
sslkey=get_complete_file_path(mgr.sslkey),
sslrootcert=get_complete_file_path(mgr.sslrootcert),
sslcrl=get_complete_file_path(mgr.sslcrl),
sslcompression=True if mgr.sslcompression else False,
service=mgr.service
sslmode=manager.ssl_mode,
sslcert=get_complete_file_path(manager.sslcert),
sslkey=get_complete_file_path(manager.sslkey),
sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcrl=get_complete_file_path(manager.sslcrl),
sslcompression=True if manager.sslcompression else False,
service=manager.service
)
# If connection is asynchronous then we will have to wait
@@ -291,6 +307,7 @@ class Connection(BaseConnection):
self._wait(pg_conn)
except psycopg2.Error as e:
manager.stop_ssh_tunnel()
if e.pgerror:
msg = e.pgerror
elif e.diag.message_detail:
@@ -317,6 +334,7 @@ class Connection(BaseConnection):
try:
status, msg = self._initialize(conn_id, **kwargs)
except Exception as e:
manager.stop_ssh_tunnel()
current_app.logger.exception(e)
self.conn = None
if not self.reconnecting:
@@ -324,7 +342,7 @@ class Connection(BaseConnection):
raise e
if status:
mgr._update_password(encpass)
manager._update_password(encpass)
else:
if not self.reconnecting:
self.wasConnected = False
@@ -342,7 +360,7 @@ class Connection(BaseConnection):
status, cur = self.__cursor()
formatted_exception_msg = self._formatted_exception_msg
mgr = self.manager
manager = self.manager
def _execute(cur, query, params=None):
try:
@@ -381,8 +399,8 @@ class Connection(BaseConnection):
return False, status
if mgr.role:
status = _execute(cur, u"SET ROLE TO %s", [mgr.role])
if manager.role:
status = _execute(cur, u"SET ROLE TO %s", [manager.role])
if status is not None:
self.conn.close()
@@ -401,7 +419,7 @@ class Connection(BaseConnection):
"Failed to setup the role with error message:\n{0}"
).format(status)
if mgr.ver is None:
if manager.ver is None:
status = _execute(cur, "SELECT version()")
if status is not None:
@@ -421,8 +439,8 @@ class Connection(BaseConnection):
if cur.rowcount > 0:
row = cur.fetchmany(1)[0]
mgr.ver = row['version']
mgr.sversion = self.conn.server_version
manager.ver = row['version']
manager.sversion = self.conn.server_version
status = _execute(cur, """
SELECT
@@ -434,14 +452,14 @@ FROM
WHERE db.datname = current_database()""")
if status is None:
mgr.db_info = mgr.db_info or dict()
manager.db_info = manager.db_info or dict()
if cur.rowcount > 0:
res = cur.fetchmany(1)[0]
mgr.db_info[res['did']] = res.copy()
manager.db_info[res['did']] = res.copy()
# We do not have database oid for the maintenance database.
if len(mgr.db_info) == 1:
mgr.did = res['did']
if len(manager.db_info) == 1:
manager.did = res['did']
status = _execute(cur, """
SELECT
@@ -453,33 +471,39 @@ WHERE
rolname = current_user""")
if status is None:
mgr.user_info = dict()
manager.user_info = dict()
if cur.rowcount > 0:
mgr.user_info = cur.fetchmany(1)[0]
manager.user_info = cur.fetchmany(1)[0]
if 'password' in kwargs:
mgr.password = kwargs['password']
manager.password = kwargs['password']
server_types = None
if 'server_types' in kwargs and isinstance(
kwargs['server_types'], list):
server_types = mgr.server_types = kwargs['server_types']
server_types = manager.server_types = kwargs['server_types']
if server_types is None:
from pgadmin.browser.server_groups.servers.types import ServerType
server_types = ServerType.types()
for st in server_types:
if st.instanceOf(mgr.ver):
mgr.server_type = st.stype
mgr.server_cls = st
if st.instanceOf(manager.ver):
manager.server_type = st.stype
manager.server_cls = st
break
mgr.update_session()
manager.update_session()
return True, None
def __cursor(self, server_cursor=False):
# Check SSH Tunnel is alive or not. If used by the database
# server for the connection.
if self.manager.use_ssh_tunnel == 1:
self.manager.check_ssh_tunnel_alive()
if self.wasConnected is False:
raise ConnectionLost(
self.manager.sid,
@@ -1166,9 +1190,9 @@ WHERE
if self.conn.closed:
self.conn = None
pg_conn = None
mgr = self.manager
manager = self.manager
password = getattr(mgr, 'password', None)
password = getattr(manager, 'password', None)
if password:
# Fetch Logged in User Details.
@@ -1181,20 +1205,23 @@ WHERE
try:
pg_conn = psycopg2.connect(
host=mgr.host,
hostaddr=mgr.hostaddr,
port=mgr.port,
host=manager.local_bind_host if manager.use_ssh_tunnel
else manager.host,
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
else manager.hostaddr,
port=manager.local_bind_port if manager.use_ssh_tunnel
else manager.port,
database=self.db,
user=mgr.user,
user=manager.user,
password=password,
passfile=get_complete_file_path(mgr.passfile),
sslmode=mgr.ssl_mode,
sslcert=get_complete_file_path(mgr.sslcert),
sslkey=get_complete_file_path(mgr.sslkey),
sslrootcert=get_complete_file_path(mgr.sslrootcert),
sslcrl=get_complete_file_path(mgr.sslcrl),
sslcompression=True if mgr.sslcompression else False,
service=mgr.service
passfile=get_complete_file_path(manager.passfile),
sslmode=manager.ssl_mode,
sslcert=get_complete_file_path(manager.sslcert),
sslkey=get_complete_file_path(manager.sslkey),
sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcrl=get_complete_file_path(manager.sslcrl),
sslcompression=True if manager.sslcompression else False,
service=manager.service
)
except psycopg2.Error as e:
@@ -1456,9 +1483,13 @@ Failed to reset the connection to the server due to following error:
try:
pg_conn = psycopg2.connect(
host=self.manager.host,
hostaddr=self.manager.hostaddr,
port=self.manager.port,
host=self.manager.local_bind_host if
self.manager.use_ssh_tunnel else self.manager.host,
hostaddr=self.manager.local_bind_host if
self.manager.use_ssh_tunnel else
self.manager.hostaddr,
port=self.manager.local_bind_port if
self.manager.use_ssh_tunnel else self.manager.port,
database=self.db,
user=self.manager.user,
password=password,

View File

@@ -12,14 +12,19 @@ Implementation of ServerManager
"""
import os
import datetime
import config
from flask import current_app, session
from flask_security import current_user
from flask_babelex import gettext
from pgadmin.utils import get_complete_file_path
from pgadmin.utils.crypto import decrypt
from .connection import Connection
from pgadmin.model import Server
from pgadmin.utils.exception import ConnectionLost
from pgadmin.model import Server, User
from pgadmin.utils.exception import ConnectionLost, SSHTunnelConnectionLost
if config.SUPPORT_SSH_TUNNEL:
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
class ServerManager(object):
@@ -32,6 +37,9 @@ class ServerManager(object):
def __init__(self, server):
self.connections = dict()
self.local_bind_host = '127.0.0.1'
self.local_bind_port = None
self.tunnel_object = None
self.update(server)
@@ -66,6 +74,20 @@ class ServerManager(object):
self.sslcrl = server.sslcrl
self.sslcompression = True if server.sslcompression else False
self.service = server.service
if config.SUPPORT_SSH_TUNNEL:
self.use_ssh_tunnel = server.use_ssh_tunnel
self.tunnel_host = server.tunnel_host
self.tunnel_port = server.tunnel_port
self.tunnel_username = server.tunnel_username
self.tunnel_authentication = server.tunnel_authentication
self.tunnel_identity_file = server.tunnel_identity_file
else:
self.use_ssh_tunnel = 0
self.tunnel_host = None
self.tunnel_port = 22
self.tunnel_username = None
self.tunnel_authentication = None
self.tunnel_identity_file = None
for con in self.connections:
self.connections[con]._release()
@@ -167,7 +189,11 @@ WHERE db.oid = {0}""".format(did))
))
if database is None:
raise ConnectionLost(self.sid, None, None)
# Check SSH Tunnel is alive or not.
if self.use_ssh_tunnel == 1:
self.check_ssh_tunnel_alive()
else:
raise ConnectionLost(self.sid, None, None)
my_id = (u'CONN:{0}'.format(conn_id)) if conn_id is not None else \
(u'DB:{0}'.format(database))
@@ -247,6 +273,9 @@ WHERE db.oid = {0}""".format(did))
self.connections.pop(conn_info['conn_id'])
def release(self, database=None, conn_id=None, did=None):
# Stop the SSH tunnel if created.
self.stop_ssh_tunnel()
if did is not None:
if did in self.db_info and 'datname' in self.db_info[did]:
database = self.db_info[did]['datname']
@@ -332,3 +361,73 @@ WHERE db.oid = {0}""".format(did))
self.password, current_user.password
).decode()
os.environ[str(env)] = password
def create_ssh_tunnel(self, tunnel_password):
"""
This method is used to create ssh tunnel and update the IP Address and
IP Address and port to localhost and the local bind port return by the
SSHTunnelForwarder class.
:return: True if tunnel is successfully created else error message.
"""
# Fetch Logged in User Details.
user = User.query.filter_by(id=current_user.id).first()
if user is None:
return False, gettext("Unauthorized request.")
try:
tunnel_password = decrypt(tunnel_password, user.password)
# Handling of non ascii password (Python2)
if hasattr(str, 'decode'):
tunnel_password = \
tunnel_password.decode('utf-8').encode('utf-8')
# password is in bytes, for python3 we need it in string
elif isinstance(tunnel_password, bytes):
tunnel_password = tunnel_password.decode()
except Exception as e:
current_app.logger.exception(e)
return False, "Failed to decrypt the SSH tunnel " \
"password.\nError: {0}".format(str(e))
try:
# If authentication method is 1 then it uses identity file
# and password
if self.tunnel_authentication == 1:
self.tunnel_object = SSHTunnelForwarder(
self.tunnel_host,
ssh_username=self.tunnel_username,
ssh_pkey=get_complete_file_path(self.tunnel_identity_file),
ssh_private_key_password=tunnel_password,
remote_bind_address=(self.host, self.port)
)
else:
self.tunnel_object = SSHTunnelForwarder(
self.tunnel_host,
ssh_username=self.tunnel_username,
ssh_password=tunnel_password,
remote_bind_address=(self.host, self.port)
)
self.tunnel_object.start()
except BaseSSHTunnelForwarderError as e:
current_app.logger.exception(e)
return False, "Failed to create the SSH tunnel." \
"\nError: {0}".format(str(e))
# Update the port to communicate locally
self.local_bind_port = self.tunnel_object.local_bind_port
return True, None
def check_ssh_tunnel_alive(self):
# Check SSH Tunnel is alive or not. if it is not then
# raise the ConnectionLost exception.
if self.tunnel_object is None or not self.tunnel_object.is_active:
raise SSHTunnelConnectionLost(self.tunnel_host)
def stop_ssh_tunnel(self):
# Stop the SSH tunnel if created.
if self.tunnel_object and self.tunnel_object.is_active:
self.tunnel_object.stop()
self.local_bind_port = None
self.tunnel_object = None

View File

@@ -48,3 +48,35 @@ class ConnectionLost(HTTPException):
def __repr__(self):
return "Connection (id #{2}) lost for the server (#{0}) on " \
"database ({1})".format(self.sid, self.db, self.conn_id)
class SSHTunnelConnectionLost(HTTPException):
"""
Exception when connection to SSH tunnel is lost
"""
def __init__(self, _tunnel_host):
self.tunnel_host = _tunnel_host
HTTPException.__init__(self)
@property
def name(self):
return HTTP_STATUS_CODES.get(503, 'Service Unavailable')
def get_response(self, environ=None):
return service_unavailable(
_("Connection to the SSH Tunnel for host '{0}' has been lost. "
"Reconnect to the database server.").format(self.tunnel_host),
info="SSH_TUNNEL_CONNECTION_LOST",
data={
'tunnel_host': self.tunnel_host
}
)
def __str__(self):
return "Connection to the SSH Tunnel for host '{0}' has been lost. " \
"Reconnect to the database server".format(self.tunnel_host)
def __repr__(self):
return "Connection to the SSH Tunnel for host '{0}' has been lost. " \
"Reconnect to the database server".format(self.tunnel_host)

View File

@@ -67,7 +67,52 @@ describe('Server#ModelValidation', () => {
});
});
describe('SSH Tunnel parameters', () => {
beforeEach(() => {
model.isNew.and.returnValue(true);
model.allValues['name'] = 'some name';
model.allValues['username'] = 'some username';
model.allValues['port'] = 12345;
model.allValues['host'] = 'some host';
model.allValues['db'] = 'some db';
model.allValues['hostaddr'] = '1.1.1.1';
model.allValues['use_ssh_tunnel'] = 1;
});
it('sets the "SSH Tunnel host must be specified." error', () => {
model.allValues['tunnel_port'] = 22;
model.allValues['tunnel_username'] = 'user1';
expect(modelValidation.validate()).toBe('SSH Tunnel host must be specified.');
expect(model.errorModel.set).toHaveBeenCalledWith({
tunnel_host:'SSH Tunnel host must be specified.',
});
});
it('sets the "SSH Tunnel port must be specified." error', () => {
model.allValues['tunnel_host'] = 'host';
model.allValues['tunnel_username'] = 'user1';
expect(modelValidation.validate()).toBe('SSH Tunnel port must be specified.');
expect(model.errorModel.set).toHaveBeenCalledWith({
tunnel_port:'SSH Tunnel port must be specified.',
});
});
it('sets the "SSH Tunnel username be specified." error', () => {
model.allValues['tunnel_host'] = 'host';
model.allValues['tunnel_port'] = 22;
expect(modelValidation.validate()).toBe('SSH Tunnel username must be specified.');
expect(model.errorModel.set).toHaveBeenCalledWith({
tunnel_username:'SSH Tunnel username must be specified.',
});
});
it('sets the "SSH Tunnel identity file be specified." error', () => {
model.allValues['tunnel_host'] = 'host';
model.allValues['tunnel_port'] = 22;
model.allValues['tunnel_username'] = 'user1';
model.allValues['tunnel_authentication'] = 1;
expect(modelValidation.validate()).toBe('SSH Tunnel identity file must be specified.');
expect(model.errorModel.set).toHaveBeenCalledWith({
tunnel_identity_file:'SSH Tunnel identity file must be specified.',
});
});
});
});
describe('When no parameters are valid', () => {