Implement simple LDAP cache layer

Insert a class before LDAPClient to cache the return value
of get_entry() and certain exceptions (NotFound and
EmptyResult). The cache uses an OrderedDict for the cases
where a large cache might result an LRU model can be used.

The cache be enabled (default) or disabled using
ldap_cache=True/False.

This cache is per-request so is not expected to grow
particularly large except in the case of a large batch
command.

The key to the cache entry is the dn of the object
being requested.

Any write to or referencing a cached dn is evicted from
the cache.

The set of attributes is somewhat taken into consideration.
"*" does not always match everything being asked for by
a plugin so unless the requested set of attributes is a
direct subset of what is cached it will be re-fetched. Err
on the side of safety.

Despite this rather conserative approach to caching 29%
of queries are saved with ipatests/xmlrpc_tests/*

https://pagure.io/freeipa/issue/8798

Signed-off-by: Rob Crittenden <rcritten@redhat.com>
Reviewed-By: Rafael Guterres Jeffman <rjeffman@redhat.com>
This commit is contained in:
Rob Crittenden
2021-03-26 12:39:20 -04:00
parent 8365d5e734
commit a4675f6f50
3 changed files with 260 additions and 13 deletions

View File

@@ -31,6 +31,8 @@ import os
import pwd
import warnings
from collections import OrderedDict
from cryptography import x509 as crypto_x509
from cryptography.hazmat.primitives import serialization
@@ -47,7 +49,7 @@ from ipalib.constants import LDAP_GENERALIZED_TIME_FORMAT
# pylint: enable=ipa-forbidden-import
from ipaplatform.paths import paths
from ipapython.ipautil import format_netloc, CIDict
from ipapython.dn import DN
from ipapython.dn import DN, RDN
from ipapython.dnsutil import DNSName
from ipapython.kerberos import Principal
@@ -1698,6 +1700,7 @@ class LDAPClient:
modlist = entry.generate_modlist()
if not modlist:
raise errors.EmptyModlist()
logger.debug("update_entry modlist %s", modlist)
# pass arguments to python-ldap
with self.error_handler():
@@ -1749,3 +1752,237 @@ def get_ldap_uri(host='', port=389, cacert=None, ldapi=False, realm=None,
return 'ldap://%s' % format_netloc(host, port)
else:
raise ValueError('Protocol %r not supported' % protocol)
class CacheEntry:
def __init__(self, entry=None, attrs_list=None, exception=None,
get_effective_rights=False, all=False):
self.entry = entry
self.attrs_list = attrs_list
self.exception = exception
self.all = all
class LDAPCache(LDAPClient):
"""A very basic LRU Cache using an OrderedDict"""
def __init__(self, ldap_uri, start_tls=False, force_schema_updates=False,
no_schema=False, decode_attrs=True, cacert=None,
sasl_nocanon=True, enable_cache=True, cache_size=100):
self.cache = OrderedDict()
self._enable_cache = True # initialize to zero to satisfy pylint
object.__setattr__(self, '_cache_misses', 0)
object.__setattr__(self, '_cache_hits', 0)
object.__setattr__(self, '_enable_cache',
enable_cache and cache_size > 0)
object.__setattr__(self, '_cache_size', cache_size)
super(LDAPCache, self).__init__(
ldap_uri, start_tls, force_schema_updates, no_schema,
decode_attrs, cacert, sasl_nocanon
)
@property
def hit(self):
return self._cache_hits # pylint: disable=no-member
@property
def miss(self):
return self._cache_misses # pylint: disable=no-member
@property
def max_entries(self):
return self._cache_size # pylint: disable=no-member
def emit(self, msg, *args, **kwargs):
if self._enable_cache:
logger.debug(msg, *args, **kwargs)
def add_cache_entry(self, dn, attrs_list=None, get_all=False,
entry=None, exception=None):
# idnsname - caching prevents delete when mod value to None
# cospriority - in a Class of Service object, uncacheable
# TODO - usercertificate was banned at one point and I don't remember
# why...
BANNED_ATTRS = {'idnsname', 'cospriority'}
if not self._enable_cache:
return
self.remove_cache_entry(dn)
if (
DN('cn=config') in dn
or DN('cn=kerberos') in dn
or DN('o=ipaca') in dn
):
return
if exception:
self.emit("EXC: Caching exception %s", exception)
self.cache[dn] = CacheEntry(exception=exception)
else:
if not BANNED_ATTRS.intersection(attrs_list):
self.cache[dn] = CacheEntry(
entry=entry.copy(),
attrs_list=deepcopy(attrs_list),
all=get_all,
)
else:
return
self.cache.move_to_end(dn)
if len(self.cache) > self.max_entries:
(dn, entry) = self.cache.popitem(last=False)
self.emit("LRU: removed %s", dn)
def clear_cache(self):
self.cache_status('FINAL')
object.__setattr__(self, 'cache', OrderedDict())
object.__setattr__(self, '_cache_hits', 0)
object.__setattr__(self, '_cache_misses', 0)
def cache_status(self, type):
self.emit("%s: Hits %d Misses %d Size %d",
type, self.hit, self.miss, len(self.cache))
def remove_cache_entry(self, dn):
assert isinstance(dn, DN)
self.emit('DROP: %s', dn)
if dn in self.cache:
del self.cache[dn]
else:
self.emit('DROP: not in cache %s', dn)
# Begin LDAPClient methods
def add_entry(self, entry):
self.emit('add_entry')
self.remove_cache_entry(entry.dn)
super(LDAPCache, self).add_entry(entry)
def update_entry(self, entry):
self.emit('update_entry')
self.remove_cache_entry(entry.dn)
super(LDAPCache, self).update_entry(entry)
def delete_entry(self, entry_or_dn):
self.emit('delete_entry')
if isinstance(entry_or_dn, DN):
dn = entry_or_dn
else:
dn = entry_or_dn.dn
self.remove_cache_entry(dn)
super(LDAPCache, self).delete_entry(dn)
def move_entry(self, dn, new_dn, del_old=True):
self.emit('move_entry')
self.remove_cache_entry(dn)
self.remove_cache_entry(new_dn)
super(LDAPCache, self).move_entry(dn, new_dn, del_old)
def modify_s(self, dn, modlist):
self.emit('modify_s')
if not isinstance(dn, DN):
dn = DN(dn)
self.emit('modlist %s', modlist)
for (_op, attr, mod_dn) in modlist:
if attr.lower() in ('member',
'ipaallowedtoperform_write_keys',
'managedby_host'):
for d in mod_dn:
if not isinstance(d, (DN, RDN)):
d = DN(d.decode('utf-8'))
self.emit('modify_s %s', d)
self.remove_cache_entry(d)
self.emit('modify_s %s', dn)
self.remove_cache_entry(dn)
return super(LDAPCache, self).modify_s(dn, modlist)
def get_entry(self, dn, attrs_list=None, time_limit=None,
size_limit=None, get_effective_rights=False):
# pylint: disable=no-member
if not self._enable_cache:
return super(LDAPCache, self).get_entry(
dn, attrs_list, time_limit, size_limit, get_effective_rights
)
self.emit("Cache lookup: %s", dn)
entry = self.cache.get(dn)
if get_effective_rights and entry:
# We don't cache this so do the query but don't drop the
# entry.
entry = None
if entry and entry.exception:
hits = self._cache_hits + 1 # pylint: disable=no-member
object.__setattr__(self, '_cache_hits', hits)
self.emit("HIT: Re-raising %s", entry.exception)
self.cache_status('HIT')
raise entry.exception
self.emit("Requested attrs_list %s", attrs_list)
if entry:
self.emit("Cached attrs_list %s", entry.attrs_list)
if not attrs_list:
attrs_list = ['*']
elif attrs_list == ['']:
attrs_list = ['dn']
get_all = False
if (
entry
and (attrs_list in (['*'], ['']))
and (entry.attrs_list in (['*'], ['']))
):
get_all = True
if entry and entry.all and get_all:
# self.hit # pylint: disable=pointless-statement
hits = self._cache_hits + 1 # pylint: disable=no-member
object.__setattr__(self, '_cache_hits', hits)
self.cache_status('HIT')
return entry.entry
# Be sure we have all the requested attributes before returning
# a cached entry.
if entry and attrs_list:
req_attrs = set(attr.lower() for attr in set(attrs_list))
cache_attrs = set(attr.lower() for attr in entry.attrs_list)
if (req_attrs.issubset(cache_attrs)):
hits = self._cache_hits + 1 # pylint: disable=no-member
object.__setattr__(self, '_cache_hits', hits)
self.cache_status('HIT')
return entry.entry
try:
entry = super(LDAPCache, self).get_entry(
dn, attrs_list, time_limit, size_limit, get_effective_rights
)
except (errors.NotFound, errors.EmptyResult) as e:
# only cache these exceptions
self.add_cache_entry(dn, exception=e)
misses = self._cache_misses + 1 # pylint: disable=no-member
object.__setattr__(self, '_cache_misses', misses)
self.cache_status('MISS: %s' % e)
raise
# pylint: disable=try-except-raise
except Exception:
# re-raise anything we aren't caching
raise
else:
self.add_cache_entry(dn, attrs_list=attrs_list, get_all=get_all,
entry=entry)
misses = self._cache_misses + 1 # pylint: disable=no-member
object.__setattr__(self, '_cache_misses', misses)
self.cache_status('MISS')
return entry