mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-01-24 07:16:52 -06:00
342 lines
11 KiB
Python
342 lines
11 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2024, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
"""A blueprint module implementing the Authentication."""
|
|
|
|
import config
|
|
import copy
|
|
import functools
|
|
from threading import Lock
|
|
|
|
from flask import current_app, flash, Response, request, url_for, \
|
|
session, redirect, render_template
|
|
from flask_babel import gettext
|
|
from flask_security.views import _security, _ctx
|
|
from flask_security.utils import logout_user, config_value
|
|
|
|
from flask_login import current_user
|
|
from flask_socketio import disconnect, ConnectionRefusedError
|
|
|
|
from pgadmin.model import db, User
|
|
from pgadmin.utils.constants import KERBEROS, INTERNAL, OAUTH2, LDAP,\
|
|
MessageType
|
|
import pgadmin.utils as pga_utils
|
|
from pgadmin.authenticate.registry import AuthSourceRegistry
|
|
|
|
MODULE_NAME = 'authenticate'
|
|
auth_obj = None
|
|
|
|
_URL_WITH_NEXT_PARAM = "{0}?next={1}"
|
|
|
|
|
|
class AuthLocker:
|
|
"""Implementing lock while authentication."""
|
|
lock = Lock()
|
|
|
|
def __enter__(self):
|
|
self.lock.acquire()
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
if self.lock.locked():
|
|
self.lock.release()
|
|
|
|
|
|
def get_logout_url() -> str:
|
|
"""
|
|
Returns the logout url based on the current authentication method.
|
|
|
|
Returns:
|
|
str: logout url
|
|
"""
|
|
BROWSER_INDEX = 'browser.index'
|
|
if config.SERVER_MODE and\
|
|
session['auth_source_manager']['current_source'] == \
|
|
KERBEROS:
|
|
return _URL_WITH_NEXT_PARAM.format(url_for(
|
|
'kerberos.logout'), url_for(BROWSER_INDEX))
|
|
elif config.SERVER_MODE and\
|
|
session['auth_source_manager']['current_source'] == \
|
|
OAUTH2:
|
|
return _URL_WITH_NEXT_PARAM.format(url_for(
|
|
'oauth2.logout'), url_for(BROWSER_INDEX))
|
|
|
|
return _URL_WITH_NEXT_PARAM.format(
|
|
url_for('security.logout'), url_for(BROWSER_INDEX))
|
|
|
|
|
|
def socket_login_required(f):
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
if not current_user.is_authenticated:
|
|
disconnect()
|
|
raise ConnectionRefusedError("Unauthorised !")
|
|
else:
|
|
return f(*args, **kwargs)
|
|
return wrapped
|
|
|
|
|
|
class AuthenticateModule(pga_utils.PgAdminModule):
|
|
def get_exposed_url_endpoints(self):
|
|
return ['authenticate.login']
|
|
|
|
|
|
blueprint = AuthenticateModule(MODULE_NAME, __name__, static_url_path='')
|
|
|
|
|
|
@blueprint.route('/login', endpoint='login', methods=['GET', 'POST'])
|
|
def login():
|
|
"""
|
|
Entry point for all the authentication sources.
|
|
The user input will be validated and authenticated.
|
|
"""
|
|
with AuthLocker():
|
|
return _login()
|
|
|
|
|
|
def _login():
|
|
"""
|
|
Internal authentication process locked by a mutex.
|
|
"""
|
|
form = _security.forms.get('login_form').cls(request.form)
|
|
if OAUTH2 in config.AUTHENTICATION_SOURCES \
|
|
and 'oauth2_button' in request.form:
|
|
# Sending empty form as oauth2 does not require form attribute
|
|
auth_obj = AuthSourceManager({}, copy.deepcopy(
|
|
config.AUTHENTICATION_SOURCES))
|
|
session['auth_obj'] = auth_obj
|
|
else:
|
|
auth_obj = AuthSourceManager(form, copy.deepcopy(
|
|
config.AUTHENTICATION_SOURCES))
|
|
|
|
session['auth_source_manager'] = None
|
|
|
|
username = form.data['email']
|
|
user = User.query.filter_by(username=username,
|
|
auth_source=INTERNAL).first()
|
|
|
|
if user:
|
|
if user.login_attempts >= config.MAX_LOGIN_ATTEMPTS > 0:
|
|
user.locked = True
|
|
else:
|
|
user.locked = False
|
|
db.session.commit()
|
|
|
|
if user.login_attempts >= config.MAX_LOGIN_ATTEMPTS > 0:
|
|
flash(gettext('Your account is locked. Please contact the '
|
|
'Administrator.'),
|
|
MessageType.WARNING)
|
|
logout_user()
|
|
return redirect(pga_utils.get_safe_post_logout_redirect())
|
|
|
|
# Validate the user
|
|
if not auth_obj.validate():
|
|
for field in form.errors:
|
|
flash_login_attempt_error = None
|
|
if user and field in config.LOGIN_ATTEMPT_FIELDS:
|
|
if config.MAX_LOGIN_ATTEMPTS > 0:
|
|
user.login_attempts += 1
|
|
left_attempts = \
|
|
config.MAX_LOGIN_ATTEMPTS - user.login_attempts
|
|
if left_attempts > 1:
|
|
flash_login_attempt_error = \
|
|
gettext('{0} more attempts remaining.'.
|
|
format(left_attempts))
|
|
else:
|
|
flash_login_attempt_error = \
|
|
gettext('{0} more attempt remaining.'.
|
|
format(left_attempts))
|
|
db.session.commit()
|
|
for error in form.errors[field]:
|
|
if flash_login_attempt_error:
|
|
error = error + flash_login_attempt_error
|
|
flash_login_attempt_error = None
|
|
flash(error, MessageType.WARNING)
|
|
|
|
return redirect(pga_utils.get_safe_post_logout_redirect())
|
|
|
|
# Authenticate the user
|
|
status, msg = auth_obj.authenticate()
|
|
if status:
|
|
# Login the user
|
|
status, msg = auth_obj.login()
|
|
current_auth_obj = auth_obj.as_dict()
|
|
|
|
if not status:
|
|
if current_auth_obj['current_source'] == \
|
|
KERBEROS:
|
|
return redirect('{0}?next={1}'.format(url_for(
|
|
'authenticate.kerberos_login'), url_for('browser.index')))
|
|
|
|
flash(msg, MessageType.ERROR)
|
|
return redirect(pga_utils.get_safe_post_logout_redirect())
|
|
|
|
session['auth_source_manager'] = current_auth_obj
|
|
|
|
if user:
|
|
user.login_attempts = 0
|
|
db.session.commit()
|
|
|
|
if 'auth_obj' in session:
|
|
session.pop('auth_obj')
|
|
return redirect(pga_utils.get_safe_post_login_redirect())
|
|
|
|
elif isinstance(msg, Response):
|
|
return msg
|
|
elif 'oauth2_button' in request.form and not isinstance(msg, str):
|
|
return msg
|
|
if 'auth_obj' in session:
|
|
session.pop('auth_obj')
|
|
flash(msg, MessageType.ERROR)
|
|
form_class = _security.forms.get('login_form').cls
|
|
form = form_class()
|
|
|
|
return _security.render_template(
|
|
config_value('LOGIN_USER_TEMPLATE'),
|
|
login_user_form=form, **_ctx('login'))
|
|
|
|
|
|
class AuthSourceManager:
|
|
"""This class will manage all the authentication sources.
|
|
"""
|
|
|
|
def __init__(self, form, sources):
|
|
self.form = form
|
|
self.auth_sources = sources
|
|
self.source = None
|
|
self.source_friendly_name = INTERNAL
|
|
self.current_source = INTERNAL
|
|
self.update_auth_sources()
|
|
|
|
def as_dict(self):
|
|
"""
|
|
Returns the dictionary object representing this object.
|
|
"""
|
|
|
|
res = dict()
|
|
res['source_friendly_name'] = self.source_friendly_name
|
|
res['auth_sources'] = self.auth_sources
|
|
res['current_source'] = self.current_source
|
|
|
|
return res
|
|
|
|
def update_auth_sources(self):
|
|
for auth_src in [KERBEROS, OAUTH2]:
|
|
if auth_src in self.auth_sources:
|
|
if 'internal_button' in request.form:
|
|
self.auth_sources.remove(auth_src)
|
|
else:
|
|
if INTERNAL in self.auth_sources:
|
|
self.auth_sources.remove(INTERNAL)
|
|
if LDAP in self.auth_sources:
|
|
self.auth_sources.remove(LDAP)
|
|
|
|
def set_current_source(self, source):
|
|
self.current_source = source
|
|
|
|
@property
|
|
def get_current_source(self):
|
|
return self.current_source
|
|
|
|
def set_source(self, source):
|
|
self.source = source
|
|
|
|
@property
|
|
def get_source(self):
|
|
return self.source
|
|
|
|
def set_source_friendly_name(self, name):
|
|
self.source_friendly_name = name
|
|
|
|
@property
|
|
def get_source_friendly_name(self):
|
|
return self.source_friendly_name
|
|
|
|
def validate(self):
|
|
"""Validate through all the sources."""
|
|
err_msg = None
|
|
for src in self.auth_sources:
|
|
source = get_auth_sources(src)
|
|
status, err_msg = source.validate(self.form)
|
|
if status:
|
|
return True
|
|
if err_msg:
|
|
flash(err_msg, MessageType.WARNING)
|
|
return False
|
|
|
|
def authenticate(self):
|
|
"""Authenticate through all the sources."""
|
|
status = False
|
|
msg = None
|
|
for src in self.auth_sources:
|
|
source = get_auth_sources(src)
|
|
self.set_source(source)
|
|
current_app.logger.debug(
|
|
"Authentication initiated via source: %s" %
|
|
source.get_source_name())
|
|
|
|
status, msg = source.authenticate(self.form)
|
|
|
|
if status:
|
|
self.set_current_source(source.get_source_name())
|
|
if msg is not None and 'username' in msg:
|
|
self.form._fields['email'].data = msg['username']
|
|
return status, msg
|
|
else:
|
|
current_app.logger.debug(
|
|
"Authentication initiated via source: %s is failed." %
|
|
source.get_source_name())
|
|
|
|
return status, msg
|
|
|
|
def login(self):
|
|
status, msg = self.source.login(self.form)
|
|
if status:
|
|
self.set_source_friendly_name(self.source.get_friendly_name())
|
|
current_app.logger.debug(
|
|
"Authentication and Login successfully done via source : %s" %
|
|
self.source.get_source_name())
|
|
|
|
# Set the login, logout view as per source if available
|
|
current_app.login_manager.login_view = getattr(
|
|
self.source, 'LOGIN_VIEW', 'security.login')
|
|
current_app.login_manager.logout_view = getattr(
|
|
self.source, 'LOGOUT_VIEW', 'security.logout')
|
|
|
|
return status, msg
|
|
|
|
|
|
def get_auth_sources(type):
|
|
"""Get the authenticated source object from the registry"""
|
|
|
|
auth_sources = getattr(current_app, '_pgadmin_auth_sources', None)
|
|
|
|
if auth_sources is None or not isinstance(auth_sources, dict):
|
|
auth_sources = dict()
|
|
|
|
if type in auth_sources:
|
|
return auth_sources[type]
|
|
|
|
auth_source = AuthSourceRegistry.get(type)
|
|
|
|
if auth_source is not None:
|
|
auth_sources[type] = auth_source
|
|
setattr(current_app, '_pgadmin_auth_sources', auth_sources)
|
|
|
|
return auth_source
|
|
|
|
|
|
def init_app(app):
|
|
auth_sources = dict()
|
|
|
|
setattr(app, '_pgadmin_auth_sources', auth_sources)
|
|
AuthSourceRegistry.load_modules(app)
|
|
|
|
return auth_sources
|