Sort and shuffle SRV record by priority and weight

On multiple occasions, SRV query answers were not properly sorted by
priority. Records with same priority weren't randomized and shuffled.
This caused FreeIPA to contact the same remote peer instead of
distributing the load across all available servers.

Two new helper functions now take care of SRV queries. sort_prio_weight()
sorts SRV and URI records. query_srv() combines SRV lookup with
sort_prio_weight().

Fixes: https://pagure.io/freeipa/issue/7475
Signed-off-by: Christian Heimes <cheimes@redhat.com>
Reviewed-By: Rob Crittenden <rcritten@redhat.com>
This commit is contained in:
Christian Heimes
2018-06-15 17:03:29 +02:00
parent eda831dba1
commit f90e137a17
7 changed files with 217 additions and 30 deletions

View File

@@ -20,7 +20,6 @@
from __future__ import absolute_import
import logging
import operator
import socket
import six
@@ -28,6 +27,7 @@ import six
from dns import resolver, rdatatype
from dns.exception import DNSException
from ipalib import errors
from ipapython.dnsutil import query_srv
from ipapython import ipaldap
from ipaplatform.paths import paths
from ipapython.ipautil import valid_ip, realm_to_suffix
@@ -498,8 +498,7 @@ class IPADiscovery(object):
logger.debug("Search DNS for SRV record of %s", qname)
try:
answers = resolver.query(qname, rdatatype.SRV)
answers = sorted(answers, key=operator.attrgetter('priority'))
answers = query_srv(qname)
except DNSException as e:
logger.debug("DNS record not found: %s", e.__class__.__name__)
answers = []

View File

@@ -45,7 +45,6 @@ import gzip
from cryptography import x509 as crypto_x509
import gssapi
from dns import resolver, rdatatype
from dns.exception import DNSException
from ssl import SSLError
import six
@@ -61,7 +60,7 @@ from ipalib.x509 import Encoding as x509_Encoding
from ipapython import ipautil
from ipapython import session_storage
from ipapython.cookie import Cookie
from ipapython.dnsutil import DNSName
from ipapython.dnsutil import DNSName, query_srv
from ipalib.text import _
from ipalib.util import create_https_connection
from ipalib.krb_utils import KRB5KDC_ERR_S_PRINCIPAL_UNKNOWN, KRB5KRB_AP_ERR_TKT_EXPIRED, \
@@ -878,7 +877,7 @@ class RPCClient(Connectible):
name = '_ldap._tcp.%s.' % self.env.domain
try:
answers = resolver.query(name, rdatatype.SRV)
answers = query_srv(name)
except DNSException:
answers = []
@@ -886,17 +885,11 @@ class RPCClient(Connectible):
server = str(answer.target).rstrip(".")
servers.append('https://%s%s' % (ipautil.format_netloc(server), path))
servers = list(set(servers))
# the list/set conversion won't preserve order so stick in the
# local config file version here.
cfg_server = rpc_uri
if cfg_server in servers:
# make sure the configured master server is there just once and
# it is the first one
servers.remove(cfg_server)
servers.insert(0, cfg_server)
else:
servers.insert(0, cfg_server)
# it is the first one.
if rpc_uri in servers:
servers.remove(rpc_uri)
servers.insert(0, rpc_uri)
return servers

View File

@@ -981,14 +981,13 @@ def detect_dns_zone_realm_type(api, domain):
try:
# The presence of this record is enough, return foreign in such case
result = resolver.query(ad_specific_record_name, rdatatype.SRV)
resolver.query(ad_specific_record_name, rdatatype.SRV)
except DNSException:
# If we could not detect type with certainty, return unknown
return 'unknown'
else:
return 'foreign'
except DNSException:
pass
# If we could not detect type with certainity, return unknown
return 'unknown'
def has_managed_topology(api):
domainlevel = api.Command['domainlevel_get']().get('result', DOMAIN_LEVEL_0)

View File

@@ -26,7 +26,6 @@ from copy import copy
import socket
import functools
from dns import resolver, rdatatype
from dns.exception import DNSException
import dns.name
# pylint: disable=import-error
@@ -36,6 +35,7 @@ from six.moves.urllib.parse import urlsplit
from ipaplatform.paths import paths
from ipapython.dn import DN
from ipapython.dnsutil import query_srv
from ipapython.ipautil import CheckedIPAddress, CheckedIPAddressLoopback
@@ -210,7 +210,7 @@ def __discover_config(discover_server = True):
name = "_ldap._tcp." + domain
try:
servers = resolver.query(name, rdatatype.SRV)
servers = query_srv(name)
except DNSException:
# try cycling on domain components of FQDN
try:
@@ -225,7 +225,7 @@ def __discover_config(discover_server = True):
return False
name = "_ldap._tcp.%s" % domain
try:
servers = resolver.query(name, rdatatype.SRV)
servers = query_srv(name)
break
except DNSException:
pass
@@ -236,7 +236,7 @@ def __discover_config(discover_server = True):
if not servers:
name = "_ldap._tcp.%s." % config.default_domain
try:
servers = resolver.query(name, rdatatype.SRV)
servers = query_srv(name)
except DNSException:
pass

View File

@@ -17,12 +17,17 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import copy
import logging
import operator
import random
import dns.name
import dns.exception
import dns.resolver
import copy
import dns.rdataclass
import dns.rdatatype
import six
@@ -373,3 +378,88 @@ def check_zone_overlap(zone, raise_on_error=True):
if ns:
msg += u" and is handled by server(s): {0}".format(', '.join(ns))
raise ValueError(msg)
def _mix_weight(records):
"""Weighted population sorting for records with same priority
"""
# trivial case
if len(records) <= 1:
return records
# Optimization for common case: If all weights are the same (e.g. 0),
# just shuffle the records, which is about four times faster.
if all(rr.weight == records[0].weight for rr in records):
random.shuffle(records)
return records
noweight = 0.01 # give records with 0 weight a small chance
result = []
records = set(records)
while len(records) > 1:
# Compute the sum of the weights of those RRs. Then choose a
# uniform random number between 0 and the sum computed (inclusive).
urn = random.uniform(0, sum(rr.weight or noweight for rr in records))
# Select the RR whose running sum value is the first in the selected
# order which is greater than or equal to the random number selected.
acc = 0.
for rr in records.copy():
acc += rr.weight or noweight
if acc >= urn:
records.remove(rr)
result.append(rr)
if records:
result.append(records.pop())
return result
def sort_prio_weight(records):
"""RFC 2782 sorting algorithm for SRV and URI records
RFC 2782 defines a sorting algorithms for SRV records, that is also used
for sorting URI records. Records are sorted by priority and than randomly
shuffled according to weight.
This implementation also removes duplicate entries.
"""
# order records by priority
records = sorted(records, key=operator.attrgetter("priority"))
# remove duplicate entries
uniquerecords = []
seen = set()
for rr in records:
# A SRV record has target and port, URI just has target.
target = (rr.target, getattr(rr, "port", None))
if target not in seen:
uniquerecords.append(rr)
seen.add(target)
# weighted randomization of entries with same priority
result = []
sameprio = []
for rr in uniquerecords:
# add all items with same priority in a bucket
if not sameprio or sameprio[0].priority == rr.priority:
sameprio.append(rr)
else:
# got different priority, shuffle bucket
result.extend(_mix_weight(sameprio))
# start a new priority list
sameprio = [rr]
# add last batch of records with same priority
if sameprio:
result.extend(_mix_weight(sameprio))
return result
def query_srv(qname, resolver=None, **kwargs):
"""Query SRV records and sort reply according to RFC 2782
:param qname: query name, _service._proto.domain.
:return: list of dns.rdtypes.IN.SRV.SRV instances
"""
if resolver is None:
resolver = dns.resolver
answer = resolver.query(qname, rdtype=dns.rdatatype.SRV, **kwargs)
return sort_prio_weight(answer)

View File

@@ -32,6 +32,7 @@ from ipalib import api, _
from ipalib import errors
from ipapython import ipautil
from ipapython.dn import DN
from ipapython.dnsutil import query_srv
from ipapython.ipaldap import ldap_initialize
from ipaserver.install import installutils
from ipaserver.dcerpc_common import (TRUST_BIDIRECTIONAL,
@@ -55,7 +56,6 @@ import samba
import ldap as _ldap
from ipapython import ipaldap
from ipapython.dnsutil import DNSName
from dns import resolver, rdatatype
from dns.exception import DNSException
import pysss_nss_idmap
import pysss
@@ -799,7 +799,7 @@ class DomainValidator(object):
gc_name = '_gc._tcp.%s.' % info['dns_domain']
try:
answers = resolver.query(gc_name, rdatatype.SRV)
answers = query_srv(gc_name)
except DNSException as e:
answers = []

View File

@@ -0,0 +1,106 @@
#
# Copyright (C) 2018 FreeIPA Contributors. See COPYING for license
#
import dns.name
import dns.rdataclass
import dns.rdatatype
from dns.rdtypes.IN.SRV import SRV
from dns.rdtypes.ANY.URI import URI
from ipapython import dnsutil
import pytest
def mksrv(priority, weight, port, target):
return SRV(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.SRV,
priority=priority,
weight=weight,
port=port,
target=dns.name.from_text(target)
)
def mkuri(priority, weight, target):
return URI(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.URI,
priority=priority,
weight=weight,
target=target
)
class TestSortSRV(object):
def test_empty(self):
assert dnsutil.sort_prio_weight([]) == []
def test_one(self):
h1 = mksrv(1, 0, 443, u"host1")
assert dnsutil.sort_prio_weight([h1]) == [h1]
h2 = mksrv(10, 5, 443, u"host2")
assert dnsutil.sort_prio_weight([h2]) == [h2]
def test_prio(self):
h1 = mksrv(1, 0, 443, u"host1")
h2 = mksrv(2, 0, 443, u"host2")
h3 = mksrv(3, 0, 443, u"host3")
assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]
h380 = mksrv(4, 0, 80, u"host3")
assert dnsutil.sort_prio_weight([h1, h3, h380]) == [h1, h3, h380]
hs = mksrv(-1, 0, 443, u"special")
assert dnsutil.sort_prio_weight([h1, h2, hs]) == [hs, h1, h2]
def assert_permutations(self, answers, permutations):
seen = set()
for _unused in range(1000):
result = tuple(dnsutil.sort_prio_weight(answers))
assert result in permutations
seen.add(result)
if seen == permutations:
break
else:
pytest.fail("sorting didn't exhaust all permutations.")
def test_sameprio(self):
h1 = mksrv(1, 0, 443, u"host1")
h2 = mksrv(1, 0, 443, u"host2")
permutations = {
(h1, h2),
(h2, h1),
}
self.assert_permutations([h1, h2], permutations)
def test_weight(self):
h1 = mksrv(1, 0, 443, u"host1")
h2_w15 = mksrv(2, 15, 443, u"host2")
h3_w10 = mksrv(2, 10, 443, u"host3")
permutations = {
(h1, h2_w15, h3_w10),
(h1, h3_w10, h2_w15),
}
self.assert_permutations([h1, h2_w15, h3_w10], permutations)
def test_large(self):
records = tuple(
mksrv(1, i, 443, "host{}".format(i)) for i in range(1000)
)
assert len(dnsutil.sort_prio_weight(records)) == len(records)
class TestSortURI(object):
def test_prio(self):
h1 = mkuri(1, 0, u"https://host1/api")
h2 = mkuri(2, 0, u"https://host2/api")
h3 = mkuri(3, 0, u"https://host3/api")
assert dnsutil.sort_prio_weight([h3, h2, h1]) == [h1, h2, h3]
assert dnsutil.sort_prio_weight([h3, h3, h3]) == [h3]
assert dnsutil.sort_prio_weight([h2, h2, h1, h1]) == [h1, h2]