Refactor exc_callback invocation.

Replace _call_exc_callbacks with a function wrapper, which will automatically
call exception callbacks when an exception is raised from the function. This
removes the need to specify the function and its arguments twice (once in the
function call itself and once in _call_exc_callbacks).

Add some extra checks to existing exception callbacks.
This commit is contained in:
Jan Cholasta 2012-04-19 08:06:32 -04:00 committed by Martin Kosek
parent 856b9627be
commit 3ba9cc8eb4
6 changed files with 179 additions and 182 deletions

View File

@ -744,26 +744,28 @@ class CallbackInterface(Method):
else:
klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)
def _call_exc_callbacks(self, args, options, exc, call_func, *call_args, **call_kwargs):
rv = None
for i in xrange(len(getattr(self, 'EXC_CALLBACKS', []))):
callback = self.EXC_CALLBACKS[i]
try:
if hasattr(callback, 'im_self'):
rv = callback(
args, options, exc, call_func, *call_args, **call_kwargs
)
else:
rv = callback(
self, args, options, exc, call_func, *call_args,
**call_kwargs
)
except errors.ExecutionError, e:
if (i + 1) < len(self.EXC_CALLBACKS):
exc = e
continue
raise e
return rv
def _exc_wrapper(self, keys, options, call_func):
"""Function wrapper that automatically calls exception callbacks"""
def wrapped(*call_args, **call_kwargs):
# call call_func first
func = call_func
callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
while True:
try:
return func(*call_args, **call_kwargs)
except errors.ExecutionError, e:
if not callbacks:
raise
# call exc_callback in the next loop
callback = callbacks.pop(0)
if hasattr(callback, 'im_self'):
def exc_func(*args, **kwargs):
return callback(keys, options, e, call_func, *args, **kwargs)
else:
def exc_func(*args, **kwargs):
return callback(self, keys, options, e, call_func, *args, **kwargs)
func = exc_func
return wrapped
class BaseLDAPCommand(CallbackInterface, Command):
@ -883,17 +885,11 @@ last, after all sets and adds."""),
if needldapattrs:
try:
(dn, old_entry) = ldap.get_entry(
(dn, old_entry) = self._exc_wrapper(keys, options, ldap.get_entry)(
dn, needldapattrs, normalize=self.obj.normalize_dn
)
except errors.ExecutionError, e:
try:
(dn, old_entry) = self._call_exc_callbacks(
keys, options, e, ldap.get_entry, dn, [],
normalize=self.obj.normalize_dn
)
except errors.NotFound:
self.obj.handle_not_found(*keys)
except errors.NotFound:
self.obj.handle_not_found(*keys)
for attr in needldapattrs:
entry_attrs[attr] = old_entry.get(attr, [])
@ -1019,29 +1015,23 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
_check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.disallow_object_classes), entry_attrs.keys(), allow_only=False)
try:
ldap.add_entry(dn, entry_attrs, normalize=self.obj.normalize_dn)
except errors.ExecutionError, e:
try:
self._call_exc_callbacks(
keys, options, e, ldap.add_entry, dn, entry_attrs,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
parent = self.obj.parent_object
if parent:
raise errors.NotFound(
reason=self.obj.parent_not_found_msg % {
'parent': keys[-2],
'oname': self.api.Object[parent].object_name,
}
)
self._exc_wrapper(keys, options, ldap.add_entry)(dn, entry_attrs, normalize=self.obj.normalize_dn)
except errors.NotFound:
parent = self.obj.parent_object
if parent:
raise errors.NotFound(
reason=self.obj.container_not_found_msg % {
'container': self.obj.container_dn,
reason=self.obj.parent_not_found_msg % {
'parent': keys[-2],
'oname': self.api.Object[parent].object_name,
}
)
except errors.DuplicateEntry:
self.obj.handle_duplicate_entry(*keys)
raise errors.NotFound(
reason=self.obj.container_not_found_msg % {
'container': self.obj.container_dn,
}
)
except errors.DuplicateEntry:
self.obj.handle_duplicate_entry(*keys)
try:
if self.obj.rdn_attribute:
@ -1050,22 +1040,16 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
object_class = self.obj.object_class
else:
object_class = None
(dn, entry_attrs) = ldap.find_entry_by_attr(
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.find_entry_by_attr)(
self.obj.primary_key.name, keys[-1], object_class, attrs_list,
self.obj.container_dn
)
else:
(dn, entry_attrs) = ldap.get_entry(
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
dn, attrs_list, normalize=self.obj.normalize_dn
)
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, ldap.get_entry, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
self.obj.handle_not_found(*keys)
except errors.NotFound:
self.obj.handle_not_found(*keys)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
@ -1181,17 +1165,11 @@ class LDAPRetrieve(LDAPQuery):
dn = callback(self, ldap, dn, attrs_list, *keys, **options)
try:
(dn, entry_attrs) = ldap.get_entry(
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
dn, attrs_list, normalize=self.obj.normalize_dn
)
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, ldap.get_entry, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
self.obj.handle_not_found(*keys)
except errors.NotFound:
self.obj.handle_not_found(*keys)
if options.get('rights', False) and options.get('all', False):
entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
@ -1297,7 +1275,7 @@ class LDAPUpdate(LDAPQuery, crud.Update):
if self.obj.rdn_is_primary_key and self.obj.primary_key.name in entry_attrs:
# RDN change
ldap.update_entry_rdn(dn,
self._exc_wrapper(keys, options, ldap.update_entry_rdn)(dn,
unicode('%s=%s' % (self.obj.primary_key.name,
entry_attrs[self.obj.primary_key.name])))
rdnkeys = keys[:-1] + (entry_attrs[self.obj.primary_key.name], )
@ -1306,37 +1284,25 @@ class LDAPUpdate(LDAPQuery, crud.Update):
options['rdnupdate'] = True
rdnupdate = True
ldap.update_entry(dn, entry_attrs, normalize=self.obj.normalize_dn)
except errors.ExecutionError, e:
# Exception callbacks will need to test for options['rdnupdate']
# to decide what to do. An EmptyModlist in this context doesn't
# mean an error occurred, just that there were no other updates to
# perform.
try:
self._call_exc_callbacks(
keys, options, e, ldap.update_entry, dn, entry_attrs,
normalize=self.obj.normalize_dn
)
except errors.EmptyModlist, e:
if not rdnupdate:
raise e
except errors.NotFound:
self.obj.handle_not_found(*keys)
self._exc_wrapper(keys, options, ldap.update_entry)(dn, entry_attrs, normalize=self.obj.normalize_dn)
except errors.EmptyModlist, e:
if not rdnupdate:
raise e
except errors.NotFound:
self.obj.handle_not_found(*keys)
try:
(dn, entry_attrs) = ldap.get_entry(
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
dn, attrs_list, normalize=self.obj.normalize_dn
)
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, ldap.get_entry, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
raise errors.MidairCollision(
format=_('the entry was deleted while being modified')
)
except errors.NotFound:
raise errors.MidairCollision(
format=_('the entry was deleted while being modified')
)
if options.get('rights', False) and options.get('all', False):
entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
@ -1399,15 +1365,9 @@ class LDAPDelete(LDAPMultiQuery):
for (dn_, entry_attrs) in subentries:
delete_subtree(dn_)
try:
ldap.delete_entry(base_dn, normalize=self.obj.normalize_dn)
except errors.ExecutionError, e:
try:
self._call_exc_callbacks(
nkeys, options, e, ldap.delete_entry, base_dn,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
self.obj.handle_not_found(*nkeys)
self._exc_wrapper(nkeys, options, ldap.delete_entry)(base_dn, normalize=self.obj.normalize_dn)
except errors.NotFound:
self.obj.handle_not_found(*nkeys)
delete_subtree(dn)
@ -1560,17 +1520,11 @@ class LDAPAddMember(LDAPModMember):
)
try:
(dn, entry_attrs) = ldap.get_entry(
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
dn, attrs_list, normalize=self.obj.normalize_dn
)
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, ldap.get_entry, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
self.obj.handle_not_found(*keys)
except errors.NotFound:
self.obj.handle_not_found(*keys)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
@ -1668,17 +1622,11 @@ class LDAPRemoveMember(LDAPModMember):
time.sleep(.3)
try:
(dn, entry_attrs) = ldap.get_entry(
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
dn, attrs_list, normalize=self.obj.normalize_dn
)
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, ldap.get_entry, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
self.obj.handle_not_found(*keys)
except errors.NotFound:
self.obj.handle_not_found(*keys)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
@ -1884,20 +1832,13 @@ class LDAPSearch(BaseLDAPCommand, crud.Search):
)
try:
(entries, truncated) = ldap.find_entries(
(entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)(
filter, attrs_list, base_dn, scope,
time_limit=options.get('timelimit', None),
size_limit=options.get('sizelimit', None)
)
except errors.ExecutionError, e:
try:
(entries, truncated) = self._call_exc_callbacks(
args, options, e, ldap.find_entries, filter, attrs_list,
base_dn, scope=ldap.SCOPE_ONELEVEL,
normalize=self.obj.normalize_dn
)
except errors.NotFound:
(entries, truncated) = ([], False)
except errors.NotFound:
(entries, truncated) = ([], False)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
@ -2030,21 +1971,15 @@ class LDAPAddReverseMember(LDAPModReverseMember):
try:
options = {'%s' % self.member_attr: keys[-1]}
try:
result = self.api.Command[self.member_command](attr, **options)
result = self._exc_wrapper(keys, options, self.api.Command[self.member_command])(attr, **options)
if result['completed'] == 1:
completed = completed + 1
else:
failed['member'][self.reverse_attr].append((attr, result['failed']['member'][self.member_attr][0][1]))
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, self.member_command, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound, e:
msg = str(e)
(attr, msg) = msg.split(':', 1)
failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
except errors.NotFound, e:
msg = str(e)
(attr, msg) = msg.split(':', 1)
failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
except errors.PublicError, e:
failed['member'][self.reverse_attr].append((attr, unicode(msg)))
@ -2143,21 +2078,15 @@ class LDAPRemoveReverseMember(LDAPModReverseMember):
try:
options = {'%s' % self.member_attr: keys[-1]}
try:
result = self.api.Command[self.member_command](attr, **options)
result = self._exc_wrapper(keys, options, self.api.Command[self.member_command])(attr, **options)
if result['completed'] == 1:
completed = completed + 1
else:
failed['member'][self.reverse_attr].append((attr, result['failed']['member'][self.member_attr][0][1]))
except errors.ExecutionError, e:
try:
(dn, entry_attrs) = self._call_exc_callbacks(
keys, options, e, self.member_command, dn, attrs_list,
normalize=self.obj.normalize_dn
)
except errors.NotFound, e:
msg = str(e)
(attr, msg) = msg.split(':', 1)
failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
except errors.NotFound, e:
msg = str(e)
(attr, msg) = msg.split(':', 1)
failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
except errors.PublicError, e:
failed['member'][self.reverse_attr].append((attr, unicode(msg)))

View File

@ -642,12 +642,12 @@ class entitle_import(LDAPUpdate):
If we are adding the first entry there are no updates so EmptyModlist
will get thrown. Ignore it.
"""
if isinstance(exc, errors.EmptyModlist):
if not getattr(context, 'entitle_import', False):
raise exc
return (call_args, {})
else:
raise exc
if call_func.func_name == 'update_entry':
if isinstance(exc, errors.EmptyModlist):
if not getattr(context, 'entitle_import', False):
raise exc
return (call_args, {})
raise exc
def execute(self, *keys, **options):
super(entitle_import, self).execute(*keys, **options)
@ -729,9 +729,10 @@ class entitle_sync(LDAPUpdate):
return dn
def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
if isinstance(exc, errors.EmptyModlist):
# If there is nothing to change we are already synchronized.
return
if call_func.func_name == 'update_entry':
if isinstance(exc, errors.EmptyModlist):
# If there is nothing to change we are already synchronized.
return
raise exc
api.register(entitle_sync)

View File

@ -211,9 +211,10 @@ class group_mod(LDAPUpdate):
def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
# Check again for GID requirement in case someone tried to clear it
# using --setattr.
if isinstance(exc, errors.ObjectclassViolation):
if 'gidNumber' in exc.message and 'posixGroup' in exc.message:
raise errors.RequirementError(name='gid')
if call_func.func_name == 'update_entry':
if isinstance(exc, errors.ObjectclassViolation):
if 'gidNumber' in exc.message and 'posixGroup' in exc.message:
raise errors.RequirementError(name='gid')
raise exc
api.register(group_mod)

View File

@ -374,20 +374,19 @@ class permission_mod(LDAPUpdate):
return dn
def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
if isinstance(exc, errors.EmptyModlist):
aciupdate = getattr(context, 'aciupdate')
opts = copy.copy(options)
# Clear the aci attributes out of the permission entry
for o in self.obj.aci_attributes + ['all', 'raw', 'rights']:
try:
del opts[o]
except:
pass
if len(opts) > 0 and not aciupdate:
raise exc
else:
raise exc
if call_func.func_name == 'update_entry':
if isinstance(exc, errors.EmptyModlist):
aciupdate = getattr(context, 'aciupdate')
opts = copy.copy(options)
# Clear the aci attributes out of the permission entry
for o in self.obj.aci_attributes + ['all', 'raw', 'rights']:
try:
del opts[o]
except:
pass
if len(opts) == 0 or aciupdate:
return
raise exc
def post_callback(self, ldap, dn, entry_attrs, *keys, **options):
# rename the underlying ACI after the change to permission

View File

@ -414,11 +414,12 @@ class pwpolicy_mod(LDAPUpdate):
return dn
def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
if isinstance(exc, errors.EmptyModlist):
entry_attrs = call_args[1]
cosupdate = getattr(context, 'cosupdate')
if not entry_attrs or cosupdate:
return
if call_func.func_name == 'update_entry':
if isinstance(exc, errors.EmptyModlist):
entry_attrs = call_args[1]
cosupdate = getattr(context, 'cosupdate')
if not entry_attrs or cosupdate:
return
raise exc
api.register(pwpolicy_mod)

View File

@ -0,0 +1,66 @@
# Authors:
# Petr Viktorin <pviktori@redhat.com>
#
# Copyright (C) 2012 Red Hat
# see file 'COPYING' for use and warranty information
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
Test the `ipalib.plugins.baseldap` module.
"""
from ipalib import errors
from ipalib.plugins import baseldap
def test_exc_wrapper():
"""Test the CallbackInterface._exc_wrapper helper method"""
handled_exceptions = []
class test_callback(baseldap.CallbackInterface):
"""Fake IPA method"""
def test_fail(self):
self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
def fail(self, *args, **kwargs):
assert args == (1, 2)
assert kwargs == dict(a=1, b=2)
raise errors.ExecutionError('failure')
instance = test_callback()
# Test with one callback first
@test_callback.register_exc_callback
def handle_exception(self, keys, options, e, call_func, *args, **kwargs):
assert args == (1, 2)
assert kwargs == dict(a=1, b=2)
handled_exceptions.append(type(e))
instance.test_fail()
assert handled_exceptions == [errors.ExecutionError]
# Test with another callback added
handled_exceptions = []
def dont_handle(self, keys, options, e, call_func, *args, **kwargs):
assert args == (1, 2)
assert kwargs == dict(a=1, b=2)
handled_exceptions.append(None)
raise e
test_callback.register_exc_callback(dont_handle, first=True)
instance.test_fail()
assert handled_exceptions == [None, errors.ExecutionError]