Replace getList by a get_entries method

The find_entries method is cumbersome to use: it requires keyword arguments
for simple uses, and callers are tempted to ignore the 'truncated' flag
it returns.
Introduce a simpler method, get_entries, that returns the found
list directly, and raises an errors if the list is truncated.
Replace the getList method by get_entries.

Part of the work for: https://fedorahosted.org/freeipa/ticket/2660
This commit is contained in:
Petr Viktorin 2013-01-21 08:39:09 -05:00 committed by Martin Kosek
parent f5c404c65d
commit 4779865ea3
8 changed files with 91 additions and 60 deletions

View File

@ -211,7 +211,7 @@ def list_replicas(realm, host, replica, dirman_passwd, verbose):
conn.do_simple_bind(bindpw=dirman_passwd)
dn = DN(('cn', 'masters'), ('cn', 'ipa'), ('cn', 'etc'), ipautil.realm_to_suffix(realm))
entries = conn.getList(dn, ldap.SCOPE_ONELEVEL)
entries = conn.get_entries(dn, ldap.SCOPE_ONELEVEL)
for ent in entries:
try:

View File

@ -123,9 +123,8 @@ def main():
# List available Managed Entry Plugins
managed_entries = None
try:
entries = conn.getList(
managed_entry_definitions_dn, ldap.SCOPE_SUBTREE, filter
)
entries = conn.get_entries(
managed_entry_definitions_dn, ldap.SCOPE_SUBTREE, filter)
except Exception, e:
root_logger.debug("Search for managed entries failed: %s" % str(e))
sys.exit("Unable to find managed entries at %s" % managed_entry_definitions_dn)

View File

@ -156,7 +156,7 @@ def list_replicas(realm, host, replica, dirman_passwd, verbose):
dn = DN(('cn', 'masters'), ('cn', 'ipa'), ('cn', 'etc'), ipautil.realm_to_suffix(realm))
try:
entries = conn.getList(dn, ldap.SCOPE_ONELEVEL)
entries = conn.get_entries(dn, ldap.SCOPE_ONELEVEL)
except:
print "Failed to read master data from '%s': %s" % (host, str(e))
return
@ -166,7 +166,7 @@ def list_replicas(realm, host, replica, dirman_passwd, verbose):
dn = DN(('cn', 'replicas'), ('cn', 'ipa'), ('cn', 'etc'), ipautil.realm_to_suffix(realm))
try:
entries = conn.getList(dn, ldap.SCOPE_ONELEVEL)
entries = conn.get_entries(dn, ldap.SCOPE_ONELEVEL)
except:
pass
else:
@ -195,8 +195,9 @@ def list_replicas(realm, host, replica, dirman_passwd, verbose):
repl = replication.ReplicationManager(realm, winsync_peer,
dirman_passwd)
cn, dn = repl.agreement_dn(replica)
entries = repl.conn.getList(dn, ldap.SCOPE_BASE,
"(objectclass=nsDSWindowsReplicationAgreement)")
entries = repl.conn.get_entries(
dn, ldap.SCOPE_BASE,
"(objectclass=nsDSWindowsReplicationAgreement)")
ent_type = 'winsync'
else:
repl = replication.ReplicationManager(realm, replica,
@ -304,7 +305,7 @@ def del_link(realm, replica1, replica2, dirman_passwd, force=False):
try:
dn = DN(('cn', replica2), ('cn', 'replicas'), ('cn', 'ipa'), ('cn', 'etc'),
ipautil.realm_to_suffix(realm))
entries = repl1.conn.getList(dn, ldap.SCOPE_SUBTREE)
entries = repl1.conn.get_entries(dn, ldap.SCOPE_SUBTREE)
if entries:
entries.sort(key=len, reverse=True)
for dn in entries:
@ -455,7 +456,7 @@ def list_clean_ruv(realm, host, dirman_passwd, verbose):
repl = replication.ReplicationManager(realm, host, dirman_passwd)
dn = DN(('cn', 'cleanallruv'),('cn', 'tasks'), ('cn', 'config'))
try:
entries = repl.conn.getList(dn, ldap.SCOPE_ONELEVEL)
entries = repl.conn.get_entries(dn, ldap.SCOPE_ONELEVEL)
except errors.NotFound:
print "No CLEANALLRUV tasks running"
else:
@ -472,7 +473,7 @@ def list_clean_ruv(realm, host, dirman_passwd, verbose):
dn = DN(('cn', 'abort cleanallruv'),('cn', 'tasks'), ('cn', 'config'))
try:
entries = repl.conn.getList(dn, ldap.SCOPE_ONELEVEL)
entries = repl.conn.get_entries(dn, ldap.SCOPE_ONELEVEL)
except errors.NotFound:
print "No abort CLEANALLRUV tasks running"
else:
@ -586,7 +587,7 @@ def del_master(realm, hostname, options):
if force_del:
dn = DN(('cn', 'masters'), ('cn', 'ipa'), ('cn', 'etc'), thisrepl.suffix)
entries = thisrepl.conn.getList(dn, ldap.SCOPE_ONELEVEL)
entries = thisrepl.conn.get_entries(dn, ldap.SCOPE_ONELEVEL)
replica_names = []
for entry in entries:
replica_names.append(entry.single_value('cn'))
@ -616,7 +617,7 @@ def del_master(realm, hostname, options):
if delrepl and not winsync:
masters_dn = DN(('cn', 'masters'), ('cn', 'ipa'), ('cn', 'etc'), ipautil.realm_to_suffix(realm))
try:
masters = delrepl.conn.getList(masters_dn, ldap.SCOPE_ONELEVEL)
masters = delrepl.conn.get_entries(masters_dn, ldap.SCOPE_ONELEVEL)
except Exception, e:
masters = []
print "Failed to read masters data from '%s': %s" % (delrepl.hostname, convert_error(e))
@ -639,7 +640,8 @@ def del_master(realm, hostname, options):
for master_cn in [m.getValue('cn') for m in masters]:
master_dn = DN(('cn', master_cn), ('cn', 'masters'), ('cn', 'ipa'), ('cn', 'etc'), ipautil.realm_to_suffix(realm))
services = delrepl.conn.getList(master_dn, ldap.SCOPE_ONELEVEL)
services = delrepl.conn.get_entries(master_dn,
delrepl.conn.SCOPE_ONELEVEL)
services_cns = [s.getValue('cn') for s in services]
if master_cn == hostname:

View File

@ -261,9 +261,9 @@ class ADTRUSTInstance(service.Service):
"""
try:
res = self.admin_conn.getList(DN(api.env.container_ranges, self.suffix),
ldap.SCOPE_ONELEVEL,
"(objectclass=ipaDomainIDRange)")
res = self.admin_conn.get_entries(
DN(api.env.container_ranges, self.suffix),
ldap.SCOPE_ONELEVEL, "(objectclass=ipaDomainIDRange)")
if len(res) != 1:
root_logger.critical("Found more than one ID range for the " \
"local domain.")

View File

@ -509,7 +509,7 @@ class LDAPUpdate:
sattrs = ["*", "aci", "attributeTypes", "objectClasses"]
scope = ldap.SCOPE_BASE
return self.conn.getList(dn, scope, searchfilter, sattrs)
return self.conn.get_entries(dn, scope, searchfilter, sattrs)
def _apply_update_disposition(self, updates, entry):
"""

View File

@ -251,8 +251,9 @@ class ReplicationManager(object):
"""
filt = self.get_agreement_filter()
try:
ents = self.conn.getList(DN(('cn', 'mapping tree'), ('cn', 'config')),
ldap.SCOPE_SUBTREE, filt)
ents = self.conn.get_entries(
DN(('cn', 'mapping tree'), ('cn', 'config')),
ldap.SCOPE_SUBTREE, filt)
except errors.NotFound:
ents = []
return ents
@ -269,8 +270,9 @@ class ReplicationManager(object):
filt = self.get_agreement_filter(IPA_REPLICA)
try:
ents = self.conn.getList(DN(('cn', 'mapping tree'), ('cn', 'config')),
ldap.SCOPE_SUBTREE, filt)
ents = self.conn.get_entries(
DN(('cn', 'mapping tree'), ('cn', 'config')),
ldap.SCOPE_SUBTREE, filt)
except errors.NotFound:
return res
@ -291,8 +293,9 @@ class ReplicationManager(object):
filt = self.get_agreement_filter(host=hostname)
try:
entries = self.conn.getList(DN(('cn', 'mapping tree'), ('cn', 'config')),
ldap.SCOPE_SUBTREE, filt)
entries = self.conn.get_entries(
DN(('cn', 'mapping tree'), ('cn', 'config')),
ldap.SCOPE_SUBTREE, filt)
except errors.NotFound:
return None
@ -1031,7 +1034,7 @@ class ReplicationManager(object):
newschedule = '2358-2359 0'
filter = self.get_agreement_filter(host=hostname)
entries = conn.getList(
entries = conn.get_entries(
DN(('cn', 'config')), ldap.SCOPE_SUBTREE, filter)
if len(entries) == 0:
root_logger.error("Unable to find replication agreement for %s" %
@ -1086,9 +1089,9 @@ class ReplicationManager(object):
# delete master kerberos key and all its svc principals
try:
filter='(krbprincipalname=*/%s@%s)' % (replica, realm)
entries = self.conn.getList(self.suffix, ldap.SCOPE_SUBTREE,
filterstr=filter)
entries = self.conn.get_entries(
self.suffix, ldap.SCOPE_SUBTREE,
filter='(krbprincipalname=*/%s@%s)' % (replica, realm))
if entries:
entries.sort(key=len, reverse=True)
for dn in entries:
@ -1128,8 +1131,9 @@ class ReplicationManager(object):
# delete master entry with all active services
try:
dn = DN(('cn', replica), ('cn', 'masters'), ('cn', 'ipa'), ('cn', 'etc'), self.suffix)
entries = self.conn.getList(dn, ldap.SCOPE_SUBTREE)
dn = DN(('cn', replica), ('cn', 'masters'), ('cn', 'ipa'),
('cn', 'etc'), self.suffix)
entries = self.conn.get_entries(dn, ldap.SCOPE_SUBTREE)
if entries:
entries.sort(key=len, reverse=True)
for dn in entries:
@ -1145,8 +1149,8 @@ class ReplicationManager(object):
try:
basedn = DN(('cn', 'etc'), self.suffix)
filter = '(dnaHostname=%s)' % replica
entries = self.conn.getList(basedn, ldap.SCOPE_SUBTREE,
filterstr=filter)
entries = self.conn.get_entries(
basedn, ldap.SCOPE_SUBTREE, filter=filter)
if len(entries) != 0:
for e in entries:
self.conn.deleteEntry(e.dn)

View File

@ -1069,6 +1069,24 @@ class LDAPConnection(object):
)
return self.combine_filters(flts, rules)
def get_entries(self, base_dn, scope=None, filter=None, attrs_list=None):
"""Return a list of matching entries.
Raises an error if the list is truncated by the server
:param base_dn: dn of the entry at which to start the search
:param scope: search scope, see LDAP docs (default ldap2.SCOPE_SUBTREE)
:param filter: LDAP filter to apply
:param attrs_list: ist of attributes to return, all if None (default)
Use the find_entries method for more options.
"""
entries, truncated = self.find_entries(
base_dn=base_dn, scope=scope, filter=filter, attrs_list=attrs_list)
if truncated:
raise errors.LimitsExceeded()
return entries
def find_entries(self, filter=None, attrs_list=None, base_dn=None,
scope=_ldap.SCOPE_SUBTREE, time_limit=None,
size_limit=None, normalize=True, search_refs=False):
@ -1629,16 +1647,6 @@ class IPAdmin(LDAPConnection):
)
return result[0]
def getList(self, base, scope, filterstr='(objectClass=*)', attrlist=None):
# FIXME: for backwards compatibility only
result, truncated = self.find_entries(
filter=filterstr,
attrs_list=attrlist,
base_dn=base,
scope=scope,
)
return result
def addEntry(self, entry):
# FIXME: for backwards compatibility only
self.add_entry(entry.dn, entry)

View File

@ -89,10 +89,12 @@ class test_update(unittest.TestCase):
self.assertTrue(modified)
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
def test_1_add(self):
"""
@ -102,7 +104,8 @@ class test_update(unittest.TestCase):
self.assertTrue(modified)
entries = self.ld.getList(self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
@ -112,7 +115,8 @@ class test_update(unittest.TestCase):
self.assertEqual(entry.single_value('cn'), 'test')
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
@ -133,7 +137,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update([self.testdir + "2_update.update"])
self.assertTrue(modified)
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
self.assertEqual(entry.single_value('gecos'), 'Test User')
@ -145,7 +150,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update([self.testdir + "3_update.update"])
self.assertTrue(modified)
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
self.assertEqual(entry.single_value('gecos'), 'Test User New')
@ -157,7 +163,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update([self.testdir + "4_update.update"])
self.assertTrue(modified)
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
self.assertEqual(entry.single_value('gecos'), 'Test User New2')
@ -169,7 +176,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update([self.testdir + "5_update.update"])
self.assertTrue(modified)
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
self.assertEqual(sorted(entry.get('cn')), sorted(['Test User', 'Test User New']))
@ -181,7 +189,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update([self.testdir + "6_update.update"])
self.assertTrue(modified)
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
self.assertEqual(sorted(entry.get('cn')), sorted(['Test User']))
@ -193,7 +202,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update([self.testdir + "6_update.update"])
self.assertFalse(modified)
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
self.assertEqual(sorted(entry.get('cn')), sorted(['Test User']))
@ -211,10 +221,12 @@ class test_update(unittest.TestCase):
self.assertTrue(modified)
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
def test_8_badsyntax(self):
"""
@ -239,10 +251,12 @@ class test_update(unittest.TestCase):
# First make sure we're clean
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
update = {
@ -274,7 +288,8 @@ class test_update(unittest.TestCase):
modified = self.updater.update_from_dict(update)
self.assertTrue(modified)
entries = self.ld.getList(self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
@ -284,7 +299,8 @@ class test_update(unittest.TestCase):
self.assertEqual(entry.single_value('cn'), 'test')
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
self.assertEqual(len(entries), 1)
entry = entries[0]
@ -314,7 +330,9 @@ class test_update(unittest.TestCase):
self.assertTrue(modified)
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.container_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
with self.assertRaises(errors.NotFound):
entries = self.ld.getList(self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])
entries = self.ld.get_entries(
self.user_dn, ldap.SCOPE_BASE, 'objectclass=*', ['*'])