diff --git a/ipaclient/install/ipadiscovery.py b/ipaclient/install/ipadiscovery.py index 363970c86..a8283d8e6 100644 --- a/ipaclient/install/ipadiscovery.py +++ b/ipaclient/install/ipadiscovery.py @@ -20,7 +20,6 @@ from __future__ import absolute_import import logging -import operator import socket import six @@ -28,6 +27,7 @@ import six from dns import resolver, rdatatype from dns.exception import DNSException from ipalib import errors +from ipapython.dnsutil import query_srv from ipapython import ipaldap from ipaplatform.paths import paths from ipapython.ipautil import valid_ip, realm_to_suffix @@ -498,8 +498,7 @@ class IPADiscovery(object): logger.debug("Search DNS for SRV record of %s", qname) try: - answers = resolver.query(qname, rdatatype.SRV) - answers = sorted(answers, key=operator.attrgetter('priority')) + answers = query_srv(qname) except DNSException as e: logger.debug("DNS record not found: %s", e.__class__.__name__) answers = [] diff --git a/ipalib/rpc.py b/ipalib/rpc.py index c6a8989f5..17368f160 100644 --- a/ipalib/rpc.py +++ b/ipalib/rpc.py @@ -45,7 +45,6 @@ import gzip from cryptography import x509 as crypto_x509 import gssapi -from dns import resolver, rdatatype from dns.exception import DNSException from ssl import SSLError import six @@ -61,7 +60,7 @@ from ipalib.x509 import Encoding as x509_Encoding from ipapython import ipautil from ipapython import session_storage from ipapython.cookie import Cookie -from ipapython.dnsutil import DNSName +from ipapython.dnsutil import DNSName, query_srv from ipalib.text import _ from ipalib.util import create_https_connection from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \ @@ -878,7 +877,7 @@ class RPCClient(Connectible): name = '_ldap._tcp.%s.' % self.env.domain try: - answers = resolver.query(name, rdatatype.SRV) + answers = query_srv(name) except DNSException: answers = [] @@ -886,17 +885,11 @@ class RPCClient(Connectible): server = str(answer.target).rstrip(".") servers.append('https://%s%s' % (ipautil.format_netloc(server), path)) - servers = list(set(servers)) - # the list/set conversion won't preserve order so stick in the - # local config file version here. - cfg_server = rpc_uri - if cfg_server in servers: - # make sure the configured master server is there just once and - # it is the first one - servers.remove(cfg_server) - servers.insert(0, cfg_server) - else: - servers.insert(0, cfg_server) + # make sure the configured master server is there just once and + # it is the first one. + if rpc_uri in servers: + servers.remove(rpc_uri) + servers.insert(0, rpc_uri) return servers diff --git a/ipalib/util.py b/ipalib/util.py index af583373d..dd83f1f46 100644 --- a/ipalib/util.py +++ b/ipalib/util.py @@ -981,14 +981,13 @@ def detect_dns_zone_realm_type(api, domain): try: # The presence of this record is enough, return foreign in such case - result = resolver.query(ad_specific_record_name, rdatatype.SRV) + resolver.query(ad_specific_record_name, rdatatype.SRV) + except DNSException: + # If we could not detect type with certainty, return unknown + return 'unknown' + else: return 'foreign' - except DNSException: - pass - - # If we could not detect type with certainity, return unknown - return 'unknown' def has_managed_topology(api): domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0) diff --git a/ipapython/config.py b/ipapython/config.py index c3360779f..f701122bd 100644 --- a/ipapython/config.py +++ b/ipapython/config.py @@ -26,7 +26,6 @@ from copy import copy import socket import functools -from dns import resolver, rdatatype from dns.exception import DNSException import dns.name # pylint: disable=import-error @@ -36,6 +35,7 @@ from six.moves.urllib.parse import urlsplit from ipaplatform.paths import paths from ipapython.dn import DN +from ipapython.dnsutil import query_srv from ipapython.ipautil import CheckedIPAddress, CheckedIPAddressLoopback @@ -210,7 +210,7 @@ def __discover_config(discover_server = True): name = "_ldap._tcp." + domain try: - servers = resolver.query(name, rdatatype.SRV) + servers = query_srv(name) except DNSException: # try cycling on domain components of FQDN try: @@ -225,7 +225,7 @@ def __discover_config(discover_server = True): return False name = "_ldap._tcp.%s" % domain try: - servers = resolver.query(name, rdatatype.SRV) + servers = query_srv(name) break except DNSException: pass @@ -236,7 +236,7 @@ def __discover_config(discover_server = True): if not servers: name = "_ldap._tcp.%s." % config.default_domain try: - servers = resolver.query(name, rdatatype.SRV) + servers = query_srv(name) except DNSException: pass diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py index b40302d0e..6157183a0 100644 --- a/ipapython/dnsutil.py +++ b/ipapython/dnsutil.py @@ -17,12 +17,17 @@ # along with this program. If not, see . # +import copy import logging +import operator +import random import dns.name import dns.exception import dns.resolver -import copy +import dns.rdataclass +import dns.rdatatype + import six @@ -373,3 +378,88 @@ def check_zone_overlap(zone, raise_on_error=True): if ns: msg += u" and is handled by server(s): {0}".format(', '.join(ns)) raise ValueError(msg) + + +def _mix_weight(records): + """Weighted population sorting for records with same priority + """ + # trivial case + if len(records) <= 1: + return records + + # Optimization for common case: If all weights are the same (e.g. 0), + # just shuffle the records, which is about four times faster. + if all(rr.weight == records[0].weight for rr in records): + random.shuffle(records) + return records + + noweight = 0.01 # give records with 0 weight a small chance + result = [] + records = set(records) + while len(records) > 1: + # Compute the sum of the weights of those RRs. Then choose a + # uniform random number between 0 and the sum computed (inclusive). + urn = random.uniform(0, sum(rr.weight or noweight for rr in records)) + # Select the RR whose running sum value is the first in the selected + # order which is greater than or equal to the random number selected. + acc = 0. + for rr in records.copy(): + acc += rr.weight or noweight + if acc >= urn: + records.remove(rr) + result.append(rr) + if records: + result.append(records.pop()) + return result + + +def sort_prio_weight(records): + """RFC 2782 sorting algorithm for SRV and URI records + + RFC 2782 defines a sorting algorithms for SRV records, that is also used + for sorting URI records. Records are sorted by priority and than randomly + shuffled according to weight. + + This implementation also removes duplicate entries. + """ + # order records by priority + records = sorted(records, key=operator.attrgetter("priority")) + + # remove duplicate entries + uniquerecords = [] + seen = set() + for rr in records: + # A SRV record has target and port, URI just has target. + target = (rr.target, getattr(rr, "port", None)) + if target not in seen: + uniquerecords.append(rr) + seen.add(target) + + # weighted randomization of entries with same priority + result = [] + sameprio = [] + for rr in uniquerecords: + # add all items with same priority in a bucket + if not sameprio or sameprio[0].priority == rr.priority: + sameprio.append(rr) + else: + # got different priority, shuffle bucket + result.extend(_mix_weight(sameprio)) + # start a new priority list + sameprio = [rr] + # add last batch of records with same priority + if sameprio: + result.extend(_mix_weight(sameprio)) + return result + + +def query_srv(qname, resolver=None, **kwargs): + """Query SRV records and sort reply according to RFC 2782 + + :param qname: query name, _service._proto.domain. + :return: list of dns.rdtypes.IN.SRV.SRV instances + """ + if resolver is None: + resolver = dns.resolver + answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs) + return sort_prio_weight(answer) diff --git a/ipaserver/dcerpc.py b/ipaserver/dcerpc.py index 2e5a7bf6c..97a7945c3 100644 --- a/ipaserver/dcerpc.py +++ b/ipaserver/dcerpc.py @@ -32,6 +32,7 @@ from ipalib import api, _ from ipalib import errors from ipapython import ipautil from ipapython.dn import DN +from ipapython.dnsutil import query_srv from ipapython.ipaldap import ldap_initialize from ipaserver.install import installutils from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL, @@ -55,7 +56,6 @@ import samba import ldap as _ldap from ipapython import ipaldap from ipapython.dnsutil import DNSName -from dns import resolver, rdatatype from dns.exception import DNSException import pysss_nss_idmap import pysss @@ -799,7 +799,7 @@ class DomainValidator(object): gc_name = '_gc._tcp.%s.' % info['dns_domain'] try: - answers = resolver.query(gc_name, rdatatype.SRV) + answers = query_srv(gc_name) except DNSException as e: answers = [] diff --git a/ipatests/test_ipapython/test_dnsutil.py b/ipatests/test_ipapython/test_dnsutil.py new file mode 100644 index 000000000..36adb077c --- /dev/null +++ b/ipatests/test_ipapython/test_dnsutil.py @@ -0,0 +1,106 @@ +# +# Copyright (C) 2018 FreeIPA Contributors. See COPYING for license +# +import dns.name +import dns.rdataclass +import dns.rdatatype +from dns.rdtypes.IN.SRV import SRV +from dns.rdtypes.ANY.URI import URI + +from ipapython import dnsutil + +import pytest + + +def mksrv(priority, weight, port, target): + return SRV( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.SRV, + priority=priority, + weight=weight, + port=port, + target=dns.name.from_text(target) + ) + + +def mkuri(priority, weight, target): + return URI( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.URI, + priority=priority, + weight=weight, + target=target + ) + + +class TestSortSRV(object): + def test_empty(self): + assert dnsutil.sort_prio_weight([]) == [] + + def test_one(self): + h1 = mksrv(1, 0, 443, u"host1") + assert dnsutil.sort_prio_weight([h1]) == [h1] + + h2 = mksrv(10, 5, 443, u"host2") + assert dnsutil.sort_prio_weight([h2]) == [h2] + + def test_prio(self): + h1 = mksrv(1, 0, 443, u"host1") + h2 = mksrv(2, 0, 443, u"host2") + h3 = mksrv(3, 0, 443, u"host3") + assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3] + assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3] + assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2] + + h380 = mksrv(4, 0, 80, u"host3") + assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380] + + hs = mksrv(-1, 0, 443, u"special") + assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2] + + def assert_permutations(self, answers, permutations): + seen = set() + for _unused in range(1000): + result = tuple(dnsutil.sort_prio_weight(answers)) + assert result in permutations + seen.add(result) + if seen == permutations: + break + else: + pytest.fail("sorting didn't exhaust all permutations.") + + def test_sameprio(self): + h1 = mksrv(1, 0, 443, u"host1") + h2 = mksrv(1, 0, 443, u"host2") + permutations = { + (h1, h2), + (h2, h1), + } + self.assert_permutations([h1, h2], permutations) + + def test_weight(self): + h1 = mksrv(1, 0, 443, u"host1") + h2_w15 = mksrv(2, 15, 443, u"host2") + h3_w10 = mksrv(2, 10, 443, u"host3") + + permutations = { + (h1, h2_w15, h3_w10), + (h1, h3_w10, h2_w15), + } + self.assert_permutations([h1, h2_w15, h3_w10], permutations) + + def test_large(self): + records = tuple( + mksrv(1, i, 443, "host{}".format(i)) for i in range(1000) + ) + assert len(dnsutil.sort_prio_weight(records)) == len(records) + + +class TestSortURI(object): + def test_prio(self): + h1 = mkuri(1, 0, u"https://host1/api") + h2 = mkuri(2, 0, u"https://host2/api") + h3 = mkuri(3, 0, u"https://host3/api") + assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3] + assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3] + assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]