diff --git a/ipapython/dnsutil.py b/ipapython/dnsutil.py index aca506120..16549c8f6 100644 --- a/ipapython/dnsutil.py +++ b/ipapython/dnsutil.py @@ -24,7 +24,7 @@ import copy import six -from ipapython.ipautil import CheckedIPAddress +from ipapython.ipautil import UnsafeIPAddress from ipapython.ipa_log_manager import root_logger if six.PY3: @@ -323,18 +323,12 @@ def resolve_rrsets(fqdn, rdtypes): def resolve_ip_addresses(fqdn): """Get IP addresses from DNS A/AAAA records for given host (using DNS). :returns: - list of IP addresses as CheckedIPAddress objects + list of IP addresses as UnsafeIPAddress objects """ rrsets = resolve_rrsets(fqdn, ['A', 'AAAA']) ip_addresses = set() for rrset in rrsets: - ip_addresses.update({CheckedIPAddress(ip, # accept whatever is in DNS - parse_netmask=False, - allow_network=True, - allow_loopback=True, - allow_broadcast=True, - allow_multicast=True) - for ip in rrset}) + ip_addresses.update({UnsafeIPAddress(ip) for ip in rrset}) return ip_addresses diff --git a/ipapython/ipautil.py b/ipapython/ipautil.py index 8506bf2d5..763a99c11 100644 --- a/ipapython/ipautil.py +++ b/ipapython/ipautil.py @@ -74,102 +74,133 @@ def get_domain_name(): return domain_name -class CheckedIPAddress(netaddr.IPAddress): + +class UnsafeIPAddress(netaddr.IPAddress): + """Any valid IP address with or without netmask.""" # Use inet_pton() rather than inet_aton() for IP address parsing. We # will use the same function in IPv4/IPv6 conversions + be stricter # and don't allow IP addresses such as '1.1.1' in the same time netaddr_ip_flags = netaddr.INET_PTON + def __init__(self, addr): + if isinstance(addr, UnsafeIPAddress): + self._net = addr._net + super(UnsafeIPAddress, self).__init__(addr, + flags=self.netaddr_ip_flags) + return + + elif isinstance(addr, netaddr.IPAddress): + self._net = None # no information about netmask + super(UnsafeIPAddress, self).__init__(addr, + flags=self.netaddr_ip_flags) + return + + elif isinstance(addr, netaddr.IPNetwork): + self._net = addr + super(UnsafeIPAddress, self).__init__(self._net.ip, + flags=self.netaddr_ip_flags) + return + + # option of last resort: parse it as string + self._net = None + addr = str(addr) + try: + try: + addr = netaddr.IPAddress(addr, flags=self.netaddr_ip_flags) + except netaddr.AddrFormatError: + # netaddr.IPAddress doesn't handle zone indices in textual + # IPv6 addresses. Try removing zone index and parse the + # address again. + addr, sep, foo = addr.partition('%') + if sep != '%': + raise + addr = netaddr.IPAddress(addr, flags=self.netaddr_ip_flags) + if addr.version != 6: + raise + except ValueError: + self._net = netaddr.IPNetwork(addr, flags=self.netaddr_ip_flags) + addr = self._net.ip + super(UnsafeIPAddress, self).__init__(addr, + flags=self.netaddr_ip_flags) + + +class CheckedIPAddress(UnsafeIPAddress): + """IPv4 or IPv6 address with additional constraints. + + Reserved or link-local addresses are never accepted. + """ def __init__(self, addr, match_local=False, parse_netmask=True, allow_network=False, allow_loopback=False, allow_broadcast=False, allow_multicast=False): + + super(CheckedIPAddress, self).__init__(addr) if isinstance(addr, CheckedIPAddress): - super(CheckedIPAddress, self).__init__(addr, flags=self.netaddr_ip_flags) self.prefixlen = addr.prefixlen return - net = None - iface = None + if not parse_netmask and self._net: + raise ValueError( + "netmask and prefix length not allowed here: {}".format(addr)) - if isinstance(addr, netaddr.IPNetwork): - net = addr - addr = net.ip - elif isinstance(addr, netaddr.IPAddress): - pass - else: - try: - try: - addr = netaddr.IPAddress(str(addr), flags=self.netaddr_ip_flags) - except netaddr.AddrFormatError: - # netaddr.IPAddress doesn't handle zone indices in textual - # IPv6 addresses. Try removing zone index and parse the - # address again. - if not isinstance(addr, six.string_types): - raise - addr, sep, foo = addr.partition('%') - if sep != '%': - raise - addr = netaddr.IPAddress(str(addr), flags=self.netaddr_ip_flags) - if addr.version != 6: - raise - except ValueError: - net = netaddr.IPNetwork(str(addr), flags=self.netaddr_ip_flags) - if not parse_netmask: - raise ValueError("netmask and prefix length not allowed here") - addr = net.ip + if self.version not in (4, 6): + raise ValueError("unsupported IP version {}".format(self.version)) - if addr.version not in (4, 6): - raise ValueError("unsupported IP version") + if not allow_loopback and self.is_loopback(): + raise ValueError("cannot use loopback IP address {}".format(addr)) + if (not self.is_loopback() and self.is_reserved()) \ + or self in netaddr.ip.IPV4_6TO4: + raise ValueError( + "cannot use IANA reserved IP address {}".format(addr)) - if not allow_loopback and addr.is_loopback(): - raise ValueError("cannot use loopback IP address") - if (not addr.is_loopback() and addr.is_reserved()) \ - or addr in netaddr.ip.IPV4_6TO4: - raise ValueError("cannot use IANA reserved IP address") - - if addr.is_link_local(): - raise ValueError("cannot use link-local IP address") - if not allow_multicast and addr.is_multicast(): - raise ValueError("cannot use multicast IP address") + if self.is_link_local(): + raise ValueError( + "cannot use link-local IP address {}".format(addr)) + if not allow_multicast and self.is_multicast(): + raise ValueError("cannot use multicast IP address {}".format(addr)) if match_local: - if addr.version == 4: + if self.version == 4: family = netifaces.AF_INET - elif addr.version == 6: + elif self.version == 6: family = netifaces.AF_INET6 else: raise ValueError( - "Unsupported address family ({})".format(addr.version) + "Unsupported address family ({})".format(self.version) ) + iface = None for interface in netifaces.interfaces(): for ifdata in netifaces.ifaddresses(interface).get(family, []): ifnet = netaddr.IPNetwork('{addr}/{netmask}'.format( addr=ifdata['addr'], netmask=ifdata['netmask'] )) - if ifnet == net or (net is None and ifnet.ip == addr): - net = ifnet + if ifnet == self._net or ( + self._net is None and ifnet.ip == self): + self._net = ifnet iface = interface break if iface is None: - raise ValueError('No network interface matches the provided IP address and netmask') + raise ValueError('no network interface matches the IP address ' + 'and netmask {}'.format(addr)) - if net is None: - if addr.version == 4: - net = netaddr.IPNetwork(netaddr.cidr_abbrev_to_verbose(str(addr))) - elif addr.version == 6: - net = netaddr.IPNetwork(str(addr) + '/64') + if self._net is None: + if self.version == 4: + self._net = netaddr.IPNetwork( + netaddr.cidr_abbrev_to_verbose(str(self))) + elif self.version == 6: + self._net = netaddr.IPNetwork(str(self) + '/64') - if not allow_network and addr == net.network: - raise ValueError("cannot use IP network address") - if not allow_broadcast and addr.version == 4 and addr == net.broadcast: - raise ValueError("cannot use broadcast IP address") + if not allow_network and self == self._net.network: + raise ValueError("cannot use IP network address {}".format(addr)) + if not allow_broadcast and (self.version == 4 and + self == self._net.broadcast): + raise ValueError("cannot use broadcast IP address {}".format(addr)) + + self.prefixlen = self._net.prefixlen - super(CheckedIPAddress, self).__init__(addr, flags=self.netaddr_ip_flags) - self.prefixlen = net.prefixlen def valid_ip(addr): return netaddr.valid_ipv4(addr) or netaddr.valid_ipv6(addr) diff --git a/ipaserver/install/installutils.py b/ipaserver/install/installutils.py index a15571f92..25f48aed1 100644 --- a/ipaserver/install/installutils.py +++ b/ipaserver/install/installutils.py @@ -448,7 +448,7 @@ def create_keytab(path, principal): def resolve_ip_addresses_nss(fqdn): """Get list of IP addresses for given host (using NSS/getaddrinfo). :returns: - list of IP addresses as CheckedIPAddress objects + list of IP addresses as UnsafeIPAddress objects """ # make sure the name is fully qualified # so search path from resolv.conf does not apply @@ -468,13 +468,7 @@ def resolve_ip_addresses_nss(fqdn): ip_addresses = set() for ai in addrinfos: try: - ip = ipautil.CheckedIPAddress(ai[4][0], - parse_netmask=False, - # these are unreliable, disable them - allow_network=True, - allow_loopback=True, - allow_broadcast=True, - allow_multicast=True) + ip = ipautil.UnsafeIPAddress(ai[4][0]) except ValueError as ex: # getaddinfo may return link-local address other similar oddities # which are not accepted by CheckedIPAddress - skip these @@ -501,8 +495,7 @@ def get_host_name(no_host_dns): def get_server_ip_address(host_name, unattended, setup_dns, ip_addresses): hostaddr = resolve_ip_addresses_nss(host_name) if hostaddr.intersection( - {ipautil.CheckedIPAddress(ip, allow_loopback=True) - for ip in ['127.0.0.1', '::1']}): + {ipautil.UnsafeIPAddress(ip) for ip in ['127.0.0.1', '::1']}): print("The hostname resolves to the localhost address (127.0.0.1/::1)", file=sys.stderr) print("Please change your /etc/hosts file so that the hostname", file=sys.stderr) print("resolves to the ip address of your network interface.", file=sys.stderr)