# # Copyright (C) 2018 FreeIPA Contributors. See COPYING for license # import dns.name import dns.rdataclass import dns.rdatatype from dns.rdtypes.IN.SRV import SRV from dns.rdtypes.ANY.URI import URI from ipapython import dnsutil import pytest def mksrv(priority, weight, port, target): return SRV( rdclass=dns.rdataclass.IN, rdtype=dns.rdatatype.SRV, priority=priority, weight=weight, port=port, target=dns.name.from_text(target) ) def mkuri(priority, weight, target): return URI( rdclass=dns.rdataclass.IN, rdtype=dns.rdatatype.URI, priority=priority, weight=weight, target=target ) class TestSortSRV: def test_empty(self): assert dnsutil.sort_prio_weight([]) == [] def test_one(self): h1 = mksrv(1, 0, 443, u"host1") assert dnsutil.sort_prio_weight([h1]) == [h1] h2 = mksrv(10, 5, 443, u"host2") assert dnsutil.sort_prio_weight([h2]) == [h2] def test_prio(self): h1 = mksrv(1, 0, 443, u"host1") h2 = mksrv(2, 0, 443, u"host2") h3 = mksrv(3, 0, 443, u"host3") assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3] assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3] assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2] h380 = mksrv(4, 0, 80, u"host3") assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380] def assert_permutations(self, answers, permutations): seen = set() for _unused in range(1000): result = tuple(dnsutil.sort_prio_weight(answers)) assert result in permutations seen.add(result) if seen == permutations: break else: pytest.fail("sorting didn't exhaust all permutations.") def test_sameprio(self): h1 = mksrv(1, 0, 443, u"host1") h2 = mksrv(1, 0, 443, u"host2") permutations = { (h1, h2), (h2, h1), } self.assert_permutations([h1, h2], permutations) def test_weight(self): h1 = mksrv(1, 0, 443, u"host1") h2_w15 = mksrv(2, 15, 443, u"host2") h3_w10 = mksrv(2, 10, 443, u"host3") permutations = { (h1, h2_w15, h3_w10), (h1, h3_w10, h2_w15), } self.assert_permutations([h1, h2_w15, h3_w10], permutations) def test_large(self): records = tuple( mksrv(1, i, 443, "host{}".format(i)) for i in range(1000) ) assert len(dnsutil.sort_prio_weight(records)) == len(records) class TestSortURI: def test_prio(self): h1 = mkuri(1, 0, u"https://host1/api") h2 = mkuri(2, 0, u"https://host2/api") h3 = mkuri(3, 0, u"https://host3/api") assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3] assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3] assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]