Turn the LDAPError handler into a context manager

This has the advantage that the traceback is left intact if an error
other than LDAPError is raised.

Part of the work for: https://fedorahosted.org/freeipa/ticket/2660
This commit is contained in:
Petr Viktorin 2013-01-21 06:35:38 -05:00 committed by Martin Kosek
parent d11c337541
commit 1960945e28
2 changed files with 45 additions and 72 deletions

View File

@ -27,6 +27,7 @@ import time
import shutil import shutil
from decimal import Decimal from decimal import Decimal
from copy import deepcopy from copy import deepcopy
import contextlib
import ldap import ldap
import ldap as _ldap import ldap as _ldap
@ -776,24 +777,23 @@ class LDAPConnection(object):
""" """
return None return None
def handle_errors(self, e, arg_desc=None): @contextlib.contextmanager
"""Universal LDAPError handler def error_handler(self, arg_desc=None):
"""Context manager that handles LDAPErrors
:param e: The error to be raised
:param url: The URL of the server
""" """
if not isinstance(e, _ldap.TIMEOUT):
desc = e.args[0]['desc'].strip()
info = e.args[0].get('info', '').strip()
if arg_desc is not None:
info = "%s arguments: %s" % (info, arg_desc)
else:
desc = ''
info = ''
try: try:
# re-raise the error so we can handle it try:
raise e yield
except _ldap.TIMEOUT:
desc = ''
info = ''
raise
except ldap.LDAPError, e:
desc = e.args[0]['desc'].strip()
info = e.args[0].get('info', '').strip()
if arg_desc is not None:
info = "%s arguments: %s" % (info, arg_desc)
raise
except _ldap.NO_SUCH_OBJECT: except _ldap.NO_SUCH_OBJECT:
raise errors.NotFound(reason=arg_desc or 'no such entry') raise errors.NotFound(reason=arg_desc or 'no such entry')
except _ldap.ALREADY_EXISTS: except _ldap.ALREADY_EXISTS:
@ -1103,24 +1103,23 @@ class LDAPConnection(object):
attrs_list = list(set(attrs_list)) attrs_list = list(set(attrs_list))
# pass arguments to python-ldap # pass arguments to python-ldap
try: with self.error_handler():
id = self.conn.search_ext( try:
base_dn, scope, filter, attrs_list, timeout=time_limit, id = self.conn.search_ext(
sizelimit=size_limit base_dn, scope, filter, attrs_list, timeout=time_limit,
) sizelimit=size_limit
while True: )
(objtype, res_list) = self.conn.result(id, 0) while True:
if not res_list: (objtype, res_list) = self.conn.result(id, 0)
break if not res_list:
if (objtype == _ldap.RES_SEARCH_ENTRY or break
(search_refs and if (objtype == _ldap.RES_SEARCH_ENTRY or
objtype == _ldap.RES_SEARCH_REFERENCE)): (search_refs and
res.append(res_list[0]) objtype == _ldap.RES_SEARCH_REFERENCE)):
except (_ldap.ADMINLIMIT_EXCEEDED, _ldap.TIMELIMIT_EXCEEDED, res.append(res_list[0])
_ldap.SIZELIMIT_EXCEEDED), e: except (_ldap.ADMINLIMIT_EXCEEDED, _ldap.TIMELIMIT_EXCEEDED,
truncated = True _ldap.SIZELIMIT_EXCEEDED), e:
except _ldap.LDAPError, e: truncated = True
self.handle_errors(e)
if not res and not truncated: if not res and not truncated:
raise errors.NotFound(reason='no such entry') raise errors.NotFound(reason='no such entry')
@ -1396,10 +1395,8 @@ class LDAPConnection(object):
# be just "if v": # be just "if v":
if v is not None and v != []) if v is not None and v != [])
try: with self.error_handler():
self.conn.add_s(dn, list(attrs.iteritems())) self.conn.add_s(dn, attrs.items())
except _ldap.LDAPError, e:
self.handle_errors(e)
def update_entry_rdn(self, dn, new_rdn, del_old=True): def update_entry_rdn(self, dn, new_rdn, del_old=True):
""" """
@ -1415,11 +1412,9 @@ class LDAPConnection(object):
dn = self.normalize_dn(dn) dn = self.normalize_dn(dn)
if dn[0] == new_rdn: if dn[0] == new_rdn:
raise errors.EmptyModlist() raise errors.EmptyModlist()
try: with self.error_handler():
self.conn.rename_s(dn, new_rdn, delold=int(del_old)) self.conn.rename_s(dn, new_rdn, delold=int(del_old))
time.sleep(.3) # Give memberOf plugin a chance to work time.sleep(.3) # Give memberOf plugin a chance to work
except _ldap.LDAPError, e:
self.handle_errors(e)
def _generate_modlist(self, dn, entry_attrs, normalize): def _generate_modlist(self, dn, entry_attrs, normalize):
assert isinstance(dn, DN) assert isinstance(dn, DN)
@ -1500,10 +1495,8 @@ class LDAPConnection(object):
raise errors.EmptyModlist() raise errors.EmptyModlist()
# pass arguments to python-ldap # pass arguments to python-ldap
try: with self.error_handler():
self.conn.modify_s(dn, modlist) self.conn.modify_s(dn, modlist)
except _ldap.LDAPError, e:
self.handle_errors(e)
def delete_entry(self, entry_or_dn, normalize=None): def delete_entry(self, entry_or_dn, normalize=None):
"""Delete an entry given either the DN or the entry itself""" """Delete an entry given either the DN or the entry itself"""
@ -1515,10 +1508,8 @@ class LDAPConnection(object):
assert normalize is None assert normalize is None
dn = entry_or_dn.dn dn = entry_or_dn.dn
try: with self.error_handler():
self.conn.delete_s(dn) self.conn.delete_s(dn)
except _ldap.LDAPError, e:
self.handle_errors(e)
class IPAdmin(LDAPConnection): class IPAdmin(LDAPConnection):
@ -1581,21 +1572,16 @@ class IPAdmin(LDAPConnection):
This is executed after the connection is bound to fill in some useful This is executed after the connection is bound to fill in some useful
values. values.
""" """
try: with self.error_handler():
ent = self.getEntry(DN(('cn', 'config'), ('cn', 'ldbm database'), ('cn', 'plugins'), ('cn', 'config')), ent = self.getEntry(DN(('cn', 'config'), ('cn', 'ldbm database'), ('cn', 'plugins'), ('cn', 'config')),
ldap.SCOPE_BASE, '(objectclass=*)', ldap.SCOPE_BASE, '(objectclass=*)',
[ 'nsslapd-directory' ]) [ 'nsslapd-directory' ])
self.dbdir = os.path.dirname(ent.getValue('nsslapd-directory')) self.dbdir = os.path.dirname(ent.getValue('nsslapd-directory'))
except ldap.LDAPError, e:
self.__handle_errors(e)
def __str__(self): def __str__(self):
return self.host + ":" + str(self.port) return self.host + ":" + str(self.port)
def __handle_errors(self, e, **kw):
return self.handle_errors(e, **kw)
def __wait_for_connection(self, timeout): def __wait_for_connection(self, timeout):
lurl = ldapurl.LDAPUrl(self.ldap_uri) lurl = ldapurl.LDAPUrl(self.ldap_uri)
if lurl.urlscheme == 'ldapi': if lurl.urlscheme == 'ldapi':
@ -1671,10 +1657,8 @@ class IPAdmin(LDAPConnection):
if len(modlist) == 0: if len(modlist) == 0:
raise errors.EmptyModlist raise errors.EmptyModlist
try: with self.error_handler():
self.modify_s(dn, modlist) self.modify_s(dn, modlist)
except ldap.LDAPError, e:
self.__handle_errors(e)
return True return True
def generateModList(self, old_entry, new_entry): def generateModList(self, old_entry, new_entry):
@ -1752,10 +1736,8 @@ class IPAdmin(LDAPConnection):
modlist.append((operation, "nsAccountlock", "TRUE")) modlist.append((operation, "nsAccountlock", "TRUE"))
try: with self.error_handler():
self.modify_s(dn, modlist) self.modify_s(dn, modlist)
except ldap.LDAPError, e:
self.__handle_errors(e)
return True return True
def deleteEntry(self, dn): def deleteEntry(self, dn):

View File

@ -128,7 +128,7 @@ class ldap2(LDAPConnection, CrudBackend):
if debug_level: if debug_level:
_ldap.set_option(_ldap.OPT_DEBUG_LEVEL, debug_level) _ldap.set_option(_ldap.OPT_DEBUG_LEVEL, debug_level)
try: with self.error_handler():
force_updates = api.env.context in ('installer', 'updates') force_updates = api.env.context in ('installer', 'updates')
conn = IPASimpleLDAPObject( conn = IPASimpleLDAPObject(
self.ldap_uri, force_schema_updates=force_updates) self.ldap_uri, force_schema_updates=force_updates)
@ -167,9 +167,6 @@ class ldap2(LDAPConnection, CrudBackend):
else: else:
conn.simple_bind_s(bind_dn, bind_pw) conn.simple_bind_s(bind_dn, bind_pw)
except _ldap.LDAPError, e:
self.handle_errors(e)
return conn return conn
def destroy_connection(self): def destroy_connection(self):
@ -346,18 +343,14 @@ class ldap2(LDAPConnection, CrudBackend):
# The python-ldap passwd command doesn't verify the old password # The python-ldap passwd command doesn't verify the old password
# so we'll do a simple bind to validate it. # so we'll do a simple bind to validate it.
if old_pass != '': if old_pass != '':
try: with self.error_handler():
conn = IPASimpleLDAPObject( conn = IPASimpleLDAPObject(
self.ldap_uri, force_schema_updates=False) self.ldap_uri, force_schema_updates=False)
conn.simple_bind_s(dn, old_pass) conn.simple_bind_s(dn, old_pass)
conn.unbind() conn.unbind()
except _ldap.LDAPError, e:
self.handle_errors(e)
try: with self.error_handler():
self.conn.passwd_s(dn, old_pass, new_pass) self.conn.passwd_s(dn, old_pass, new_pass)
except _ldap.LDAPError, e:
self.handle_errors(e)
def add_entry_to_group(self, dn, group_dn, member_attr='member', allow_same=False): def add_entry_to_group(self, dn, group_dn, member_attr='member', allow_same=False):
""" """
@ -473,10 +466,8 @@ class ldap2(LDAPConnection, CrudBackend):
mod = [(_ldap.MOD_REPLACE, 'krbprincipalkey', None), mod = [(_ldap.MOD_REPLACE, 'krbprincipalkey', None),
(_ldap.MOD_REPLACE, 'krblastpwdchange', None)] (_ldap.MOD_REPLACE, 'krblastpwdchange', None)]
try: with self.error_handler():
self.conn.modify_s(dn, mod) self.conn.modify_s(dn, mod)
except _ldap.LDAPError, e:
self.handle_errors(e)
# CrudBackend methods # CrudBackend methods