Rework the CallbackInterface

Fix several problems with the callback interface:
- Automatically registered callbacks (i.e. methods named
    exc_callback, pre_callback etc) were registered on every
    instantiation.
    Fix: Do not register callbacks in __init__; instead return the
    method when asked for it.
- The calling code had to distinguish between bound methods and
    plain functions by checking the 'im_self' attribute.
    Fix: Always return the "default" callback as an unbound method.
    Registered callbacks now always take the extra `self` argument,
    whether they happen to be bound methods or not.
    Calling code now always needs to pass the `self` argument.
- Did not work well with inheritance: due to the fact that Python
    looks up missing attributes in superclasses, callbacks could
    get attached to a superclass if it was instantiated early enough. *
    Fix: Instead of attribute lookup, use a dictionary with class keys.
- The interface included the callback types, which are LDAP-specific.
    Fix: Create generic register_callback and get_callback mehods,
    move LDAP-specific code to BaseLDAPCommand

Update code that calls the callbacks.
Add tests.
Remove lint exceptions for CallbackInterface.

* https://fedorahosted.org/freeipa/ticket/2674
This commit is contained in:
Petr Viktorin 2012-04-25 10:31:10 -04:00 committed by Martin Kosek
parent f52fa2a018
commit 9960149e3f
4 changed files with 235 additions and 217 deletions

View File

@ -1195,8 +1195,13 @@ class cli(backend.Executioner):
param.label, param.confirm
)
for callback in getattr(cmd, 'INTERACTIVE_PROMPT_CALLBACKS', []):
callback(kw)
try:
callbacks = cmd.get_callbacks('interactive_prompt')
except AttributeError:
pass
else:
for callback in callbacks:
callback(cmd, kw)
def load_files(self, cmd, kw):
"""

View File

@ -690,93 +690,57 @@ def _check_limit_object_class(attributes, attrs, allow_only):
if len(limitattrs) > 0 and allow_only:
raise errors.ObjectclassViolation(info='attribute "%(attribute)s" not allowed' % dict(attribute=limitattrs[0]))
class CallbackInterface(Method):
"""Callback registration interface
This class's subclasses allow different types of callbacks to be added and
removed to them.
Registering a callback is done either by ``register_callback``, or by
defining a ``<type>_callback`` method.
Subclasses should define the `_callback_registry` attribute as a dictionary
mapping allowed callback types to (initially) empty dictionaries.
"""
Callback registration interface
"""
def __init__(self):
#pylint: disable=E1003
if not hasattr(self.__class__, 'PRE_CALLBACKS'):
self.__class__.PRE_CALLBACKS = []
if not hasattr(self.__class__, 'POST_CALLBACKS'):
self.__class__.POST_CALLBACKS = []
if not hasattr(self.__class__, 'EXC_CALLBACKS'):
self.__class__.EXC_CALLBACKS = []
if not hasattr(self.__class__, 'INTERACTIVE_PROMPT_CALLBACKS'):
self.__class__.INTERACTIVE_PROMPT_CALLBACKS = []
if hasattr(self, 'pre_callback'):
self.register_pre_callback(self.pre_callback, True)
if hasattr(self, 'post_callback'):
self.register_post_callback(self.post_callback, True)
if hasattr(self, 'exc_callback'):
self.register_exc_callback(self.exc_callback, True)
if hasattr(self, 'interactive_prompt_callback'):
self.register_interactive_prompt_callback(
self.interactive_prompt_callback, True) #pylint: disable=E1101
super(Method, self).__init__()
_callback_registry = dict()
@classmethod
def register_pre_callback(klass, callback, first=False):
assert callable(callback)
if not hasattr(klass, 'PRE_CALLBACKS'):
klass.PRE_CALLBACKS = []
if first:
klass.PRE_CALLBACKS.insert(0, callback)
else:
klass.PRE_CALLBACKS.append(callback)
@classmethod
def register_post_callback(klass, callback, first=False):
assert callable(callback)
if not hasattr(klass, 'POST_CALLBACKS'):
klass.POST_CALLBACKS = []
if first:
klass.POST_CALLBACKS.insert(0, callback)
else:
klass.POST_CALLBACKS.append(callback)
@classmethod
def register_exc_callback(klass, callback, first=False):
assert callable(callback)
if not hasattr(klass, 'EXC_CALLBACKS'):
klass.EXC_CALLBACKS = []
if first:
klass.EXC_CALLBACKS.insert(0, callback)
else:
klass.EXC_CALLBACKS.append(callback)
@classmethod
def register_interactive_prompt_callback(klass, callback, first=False):
assert callable(callback)
if not hasattr(klass, 'INTERACTIVE_PROMPT_CALLBACKS'):
klass.INTERACTIVE_PROMPT_CALLBACKS = []
if first:
klass.INTERACTIVE_PROMPT_CALLBACKS.insert(0, callback)
else:
klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)
def _exc_wrapper(self, keys, options, call_func):
"""Function wrapper that automatically calls exception callbacks"""
def wrapped(*call_args, **call_kwargs):
# call call_func first
func = call_func
callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
while True:
def get_callbacks(cls, callback_type):
"""Yield callbacks of the given type"""
# Use one shared callback registry, keyed on class, to avoid problems
# with missing attributes being looked up in superclasses
callbacks = cls._callback_registry[callback_type].get(cls, [None])
for callback in callbacks:
if callback is None:
try:
return func(*call_args, **call_kwargs)
except errors.ExecutionError, e:
if not callbacks:
raise
# call exc_callback in the next loop
callback = callbacks.pop(0)
if hasattr(callback, 'im_self'):
def exc_func(*args, **kwargs):
return callback(keys, options, e, call_func, *args, **kwargs)
yield getattr(cls, '%s_callback' % callback_type)
except AttributeError:
pass
else:
def exc_func(*args, **kwargs):
return callback(self, keys, options, e, call_func, *args, **kwargs)
func = exc_func
return wrapped
yield callback
@classmethod
def register_callback(cls, callback_type, callback, first=False):
"""Register a callback
:param callback_type: The callback type (e.g. 'pre', 'post')
:param callback: The callable added
:param first: If true, the new callback will be added before all
existing callbacks; otherwise it's added after them
Note that callbacks registered this way will be attached to this class
only, not to its subclasses.
"""
assert callable(callback)
try:
callbacks = cls._callback_registry[callback_type][cls]
except KeyError:
callbacks = cls._callback_registry[callback_type][cls] = [None]
if first:
callbacks.insert(0, callback)
else:
callbacks.append(callback)
class BaseLDAPCommand(CallbackInterface, Command):
@ -802,6 +766,8 @@ last, after all sets and adds."""),
exclude='webui',
)
_callback_registry = dict(pre={}, post={}, exc={}, interactive_prompt={})
def _convert_2_dict(self, attrs):
"""
Convert a string in the form of name/value pairs into a dictionary.
@ -961,6 +927,45 @@ last, after all sets and adds."""),
elif isinstance(entry_attrs[attr], (tuple, list)) and len(entry_attrs[attr]) == 1:
entry_attrs[attr] = entry_attrs[attr][0]
@classmethod
def register_pre_callback(cls, callback, first=False):
"""Shortcut for register_callback('pre', ...)"""
cls.register_callback('pre', callback, first)
@classmethod
def register_post_callback(cls, callback, first=False):
"""Shortcut for register_callback('post', ...)"""
cls.register_callback('post', callback, first)
@classmethod
def register_exc_callback(cls, callback, first=False):
"""Shortcut for register_callback('exc', ...)"""
cls.register_callback('exc', callback, first)
@classmethod
def register_interactive_prompt_callback(cls, callback, first=False):
"""Shortcut for register_callback('interactive_prompt', ...)"""
cls.register_callback('interactive_prompt', callback, first)
def _exc_wrapper(self, keys, options, call_func):
"""Function wrapper that automatically calls exception callbacks"""
def wrapped(*call_args, **call_kwargs):
# call call_func first
func = call_func
callbacks = list(self.get_callbacks('exc'))
while True:
try:
return func(*call_args, **call_kwargs)
except errors.ExecutionError, e:
if not callbacks:
raise
# call exc_callback in the next loop
callback = callbacks.pop(0)
def exc_func(*args, **kwargs):
return callback(
self, keys, options, e, call_func, *args, **kwargs)
func = exc_func
return wrapped
class LDAPCreate(BaseLDAPCommand, crud.Create):
"""
@ -1012,15 +1017,9 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
set(self.obj.default_attributes + entry_attrs.keys())
)
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
for callback in self.get_callbacks('pre'):
dn = callback(
ldap, dn, entry_attrs, attrs_list, *keys, **options
)
else:
dn = callback(
self, ldap, dn, entry_attrs, attrs_list, *keys, **options
)
self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
_check_single_value_attrs(self.params, entry_attrs)
ldap.get_schema()
@ -1064,10 +1063,7 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
except errors.NotFound:
self.obj.handle_not_found(*keys)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, entry_attrs, *keys, **options)
else:
for callback in self.get_callbacks('post'):
dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
entry_attrs['dn'] = dn
@ -1173,10 +1169,7 @@ class LDAPRetrieve(LDAPQuery):
else:
attrs_list = list(self.obj.default_attributes)
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, attrs_list, *keys, **options)
else:
for callback in self.get_callbacks('pre'):
dn = callback(self, ldap, dn, attrs_list, *keys, **options)
try:
@ -1189,10 +1182,7 @@ class LDAPRetrieve(LDAPQuery):
if options.get('rights', False) and options.get('all', False):
entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, entry_attrs, *keys, **options)
else:
for callback in self.get_callbacks('post'):
dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
self.obj.convert_attribute_members(entry_attrs, *keys, **options)
@ -1268,15 +1258,9 @@ class LDAPUpdate(LDAPQuery, crud.Update):
_check_single_value_attrs(self.params, entry_attrs)
_check_empty_attrs(self.obj.params, entry_attrs)
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
for callback in self.get_callbacks('pre'):
dn = callback(
ldap, dn, entry_attrs, attrs_list, *keys, **options
)
else:
dn = callback(
self, ldap, dn, entry_attrs, attrs_list, *keys, **options
)
self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
ldap.get_schema()
_check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.limit_object_classes), entry_attrs.keys(), allow_only=True)
@ -1323,10 +1307,7 @@ class LDAPUpdate(LDAPQuery, crud.Update):
if options.get('rights', False) and options.get('all', False):
entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, entry_attrs, *keys, **options)
else:
for callback in self.get_callbacks('post'):
dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
self.obj.convert_attribute_members(entry_attrs, *keys, **options)
@ -1362,10 +1343,7 @@ class LDAPDelete(LDAPMultiQuery):
nkeys = keys[:-1] + (pkey, )
dn = self.obj.get_dn(*nkeys, **options)
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, *nkeys, **options)
else:
for callback in self.get_callbacks('pre'):
dn = callback(self, ldap, dn, *nkeys, **options)
def delete_subtree(base_dn):
@ -1387,10 +1365,7 @@ class LDAPDelete(LDAPMultiQuery):
delete_subtree(dn)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
result = callback(ldap, dn, *nkeys, **options)
else:
for callback in self.get_callbacks('post'):
result = callback(self, ldap, dn, *nkeys, **options)
return result
@ -1503,13 +1478,8 @@ class LDAPAddMember(LDAPModMember):
dn = self.obj.get_dn(*keys, **options)
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, member_dns, failed, *keys, **options)
else:
dn = callback(
self, ldap, dn, member_dns, failed, *keys, **options
)
for callback in self.get_callbacks('pre'):
dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
completed = 0
for (attr, objs) in member_dns.iteritems():
@ -1542,16 +1512,10 @@ class LDAPAddMember(LDAPModMember):
except errors.NotFound:
self.obj.handle_not_found(*keys)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
(completed, dn) = callback(
ldap, completed, failed, dn, entry_attrs, *keys, **options
)
else:
for callback in self.get_callbacks('post'):
(completed, dn) = callback(
self, ldap, completed, failed, dn, entry_attrs, *keys,
**options
)
**options)
entry_attrs['dn'] = dn
self.obj.convert_attribute_members(entry_attrs, *keys, **options)
@ -1602,13 +1566,8 @@ class LDAPRemoveMember(LDAPModMember):
dn = self.obj.get_dn(*keys, **options)
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, member_dns, failed, *keys, **options)
else:
dn = callback(
self, ldap, dn, member_dns, failed, *keys, **options
)
for callback in self.get_callbacks('pre'):
dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
completed = 0
for (attr, objs) in member_dns.iteritems():
@ -1644,16 +1603,10 @@ class LDAPRemoveMember(LDAPModMember):
except errors.NotFound:
self.obj.handle_not_found(*keys)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
(completed, dn) = callback(
ldap, completed, failed, dn, entry_attrs, *keys, **options
)
else:
for callback in self.get_callbacks('post'):
(completed, dn) = callback(
self, ldap, completed, failed, dn, entry_attrs, *keys,
**options
)
**options)
entry_attrs['dn'] = dn
@ -1838,15 +1791,9 @@ class LDAPSearch(BaseLDAPCommand, crud.Search):
)
scope = ldap.SCOPE_ONELEVEL
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
for callback in self.get_callbacks('pre'):
(filter, base_dn, scope) = callback(
ldap, filter, attrs_list, base_dn, scope, *args, **options
)
else:
(filter, base_dn, scope) = callback(
self, ldap, filter, attrs_list, base_dn, scope, *args, **options
)
self, ldap, filter, attrs_list, base_dn, scope, *args, **options)
try:
(entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)(
@ -1857,10 +1804,7 @@ class LDAPSearch(BaseLDAPCommand, crud.Search):
except errors.NotFound:
(entries, truncated) = ([], False)
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
truncated = callback(ldap, entries, truncated, *args, **options)
else:
for callback in self.get_callbacks('post'):
truncated = callback(self, ldap, entries, truncated, *args, **options)
if self.sort_result_entries:
@ -1965,13 +1909,8 @@ class LDAPAddReverseMember(LDAPModReverseMember):
result = self.api.Command[self.show_command](keys[-1])['result']
dn = result['dn']
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, *keys, **options)
else:
dn = callback(
self, ldap, dn, *keys, **options
)
for callback in self.get_callbacks('pre'):
dn = callback(self, ldap, dn, *keys, **options)
if options.get('all', False):
attrs_list = ['*'] + self.obj.default_attributes
@ -2006,16 +1945,10 @@ class LDAPAddReverseMember(LDAPModReverseMember):
except Exception, e:
raise errors.ReverseMemberError(verb=_('added'), exc=str(e))
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
(completed, dn) = callback(
ldap, completed, failed, dn, entry_attrs, *keys, **options
)
else:
for callback in self.get_callbacks('post'):
(completed, dn) = callback(
self, ldap, completed, failed, dn, entry_attrs, *keys,
**options
)
**options)
entry_attrs['dn'] = dn
return dict(
@ -2072,13 +2005,8 @@ class LDAPRemoveReverseMember(LDAPModReverseMember):
result = self.api.Command[self.show_command](keys[-1])['result']
dn = result['dn']
for callback in self.PRE_CALLBACKS:
if hasattr(callback, 'im_self'):
dn = callback(ldap, dn, *keys, **options)
else:
dn = callback(
self, ldap, dn, *keys, **options
)
for callback in self.get_callbacks('pre'):
dn = callback(self, ldap, dn, *keys, **options)
if options.get('all', False):
attrs_list = ['*'] + self.obj.default_attributes
@ -2113,16 +2041,10 @@ class LDAPRemoveReverseMember(LDAPModReverseMember):
except Exception, e:
raise errors.ReverseMemberError(verb=_('removed'), exc=str(e))
for callback in self.POST_CALLBACKS:
if hasattr(callback, 'im_self'):
(completed, dn) = callback(
ldap, completed, failed, dn, entry_attrs, *keys, **options
)
else:
for callback in self.get_callbacks('post'):
(completed, dn) = callback(
self, ldap, completed, failed, dn, entry_attrs, *keys,
**options
)
**options)
entry_attrs['dn'] = dn
return dict(

View File

@ -51,8 +51,6 @@ class IPATypeChecker(TypeChecker):
'ipalib.plugable.Plugin': ['Command', 'Object', 'Method', 'Property',
'Backend', 'env', 'debug', 'info', 'warning', 'error', 'critical',
'exception', 'context', 'log'],
'ipalib.plugins.baseldap.CallbackInterface': ['pre_callback',
'post_callback', 'exc_callback'],
'ipalib.plugins.misc.env': ['env'],
'ipalib.parameters.Param': ['cli_name', 'cli_short_name', 'label',
'doc', 'required', 'multivalue', 'primary_key', 'normalizer',

View File

@ -24,11 +24,12 @@ Test the `ipalib.plugins.baseldap` module.
from ipalib import errors
from ipalib.plugins import baseldap
def test_exc_wrapper():
"""Test the CallbackInterface._exc_wrapper helper method"""
handled_exceptions = []
class test_callback(baseldap.CallbackInterface):
class test_callback(baseldap.BaseLDAPCommand):
"""Fake IPA method"""
def test_fail(self):
self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
@ -64,3 +65,95 @@ def test_exc_wrapper():
instance.test_fail()
assert handled_exceptions == [None, errors.ExecutionError]
def test_callback_registration():
class callbacktest_base(baseldap.CallbackInterface):
_callback_registry = dict(test={})
def test_callback(self, param):
messages.append(('Base test_callback', param))
def registered_callback(self, param):
messages.append(('Base registered callback', param))
callbacktest_base.register_callback('test', registered_callback)
class SomeClass(object):
def registered_callback(self, command, param):
messages.append(('Registered callback from another class', param))
callbacktest_base.register_callback('test', SomeClass().registered_callback)
class callbacktest_subclass(callbacktest_base):
pass
def subclass_callback(self, param):
messages.append(('Subclass registered callback', param))
callbacktest_subclass.register_callback('test', subclass_callback)
messages = []
instance = callbacktest_base()
for callback in instance.get_callbacks('test'):
callback(instance, 42)
assert messages == [
('Base test_callback', 42),
('Base registered callback', 42),
('Registered callback from another class', 42)]
messages = []
instance = callbacktest_subclass()
for callback in instance.get_callbacks('test'):
callback(instance, 42)
assert messages == [
('Base test_callback', 42),
('Subclass registered callback', 42)]
def test_exc_callback_registration():
messages = []
class callbacktest_base(baseldap.BaseLDAPCommand):
"""A method superclass with an exception callback"""
def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
"""Let the world know we saw the error, but don't handle it"""
messages.append('Base exc_callback')
raise exc
def test_fail(self):
"""Raise a handled exception"""
try:
self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
except Exception:
pass
def fail(self, *args, **kwargs):
"""Raise an error"""
raise errors.ExecutionError('failure')
base_instance = callbacktest_base()
class callbacktest_subclass(callbacktest_base):
pass
@callbacktest_subclass.register_exc_callback
def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
"""Subclass's private exception callback"""
messages.append('Subclass registered callback')
raise exc
subclass_instance = callbacktest_subclass()
# Make sure exception in base class is only handled by the base class
base_instance.test_fail()
assert messages == ['Base exc_callback']
@callbacktest_base.register_exc_callback
def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
"""Callback on super class; doesn't affect the subclass"""
messages.append('Superclass registered callback')
raise exc
# Make sure exception in subclass is only handled by both
messages = []
subclass_instance.test_fail()
assert messages == ['Base exc_callback', 'Subclass registered callback']