cert-request: more specific errors in IP address validation

Update the IP address validation to raise different error messages
for:

- inability to reach IP address from a DNS name
- missing PTR records for IP address
- asymmetric PTR / forward records

If multiple scenarios apply, indicate the first error (from list
above).

The code should now be a bit easier to follow.  We first build dicts
of forward and reverse DNS relationships, keyed by IP address.  Then
we check that entries for each iPAddressName are present in both
dicts.  Finally we check for PTR-A/AAAA symmetry.

Update the tests to check that raised ValidationErrors indicate the
expected error.

Part of: https://pagure.io/freeipa/issue/7451

Reviewed-By: Florence Blanc-Renaud <flo@redhat.com>
This commit is contained in:
Fraser Tweedale 2019-02-21 13:54:42 +11:00 committed by Florence Blanc-Renaud
parent 474a2e6952
commit a65c12d042
2 changed files with 79 additions and 61 deletions

View File

@ -1115,39 +1115,63 @@ def _validate_san_ips(san_ipaddrs, san_dnsnames):
address. address.
""" """
san_ip_set = frozenset(unicode(ip) for ip in san_ipaddrs)
# Collect the IP addresses for each SAN dNSName # Build a dict of IPs that are reachable from the SAN dNSNames
san_dns_ips = set() reachable = {}
for name in san_dnsnames: for name in san_dnsnames:
san_dns_ips.update(_san_dnsname_ips(name, cname_depth=1)) _san_ip_update_reachable(reachable, name, cname_depth=1)
# Each SAN iPAddressName must appear in the addresses we just collected # Each iPAddressName must be reachable from a dNSName
unmatched_ips = set(unicode(ip) for ip in san_ipaddrs) - san_dns_ips unreachable_ips = san_ip_set - six.viewkeys(reachable)
if len(unmatched_ips) > 0: if len(unreachable_ips) > 0:
raise errors.ValidationError( raise errors.ValidationError(
name='csr', name='csr',
error=_( error=_(
"IP address in subjectAltName (%s) does not match any DNS name" "IP address in subjectAltName (%s) unreachable from DNS names"
) % ', '.join(unmatched_ips) ) % ', '.join(unreachable_ips)
)
# Collect PTR records for each IP address
ptrs_by_ip = {}
for ip in san_ipaddrs:
ptrs = _ip_ptr_records(unicode(ip))
if len(ptrs) > 0:
ptrs_by_ip[unicode(ip)] = set(s.rstrip('.') for s in ptrs)
# Each iPAddressName must have a corresponding PTR record.
missing_ptrs = san_ip_set - six.viewkeys(ptrs_by_ip)
if len(missing_ptrs) > 0:
raise errors.ValidationError(
name='csr',
error=_(
"IP address in subjectAltName (%s) does not have PTR record"
) % ', '.join(missing_ptrs)
)
# PTRs and forward records must form a loop
for ip, ptrs in ptrs_by_ip.items():
# PTR value must appear in the set of names that resolve to
# this IP address (via A/AAAA records)
if len(ptrs - reachable.get(ip, set())) > 0:
raise errors.ValidationError(
name='csr',
error=_(
"PTR record for SAN IP (%s) does not match A/AAAA records"
) % ip
) )
def _san_dnsname_ips(dnsname, cname_depth): def _san_ip_update_reachable(reachable, dnsname, cname_depth):
""" """
Resolve a DNS name to its IP address(es). Update dict of reachable IPs and the names that reach them.
The name is assumed to be fully qualified. :param reachable: the dict to update. Keys are IP addresses,
values are sets of DNS names.
Returns a set of IP addresses, managed by this IPA instance, :param dnsname: the DNS name to resolve
that correspond to the DNS name (from the subjectAltName). :param cname_depth: How many levels of CNAME indirection are permitted.
:param dnsname: The DNS name (text) for which to resolve the IP addresses
:param cname_depth: How many cnames are we allowed to follow?
:return: The set of IP addresses resolved from the DNS name
""" """
ips = set()
fqdn = dnsutil.DNSName(dnsname).make_absolute() fqdn = dnsutil.DNSName(dnsname).make_absolute()
zone = dnsutil.DNSName(resolver.zone_for_name(fqdn)) zone = dnsutil.DNSName(resolver.zone_for_name(fqdn))
name = fqdn.relativize(zone) name = fqdn.relativize(zone)
@ -1155,34 +1179,27 @@ def _san_dnsname_ips(dnsname, cname_depth):
result = api.Command['dnsrecord_show'](zone, name)['result'] result = api.Command['dnsrecord_show'](zone, name)['result']
except errors.NotFound as nf: except errors.NotFound as nf:
logger.debug("Skipping IPs for %s: %s", dnsname, nf) logger.debug("Skipping IPs for %s: %s", dnsname, nf)
return ips return # nothing to do
for ip in itertools.chain(result.get('arecord', ()), for ip in itertools.chain(result.get('arecord', ()),
result.get('aaaarecord', ())): result.get('aaaarecord', ())):
if _ip_rdns_ok(ip, fqdn): # add this forward relationship to the 'reachable' dict
ips.add(ip) names = reachable.get(ip, set())
names.add(dnsname.rstrip('.'))
reachable[ip] = names
if cname_depth > 0: if cname_depth > 0:
for cname in result.get('cnamerecord', []): for cname in result.get('cnamerecord', []):
if not cname.endswith('.'): if not cname.endswith('.'):
cname = u'%s.%s' % (cname, zone) cname = u'%s.%s' % (cname, zone)
ips.update(_san_dnsname_ips(cname, cname_depth=cname_depth - 1)) _san_ip_update_reachable(reachable, cname, cname_depth - 1)
return ips
def _ip_rdns_ok(ip, fqdn): def _ip_ptr_records(ip):
""" """
Check an IP address's reverse DNS record. Look up PTR record(s) for IP address.
Determines whether the IP address has a reverse DNS entry (managed :return: a ``set`` of IP addresses, possibly empty.
by this IPA instance) that points to the FQDN.
:param ip: The IP address to check
:param fqdn: The FQDN (A/AAAA record) to which the reverse record should
point
:return: True if the IP address's reverse DNS record checks out, False if
it does not
""" """
rname = dnsutil.DNSName(reversename.from_address(ip)) rname = dnsutil.DNSName(reversename.from_address(ip))
@ -1191,16 +1208,10 @@ def _ip_rdns_ok(ip, fqdn):
try: try:
result = api.Command['dnsrecord_show'](zone, name)['result'] result = api.Command['dnsrecord_show'](zone, name)['result']
except errors.NotFound: except errors.NotFound:
logger.debug("Skipping IP %s: reverse DNS record not found", ip) ptrs = set()
return False
# Require the PTR record to match the expected hostname
if any(ptr == fqdn.to_unicode() for ptr in result.get('ptrrecord', [])):
return True
else: else:
logger.debug("Skipping IP: %s: reverse DNS doesn't match FQDN %s", ptrs = set(result.get('ptrrecord', []))
ip, fqdn) return ptrs
return False
@register() @register()

View File

@ -219,6 +219,13 @@ def user_alice(request):
return user.make_fixture(request) return user.make_fixture(request)
# Patterns for ValidationError messages
PAT_FWD = "unreachable from DNS names"
PAT_REV = "does not have PTR record"
PAT_LOOP = "does not match A/AAAA records"
PAT_USER = "forbidden for user principals"
@pytest.mark.tier1 @pytest.mark.tier1
class TestIPAddressSANIssuance(XMLRPC_test): class TestIPAddressSANIssuance(XMLRPC_test):
""" """
@ -241,19 +248,19 @@ class TestIPAddressSANIssuance(XMLRPC_test):
host.run_command('cert_request', csr_ipv4_ipv6, principal=host_princ) host.run_command('cert_request', csr_ipv4_ipv6, principal=host_princ)
def test_failure_extra_ip(self, host, ipv4_a, ipv4_ptr, csr_extra_ipv4): def test_failure_extra_ip(self, host, ipv4_a, ipv4_ptr, csr_extra_ipv4):
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command( host.run_command(
'cert_request', csr_extra_ipv4, principal=host_princ) 'cert_request', csr_extra_ipv4, principal=host_princ)
def test_failure_no_dnsname(self, host, ipv4_a, ipv4_ptr, csr_no_dnsname): def test_failure_no_dnsname(self, host, ipv4_a, ipv4_ptr, csr_no_dnsname):
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command( host.run_command(
'cert_request', csr_no_dnsname, principal=host_princ) 'cert_request', csr_no_dnsname, principal=host_princ)
def test_failure_user_princ( def test_failure_user_princ(
self, host, ipv4_a, ipv4_ptr, csr_alice, user_alice): self, host, ipv4_a, ipv4_ptr, csr_alice, user_alice):
user_alice.ensure_exists() user_alice.ensure_exists()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_USER):
host.run_command( host.run_command(
'cert_request', csr_alice, principal=user_alice.uid) 'cert_request', csr_alice, principal=user_alice.uid)
@ -267,7 +274,7 @@ class TestIPAddressSANMissingARecord(XMLRPC_test):
def test_issuance_ipv4( def test_issuance_ipv4(
self, host, ipv6_aaaa, ipv6_ptr, ipv4_ptr, csr_ipv4): self, host, ipv6_aaaa, ipv6_ptr, ipv4_ptr, csr_ipv4):
"""Issuing with IPv4 address fails.""" """Issuing with IPv4 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command('cert_request', csr_ipv4, principal=host_princ) host.run_command('cert_request', csr_ipv4, principal=host_princ)
def test_issuance_ipv6( def test_issuance_ipv6(
@ -278,7 +285,7 @@ class TestIPAddressSANMissingARecord(XMLRPC_test):
def test_issuance_ipv4_ipv6( def test_issuance_ipv4_ipv6(
self, host, ipv6_aaaa, ipv4_ptr, ipv6_ptr, csr_ipv4_ipv6): self, host, ipv6_aaaa, ipv4_ptr, ipv6_ptr, csr_ipv4_ipv6):
"""Issuing with IPv4 *and* IPv6 address fails.""" """Issuing with IPv4 *and* IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command( host.run_command(
'cert_request', csr_ipv4_ipv6, principal=host_princ) 'cert_request', csr_ipv4_ipv6, principal=host_princ)
@ -297,13 +304,13 @@ class TestIPAddressSANMissingAAAARecord(XMLRPC_test):
def test_issuance_ipv6( def test_issuance_ipv6(
self, host, ipv4_a, ipv6_ptr, ipv4_ptr, csr_ipv6): self, host, ipv4_a, ipv6_ptr, ipv4_ptr, csr_ipv6):
"""Issuing with IPv6 address fails.""" """Issuing with IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command('cert_request', csr_ipv6, principal=host_princ) host.run_command('cert_request', csr_ipv6, principal=host_princ)
def test_issuance_ipv4_ipv6( def test_issuance_ipv4_ipv6(
self, host, ipv4_a, ipv4_ptr, ipv6_ptr, csr_ipv4_ipv6): self, host, ipv4_a, ipv4_ptr, ipv6_ptr, csr_ipv4_ipv6):
"""Issuing with IPv4 *and* IPv6 address fails.""" """Issuing with IPv4 *and* IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command( host.run_command(
'cert_request', csr_ipv4_ipv6, principal=host_princ) 'cert_request', csr_ipv4_ipv6, principal=host_princ)
@ -317,7 +324,7 @@ class TestIPAddressSANMissingIPv4Ptr(XMLRPC_test):
def test_issuance_ipv4( def test_issuance_ipv4(
self, host, ipv4_a, ipv6_aaaa, ipv6_ptr, csr_ipv4): self, host, ipv4_a, ipv6_aaaa, ipv6_ptr, csr_ipv4):
"""Issuing with IPv4 address fails.""" """Issuing with IPv4 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_REV):
host.run_command('cert_request', csr_ipv4, principal=host_princ) host.run_command('cert_request', csr_ipv4, principal=host_princ)
def test_issuance_ipv6( def test_issuance_ipv6(
@ -328,7 +335,7 @@ class TestIPAddressSANMissingIPv4Ptr(XMLRPC_test):
def test_issuance_ipv4_ipv6( def test_issuance_ipv4_ipv6(
self, host, ipv4_a, ipv6_aaaa, ipv6_ptr, csr_ipv4_ipv6): self, host, ipv4_a, ipv6_aaaa, ipv6_ptr, csr_ipv4_ipv6):
"""Issuing with IPv4 *and* IPv6 address fails.""" """Issuing with IPv4 *and* IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_REV):
host.run_command( host.run_command(
'cert_request', csr_ipv4_ipv6, principal=host_princ) 'cert_request', csr_ipv4_ipv6, principal=host_princ)
@ -347,13 +354,13 @@ class TestIPAddressSANMissingIPv6Ptr(XMLRPC_test):
def test_issuance_ipv6( def test_issuance_ipv6(
self, host, ipv4_a, ipv6_aaaa, ipv4_ptr, csr_ipv6): self, host, ipv4_a, ipv6_aaaa, ipv4_ptr, csr_ipv6):
"""Issuing with IPv6 address fails.""" """Issuing with IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_REV):
host.run_command('cert_request', csr_ipv6, principal=host_princ) host.run_command('cert_request', csr_ipv6, principal=host_princ)
def test_issuance_ipv4_ipv6( def test_issuance_ipv4_ipv6(
self, host, ipv4_a, ipv6_aaaa, ipv4_ptr, csr_ipv4_ipv6): self, host, ipv4_a, ipv6_aaaa, ipv4_ptr, csr_ipv4_ipv6):
"""Issuing with IPv4 *and* IPv6 address fails.""" """Issuing with IPv4 *and* IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_REV):
host.run_command( host.run_command(
'cert_request', csr_ipv4_ipv6, principal=host_princ) 'cert_request', csr_ipv4_ipv6, principal=host_princ)
@ -374,13 +381,13 @@ class TestIPAddressSANOtherForwardRecords(XMLRPC_test):
def test_issuance_ipv4( def test_issuance_ipv4(
self, host, other_forward_records, ipv4_ptr, ipv6_ptr, csr_ipv4): self, host, other_forward_records, ipv4_ptr, ipv6_ptr, csr_ipv4):
"""Issuing with IPv4 address fails.""" """Issuing with IPv4 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command('cert_request', csr_ipv4, principal=host_princ) host.run_command('cert_request', csr_ipv4, principal=host_princ)
def test_issuance_ipv6( def test_issuance_ipv6(
self, host, other_forward_records, ipv4_ptr, ipv6_ptr, csr_ipv6): self, host, other_forward_records, ipv4_ptr, ipv6_ptr, csr_ipv6):
"""Issuing with IPv6 address fails.""" """Issuing with IPv6 address fails."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command('cert_request', csr_ipv6, principal=host_princ) host.run_command('cert_request', csr_ipv6, principal=host_princ)
@ -403,7 +410,7 @@ class TestIPAddressPTRLoopback(XMLRPC_test):
def test_failure(self, host, ipv4_a, ipv4_ptr_other, csr_iptest_other): def test_failure(self, host, ipv4_a, ipv4_ptr_other, csr_iptest_other):
"""The A and PTR records are not symmetric.""" """The A and PTR records are not symmetric."""
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_LOOP):
host.run_command( host.run_command(
'cert_request', csr_iptest_other, principal=host_princ) 'cert_request', csr_iptest_other, principal=host_princ)
@ -441,5 +448,5 @@ class TestIPAddressCNAME(XMLRPC_test):
host.run_command('cert_request', csr_cname1, principal=host_princ) host.run_command('cert_request', csr_cname1, principal=host_princ)
def test_two_levels(self, host, csr_cname2): def test_two_levels(self, host, csr_cname2):
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError, match=PAT_FWD):
host.run_command('cert_request', csr_cname2, principal=host_princ) host.run_command('cert_request', csr_cname2, principal=host_princ)