ipapython.dn: Use rich comparisons

__cmp__ and cmp were removed from Python 3.

Reviewed-By: David Kupka <dkupka@redhat.com>
Reviewed-By: Jan Cholasta <jcholast@redhat.com>
Reviewed-By: Martin Basti <mbasti@redhat.com>
This commit is contained in:
Petr Viktorin
2015-09-18 15:28:23 +02:00
committed by Jan Cholasta
parent c9ca8de7a2
commit ed96f8d9ba
2 changed files with 94 additions and 47 deletions

View File

@@ -420,6 +420,7 @@ to the constructor. The result may share underlying structure.
from __future__ import print_function
import sys
import functools
from ldap.dn import str2dn, dn2str
from ldap import DECODING_ERROR
@@ -449,9 +450,11 @@ def _adjust_indices(start, end, length):
def _normalize_ava_input(val):
if not isinstance(val, six.string_types):
val = unicode(val).encode('utf-8')
elif isinstance(val, unicode):
if six.PY3 and isinstance(val, bytes):
raise TypeError('expected str, got bytes: %s' % val)
elif not isinstance(val, six.string_types):
val = val_encode(six.text_type(val))
elif six.PY2 and isinstance(val, unicode):
val = val.encode('utf-8')
return val
@@ -512,29 +515,47 @@ def get_ava(*args):
def sort_avas(rdn):
if len(rdn) <= 1:
return
rdn.sort(cmp=cmp_avas)
rdn.sort(key=ava_key)
def cmp_avas(a, b):
r = cmp(a[0].lower(), b[0].lower())
if r == 0:
r = cmp(a[1].lower(), b[1].lower())
return r
def ava_key(ava):
return ava[0].lower(), ava[1].lower()
def cmp_rdns(a, b):
key_a = rdn_key(a)
key_b = rdn_key(b)
if key_a == key_b:
return 0
elif key_a < key_b:
return -1
else:
return 1
l = len(a)
r = cmp(l, len(b))
if r != 0:
return r
for i, ava_a in enumerate(a):
r = cmp_avas(ava_a, b[i])
if r != 0:
return r
return 0
def rdn_key(rdn):
return (len(rdn),) + tuple(ava_key(k) for k in rdn)
if six.PY2:
# Python 2: Input/output is unicode; we store UTF-8 bytes
def val_encode(s):
return s.encode('utf-8')
def val_decode(s):
return s.decode('utf-8')
else:
# Python 3: Everything is unicode (str)
def val_encode(s):
if isinstance(s, bytes):
raise TypeError('expected str, got bytes: %s' % s)
return s
def val_decode(s):
return s
@functools.total_ordering
class AVA(object):
'''
AVA(arg0, ...)
@@ -593,23 +614,23 @@ class AVA(object):
self._ava = get_ava(*args)
def _get_attr(self):
return self._ava[0].decode('utf-8')
return val_decode(self._ava[0])
def _set_attr(self, new_attr):
try:
self._ava[0] = _normalize_ava_input(new_attr)
except Exception, e:
except Exception as e:
raise ValueError('unable to convert attr "%s": %s' % (new_attr, e))
attr = property(_get_attr)
def _get_value(self):
return self._ava[1].decode('utf-8')
return val_decode(self._ava[1])
def _set_value(self, new_value):
try:
self._ava[1] = _normalize_ava_input(new_value)
except Exception, e:
except Exception as e:
raise ValueError('unable to convert value "%s": %s' % (new_value, e))
value = property(_get_value)
@@ -669,20 +690,21 @@ class AVA(object):
return False
# Perform comparison between objects of same type
return cmp_avas(self._ava, other._ava) == 0
return ava_key(self._ava) == ava_key(other._ava)
def __ne__(self, other):
return not self.__eq__(other)
def __cmp__(self, other):
def __lt__(self, other):
'comparison is case insensitive, see __eq__ doc for explanation'
if not isinstance(other, AVA):
raise TypeError("expected AVA but got %s" % (other.__class__.__name__))
return cmp_avas(self._ava, other._ava)
return ava_key(self._ava) < ava_key(other._ava)
@functools.total_ordering
class RDN(object):
'''
RDN(arg0, ...)
@@ -843,8 +865,8 @@ class RDN(object):
return [self._get_ava(ava) for ava in self._avas[key]]
elif isinstance(key, six.string_types):
for ava in self._avas:
if key == ava[0].decode('utf-8'):
return ava[1].decode('utf-8')
if key == val_decode(ava[0]):
return val_decode(ava[1])
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
else:
raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
@@ -853,25 +875,25 @@ class RDN(object):
def _get_attr(self):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
return self._avas[0][0].decode('utf-8')
return val_decode(self._avas[0][0])
def _set_attr(self, new_attr):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
self._avas[0][0] = unicode(new_attr).encode('utf-8')
self._avas[0][0] = val_encode(six.text_type(new_attr))
attr = property(_get_attr)
def _get_value(self):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
return self._avas[0][1].decode('utf-8')
return val_decode(self._avas[0][1])
def _set_value(self, new_value):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
self._avas[0][1] = unicode(new_value).encode('utf-8')
self._avas[0][1] = val_encode(six.text_type(new_value))
value = property(_get_value)
@@ -898,16 +920,16 @@ class RDN(object):
return False
# Perform comparison between objects of same type
return cmp_rdns(self._avas, other._avas) == 0
return rdn_key(self._avas) == rdn_key(other._avas)
def __ne__(self, other):
return not self.__eq__(other)
def __cmp__(self, other):
def __lt__(self, other):
if not isinstance(other, RDN):
raise TypeError("expected RDN but got %s" % (other.__class__.__name__))
return cmp_rdns(self._avas, other._avas)
return rdn_key(self._avas) < rdn_key(other._avas)
def __add__(self, other):
result = self.__class__(self)
@@ -927,6 +949,7 @@ class RDN(object):
return result
@functools.total_ordering
class DN(object):
'''
DN(arg0, ...)
@@ -1088,8 +1111,8 @@ class DN(object):
def _rdns_from_value(self, value):
if isinstance(value, six.string_types):
try:
if isinstance(value, unicode):
value = value.encode('utf-8')
if isinstance(value, six.text_type):
value = val_encode(value)
rdns = str2dn(value)
except DECODING_ERROR:
raise ValueError("malformed RDN string = \"%s\"" % value)
@@ -1124,7 +1147,7 @@ class DN(object):
def __str__(self):
try:
return dn2str(self.rdns)
except Exception, e:
except Exception as e:
print(len(self.rdns))
print(self.rdns)
raise
@@ -1153,8 +1176,8 @@ class DN(object):
elif isinstance(key, six.string_types):
for rdn in self.rdns:
for ava in rdn:
if key == ava[0].decode('utf-8'):
return ava[1].decode('utf-8')
if key == val_decode(ava[0]):
return val_decode(ava[1])
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
else:
raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \
@@ -1187,23 +1210,25 @@ class DN(object):
if not isinstance(other, DN):
return False
if len(self) != len(other):
return False
# Perform comparison between objects of same type
return self.__cmp__(other) == 0
return self._cmp_sequence(other, 0, len(self)) == 0
def __ne__(self, other):
return not self.__eq__(other)
def __cmp__(self, other):
def __lt__(self, other):
if not isinstance(other, DN):
raise TypeError("expected DN but got %s" % (other.__class__.__name__))
result = cmp(len(self), len(other))
if result != 0:
return result
return self._cmp_sequence(other, 0, len(self))
if len(self) != len(other):
return len(self) < len(other)
return self._cmp_sequence(other, 0, len(self)) < 0
def _cmp_sequence(self, pattern, self_start, pat_len):
self_idx = self_start
self_len = len(self)
pat_idx = 0

View File

@@ -4,11 +4,33 @@ import unittest
import six
from ipapython.dn import *
from ipapython.dn import DN, RDN, AVA
if six.PY3:
unicode = str
def cmp(a, b):
if a == b:
assert not a < b
assert not a > b
assert not a != b
assert a <= b
assert a >= b
return 0
elif a < b:
assert not a > b
assert a != b
assert a <= b
assert not a >= b
return -1
else:
assert a > b
assert a != b
assert not a <= b
assert a >= b
return 1
def expected_class(klass, component):
if klass is AVA:
if component == 'self':