diff --git a/ipalib/ipa_types.py b/ipalib/ipa_types.py index 670c4dd64..19950ead5 100644 --- a/ipalib/ipa_types.py +++ b/ipalib/ipa_types.py @@ -155,3 +155,19 @@ class Unicode(Type): if self.max_length is not None and len(value) > self.max_length: return 'Can be at most %d characters long' % self.max_length + + +class Enum(Type): + def __init__(self, *values): + if len(values) < 1: + raise ValueError('%s requires at least one value' % self.name) + type_ = type(values[0]) + if type_ not in (unicode, int, float): + raise TypeError( + '%r: %r not unicode, int, nor float' % (values[0], type_) + ) + for val in values[1:]: + if type(val) is not type_: + raise TypeError('%r: %r is not %r' % (val, type(val), type_)) + self.values = values + super(Enum, self).__init__(type_) diff --git a/ipalib/tests/test_ipa_types.py b/ipalib/tests/test_ipa_types.py index 5d31b8446..6ae94c416 100644 --- a/ipalib/tests/test_ipa_types.py +++ b/ipalib/tests/test_ipa_types.py @@ -331,3 +331,34 @@ class test_Unicode(ClassChecker): assert o.validate(u'a___b') == 'Can be at most 4 characters long' assert o.validate(u'a-b') == 'Must match %r' % pat assert o.validate(u'a--b') == 'Must match %r' % pat + + +class test_Enum(ClassChecker): + _cls = ipa_types.Enum + + def test_class(self): + assert self.cls.__bases__ == (ipa_types.Type,) + + def test_init(self): + for t in (unicode, int, float): + vals = (t(1),) + o = self.cls(*vals) + assert o.__islocked__() is True + assert read_only(o, 'type') is t + assert read_only(o, 'name') is 'Enum' + assert read_only(o, 'values') == vals + + # Check that ValueError is raised when no values are given: + e = raises(ValueError, self.cls) + assert str(e) == 'Enum requires at least one value' + + # Check that TypeError is raised when type of first value is not + # allowed: + e = raises(TypeError, self.cls, 'hello') + assert str(e) == '%r: %r not unicode, int, nor float' % ('hello', str) + #self.cls('hello') + + # Check that TypeError is raised when subsequent values aren't same + # type as first: + e = raises(TypeError, self.cls, u'hello', 'world') + assert str(e) == '%r: %r is not %r' % ('world', str, unicode)