mirror of
https://salsa.debian.org/freeipa-team/freeipa.git
synced 2025-02-25 18:55:28 -06:00
ipapython.dn: Use rich comparisons
__cmp__ and cmp were removed from Python 3. Reviewed-By: David Kupka <dkupka@redhat.com> Reviewed-By: Jan Cholasta <jcholast@redhat.com> Reviewed-By: Martin Basti <mbasti@redhat.com>
This commit is contained in:
committed by
Jan Cholasta
parent
c9ca8de7a2
commit
ed96f8d9ba
117
ipapython/dn.py
117
ipapython/dn.py
@@ -420,6 +420,7 @@ to the constructor. The result may share underlying structure.
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import functools
|
||||||
|
|
||||||
from ldap.dn import str2dn, dn2str
|
from ldap.dn import str2dn, dn2str
|
||||||
from ldap import DECODING_ERROR
|
from ldap import DECODING_ERROR
|
||||||
@@ -449,9 +450,11 @@ def _adjust_indices(start, end, length):
|
|||||||
|
|
||||||
|
|
||||||
def _normalize_ava_input(val):
|
def _normalize_ava_input(val):
|
||||||
if not isinstance(val, six.string_types):
|
if six.PY3 and isinstance(val, bytes):
|
||||||
val = unicode(val).encode('utf-8')
|
raise TypeError('expected str, got bytes: %s' % val)
|
||||||
elif isinstance(val, unicode):
|
elif not isinstance(val, six.string_types):
|
||||||
|
val = val_encode(six.text_type(val))
|
||||||
|
elif six.PY2 and isinstance(val, unicode):
|
||||||
val = val.encode('utf-8')
|
val = val.encode('utf-8')
|
||||||
return val
|
return val
|
||||||
|
|
||||||
@@ -512,29 +515,47 @@ def get_ava(*args):
|
|||||||
def sort_avas(rdn):
|
def sort_avas(rdn):
|
||||||
if len(rdn) <= 1:
|
if len(rdn) <= 1:
|
||||||
return
|
return
|
||||||
rdn.sort(cmp=cmp_avas)
|
rdn.sort(key=ava_key)
|
||||||
|
|
||||||
|
|
||||||
def cmp_avas(a, b):
|
def ava_key(ava):
|
||||||
r = cmp(a[0].lower(), b[0].lower())
|
return ava[0].lower(), ava[1].lower()
|
||||||
if r == 0:
|
|
||||||
r = cmp(a[1].lower(), b[1].lower())
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
def cmp_rdns(a, b):
|
def cmp_rdns(a, b):
|
||||||
|
key_a = rdn_key(a)
|
||||||
|
key_b = rdn_key(b)
|
||||||
|
if key_a == key_b:
|
||||||
|
return 0
|
||||||
|
elif key_a < key_b:
|
||||||
|
return -1
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
l = len(a)
|
|
||||||
r = cmp(l, len(b))
|
|
||||||
if r != 0:
|
|
||||||
return r
|
|
||||||
|
|
||||||
for i, ava_a in enumerate(a):
|
def rdn_key(rdn):
|
||||||
r = cmp_avas(ava_a, b[i])
|
return (len(rdn),) + tuple(ava_key(k) for k in rdn)
|
||||||
if r != 0:
|
|
||||||
return r
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
|
if six.PY2:
|
||||||
|
# Python 2: Input/output is unicode; we store UTF-8 bytes
|
||||||
|
def val_encode(s):
|
||||||
|
return s.encode('utf-8')
|
||||||
|
|
||||||
|
def val_decode(s):
|
||||||
|
return s.decode('utf-8')
|
||||||
|
else:
|
||||||
|
# Python 3: Everything is unicode (str)
|
||||||
|
def val_encode(s):
|
||||||
|
if isinstance(s, bytes):
|
||||||
|
raise TypeError('expected str, got bytes: %s' % s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def val_decode(s):
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
class AVA(object):
|
class AVA(object):
|
||||||
'''
|
'''
|
||||||
AVA(arg0, ...)
|
AVA(arg0, ...)
|
||||||
@@ -593,23 +614,23 @@ class AVA(object):
|
|||||||
self._ava = get_ava(*args)
|
self._ava = get_ava(*args)
|
||||||
|
|
||||||
def _get_attr(self):
|
def _get_attr(self):
|
||||||
return self._ava[0].decode('utf-8')
|
return val_decode(self._ava[0])
|
||||||
|
|
||||||
def _set_attr(self, new_attr):
|
def _set_attr(self, new_attr):
|
||||||
try:
|
try:
|
||||||
self._ava[0] = _normalize_ava_input(new_attr)
|
self._ava[0] = _normalize_ava_input(new_attr)
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
raise ValueError('unable to convert attr "%s": %s' % (new_attr, e))
|
raise ValueError('unable to convert attr "%s": %s' % (new_attr, e))
|
||||||
|
|
||||||
attr = property(_get_attr)
|
attr = property(_get_attr)
|
||||||
|
|
||||||
def _get_value(self):
|
def _get_value(self):
|
||||||
return self._ava[1].decode('utf-8')
|
return val_decode(self._ava[1])
|
||||||
|
|
||||||
def _set_value(self, new_value):
|
def _set_value(self, new_value):
|
||||||
try:
|
try:
|
||||||
self._ava[1] = _normalize_ava_input(new_value)
|
self._ava[1] = _normalize_ava_input(new_value)
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
raise ValueError('unable to convert value "%s": %s' % (new_value, e))
|
raise ValueError('unable to convert value "%s": %s' % (new_value, e))
|
||||||
|
|
||||||
value = property(_get_value)
|
value = property(_get_value)
|
||||||
@@ -669,20 +690,21 @@ class AVA(object):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Perform comparison between objects of same type
|
# Perform comparison between objects of same type
|
||||||
return cmp_avas(self._ava, other._ava) == 0
|
return ava_key(self._ava) == ava_key(other._ava)
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
def __cmp__(self, other):
|
def __lt__(self, other):
|
||||||
'comparison is case insensitive, see __eq__ doc for explanation'
|
'comparison is case insensitive, see __eq__ doc for explanation'
|
||||||
|
|
||||||
if not isinstance(other, AVA):
|
if not isinstance(other, AVA):
|
||||||
raise TypeError("expected AVA but got %s" % (other.__class__.__name__))
|
raise TypeError("expected AVA but got %s" % (other.__class__.__name__))
|
||||||
|
|
||||||
return cmp_avas(self._ava, other._ava)
|
return ava_key(self._ava) < ava_key(other._ava)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
class RDN(object):
|
class RDN(object):
|
||||||
'''
|
'''
|
||||||
RDN(arg0, ...)
|
RDN(arg0, ...)
|
||||||
@@ -843,8 +865,8 @@ class RDN(object):
|
|||||||
return [self._get_ava(ava) for ava in self._avas[key]]
|
return [self._get_ava(ava) for ava in self._avas[key]]
|
||||||
elif isinstance(key, six.string_types):
|
elif isinstance(key, six.string_types):
|
||||||
for ava in self._avas:
|
for ava in self._avas:
|
||||||
if key == ava[0].decode('utf-8'):
|
if key == val_decode(ava[0]):
|
||||||
return ava[1].decode('utf-8')
|
return val_decode(ava[1])
|
||||||
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
|
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
|
||||||
else:
|
else:
|
||||||
raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
|
raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
|
||||||
@@ -853,25 +875,25 @@ class RDN(object):
|
|||||||
def _get_attr(self):
|
def _get_attr(self):
|
||||||
if len(self._avas) == 0:
|
if len(self._avas) == 0:
|
||||||
raise IndexError("No AVA's in this RDN")
|
raise IndexError("No AVA's in this RDN")
|
||||||
return self._avas[0][0].decode('utf-8')
|
return val_decode(self._avas[0][0])
|
||||||
|
|
||||||
def _set_attr(self, new_attr):
|
def _set_attr(self, new_attr):
|
||||||
if len(self._avas) == 0:
|
if len(self._avas) == 0:
|
||||||
raise IndexError("No AVA's in this RDN")
|
raise IndexError("No AVA's in this RDN")
|
||||||
|
|
||||||
self._avas[0][0] = unicode(new_attr).encode('utf-8')
|
self._avas[0][0] = val_encode(six.text_type(new_attr))
|
||||||
|
|
||||||
attr = property(_get_attr)
|
attr = property(_get_attr)
|
||||||
|
|
||||||
def _get_value(self):
|
def _get_value(self):
|
||||||
if len(self._avas) == 0:
|
if len(self._avas) == 0:
|
||||||
raise IndexError("No AVA's in this RDN")
|
raise IndexError("No AVA's in this RDN")
|
||||||
return self._avas[0][1].decode('utf-8')
|
return val_decode(self._avas[0][1])
|
||||||
|
|
||||||
def _set_value(self, new_value):
|
def _set_value(self, new_value):
|
||||||
if len(self._avas) == 0:
|
if len(self._avas) == 0:
|
||||||
raise IndexError("No AVA's in this RDN")
|
raise IndexError("No AVA's in this RDN")
|
||||||
self._avas[0][1] = unicode(new_value).encode('utf-8')
|
self._avas[0][1] = val_encode(six.text_type(new_value))
|
||||||
|
|
||||||
value = property(_get_value)
|
value = property(_get_value)
|
||||||
|
|
||||||
@@ -898,16 +920,16 @@ class RDN(object):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Perform comparison between objects of same type
|
# Perform comparison between objects of same type
|
||||||
return cmp_rdns(self._avas, other._avas) == 0
|
return rdn_key(self._avas) == rdn_key(other._avas)
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
def __cmp__(self, other):
|
def __lt__(self, other):
|
||||||
if not isinstance(other, RDN):
|
if not isinstance(other, RDN):
|
||||||
raise TypeError("expected RDN but got %s" % (other.__class__.__name__))
|
raise TypeError("expected RDN but got %s" % (other.__class__.__name__))
|
||||||
|
|
||||||
return cmp_rdns(self._avas, other._avas)
|
return rdn_key(self._avas) < rdn_key(other._avas)
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
result = self.__class__(self)
|
result = self.__class__(self)
|
||||||
@@ -927,6 +949,7 @@ class RDN(object):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
class DN(object):
|
class DN(object):
|
||||||
'''
|
'''
|
||||||
DN(arg0, ...)
|
DN(arg0, ...)
|
||||||
@@ -1088,8 +1111,8 @@ class DN(object):
|
|||||||
def _rdns_from_value(self, value):
|
def _rdns_from_value(self, value):
|
||||||
if isinstance(value, six.string_types):
|
if isinstance(value, six.string_types):
|
||||||
try:
|
try:
|
||||||
if isinstance(value, unicode):
|
if isinstance(value, six.text_type):
|
||||||
value = value.encode('utf-8')
|
value = val_encode(value)
|
||||||
rdns = str2dn(value)
|
rdns = str2dn(value)
|
||||||
except DECODING_ERROR:
|
except DECODING_ERROR:
|
||||||
raise ValueError("malformed RDN string = \"%s\"" % value)
|
raise ValueError("malformed RDN string = \"%s\"" % value)
|
||||||
@@ -1124,7 +1147,7 @@ class DN(object):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
try:
|
try:
|
||||||
return dn2str(self.rdns)
|
return dn2str(self.rdns)
|
||||||
except Exception, e:
|
except Exception as e:
|
||||||
print(len(self.rdns))
|
print(len(self.rdns))
|
||||||
print(self.rdns)
|
print(self.rdns)
|
||||||
raise
|
raise
|
||||||
@@ -1153,8 +1176,8 @@ class DN(object):
|
|||||||
elif isinstance(key, six.string_types):
|
elif isinstance(key, six.string_types):
|
||||||
for rdn in self.rdns:
|
for rdn in self.rdns:
|
||||||
for ava in rdn:
|
for ava in rdn:
|
||||||
if key == ava[0].decode('utf-8'):
|
if key == val_decode(ava[0]):
|
||||||
return ava[1].decode('utf-8')
|
return val_decode(ava[1])
|
||||||
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
|
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
|
||||||
else:
|
else:
|
||||||
raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \
|
raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \
|
||||||
@@ -1187,23 +1210,25 @@ class DN(object):
|
|||||||
if not isinstance(other, DN):
|
if not isinstance(other, DN):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if len(self) != len(other):
|
||||||
|
return False
|
||||||
|
|
||||||
# Perform comparison between objects of same type
|
# Perform comparison between objects of same type
|
||||||
return self.__cmp__(other) == 0
|
return self._cmp_sequence(other, 0, len(self)) == 0
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
def __cmp__(self, other):
|
def __lt__(self, other):
|
||||||
if not isinstance(other, DN):
|
if not isinstance(other, DN):
|
||||||
raise TypeError("expected DN but got %s" % (other.__class__.__name__))
|
raise TypeError("expected DN but got %s" % (other.__class__.__name__))
|
||||||
|
|
||||||
result = cmp(len(self), len(other))
|
if len(self) != len(other):
|
||||||
if result != 0:
|
return len(self) < len(other)
|
||||||
return result
|
|
||||||
return self._cmp_sequence(other, 0, len(self))
|
return self._cmp_sequence(other, 0, len(self)) < 0
|
||||||
|
|
||||||
def _cmp_sequence(self, pattern, self_start, pat_len):
|
def _cmp_sequence(self, pattern, self_start, pat_len):
|
||||||
|
|
||||||
self_idx = self_start
|
self_idx = self_start
|
||||||
self_len = len(self)
|
self_len = len(self)
|
||||||
pat_idx = 0
|
pat_idx = 0
|
||||||
|
|||||||
@@ -4,11 +4,33 @@ import unittest
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from ipapython.dn import *
|
from ipapython.dn import DN, RDN, AVA
|
||||||
|
|
||||||
if six.PY3:
|
if six.PY3:
|
||||||
unicode = str
|
unicode = str
|
||||||
|
|
||||||
|
def cmp(a, b):
|
||||||
|
if a == b:
|
||||||
|
assert not a < b
|
||||||
|
assert not a > b
|
||||||
|
assert not a != b
|
||||||
|
assert a <= b
|
||||||
|
assert a >= b
|
||||||
|
return 0
|
||||||
|
elif a < b:
|
||||||
|
assert not a > b
|
||||||
|
assert a != b
|
||||||
|
assert a <= b
|
||||||
|
assert not a >= b
|
||||||
|
return -1
|
||||||
|
else:
|
||||||
|
assert a > b
|
||||||
|
assert a != b
|
||||||
|
assert not a <= b
|
||||||
|
assert a >= b
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def expected_class(klass, component):
|
def expected_class(klass, component):
|
||||||
if klass is AVA:
|
if klass is AVA:
|
||||||
if component == 'self':
|
if component == 'self':
|
||||||
|
|||||||
Reference in New Issue
Block a user