Improve address family handling in sockets

Many functions use low-level socket interface for connection or
various checks. However, most of the time we don't respect
automatic address family detection but rather try to force our
values. This may cause either redundat connection tries when an
address family is disabled on system tries or even crashes
when socket exceptions are not properly caught.

Instead of forcing address families to socket, rather use
getaddrinfo interface to automatically retrieve a list of all
relevant address families and other connection settings when
connecting to remote/local machine or binding to a local port.
Now, we will also fill correctly all connection parameters like
flowinfo and scopeid for IPv6 connections which will for example
prevent issues with scoped IPv6 addresses.

bind_port_responder function was changed to at first try to bind
to IPv6 wildcard address before IPv4 as IPv6 socket is able to
accept both IPv4 and IPv6 connections (unlike IPv4 socket).

nsslib connection was refactored to use nss.io.AddrInfo class to
get all the available connections. Socket is now not created by
default in NSSConnection class initializer, but rather when the
actual connection is being made, becase we do not an address family
where connection is successful.

https://fedorahosted.org/freeipa/ticket/2913
https://fedorahosted.org/freeipa/ticket/2695
This commit is contained in:
Martin Kosek 2012-07-03 16:49:10 +02:00
parent 5c54dd5b03
commit 4879c68d68
6 changed files with 165 additions and 167 deletions

View File

@ -236,15 +236,15 @@ class PortResponder(threading.Thread):
self._stop_request = True
def port_check(host, port_list):
ip = installutils.resolve_host(host)
if not ip:
raise RuntimeError("Port check failed! Unable to resolve host name '%s'" % host)
ports_failed = []
ports_udp_warning = [] # conncheck could not verify that port is open
for port in port_list:
if ipautil.host_port_open(host, port.port, port.port_type, socket_timeout=CONNECT_TIMEOUT):
try:
port_open = ipautil.host_port_open(host, port.port,
port.port_type, socket_timeout=CONNECT_TIMEOUT)
except socket.gaierror:
raise RuntimeError("Port check failed! Unable to resolve host name '%s'" % host)
if port_open:
result = "OK"
else:
if port.port_type == socket.SOCK_DGRAM:

View File

@ -979,19 +979,36 @@ def configure_ssh(fstore, ssh_dir, options):
def resolve_ipaddress(server):
""" Connect to the server's LDAP port in order to determine what ip
address this machine uses as "public" ip (relative to the server).
Returns a tuple with the IP address and address family when
connection was successful. Socket error is raised otherwise.
"""
last_socket_error = None
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
try:
s.connect((server, 389))
addr, port = s.getsockname()
except socket.gaierror:
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_TCP)
s.connect((server, 389))
addr, port, foo, bar = s.getsockname()
s.close()
for res in socket.getaddrinfo(server, 389, socket.AF_UNSPEC,
socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
try:
s = socket.socket(af, socktype, proto)
except socket.error, e:
last_socket_error = e
s = None
continue
return addr
try:
s.connect(sa)
sockname = s.getsockname()
# For both IPv4 and IPv6 own IP address is always the first item
return (sockname[0], af)
except socket.error, e:
last_socket_error = e
finally:
if s:
s.close()
if last_socket_error is not None:
raise last_socket_error # pylint: disable=E0702
def do_nsupdate(update_txt):
root_logger.debug("Writing nsupdate commands to %s:", UPDATE_FILE)
@ -1037,7 +1054,13 @@ CCACHE_FILE = "/etc/ipa/.dns_ccache"
def update_dns(server, hostname):
ip = resolve_ipaddress(server)
try:
(ip, af) = resolve_ipaddress(server)
except socket.gaierror, e:
root_logger.debug("update_dns: could not connect to server: %s", e)
root_logger.error("Cannot update DNS records! "
"Failed to connect to server '%s'.", server)
return
sub_dict = dict(HOSTNAME=hostname,
IPADDRESS=ip,
@ -1045,9 +1068,9 @@ def update_dns(server, hostname):
ZONE='.'.join(hostname.split('.')[1:])
)
if len(ip.split('.')) == 4:
if af == socket.AF_INET:
template = UPDATE_TEMPLATE_A
elif ':' in ip:
elif af == socket.AF_INET6:
template = UPDATE_TEMPLATE_AAAA
else:
root_logger.info("Failed to determine this machine's ip address.")

View File

@ -48,6 +48,7 @@ from dns.exception import DNSException
from ipapython.ipa_log_manager import *
from ipapython import ipavalidate
from ipapython import config
try:
from subprocess import CalledProcessError
except ImportError:
@ -672,72 +673,103 @@ def get_gsserror(e):
def host_port_open(host, port, socket_type=socket.SOCK_STREAM, socket_timeout=None):
families = (socket.AF_INET, socket.AF_INET6)
success = False
for family in families:
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket_type):
af, socktype, proto, canonname, sa = res
try:
try:
s = socket.socket(family, socket_type)
s = socket.socket(af, socktype, proto)
except socket.error:
s = None
continue
if socket_timeout is not None:
s.settimeout(socket_timeout)
s.connect((host, port))
s.connect(sa)
if socket_type == socket.SOCK_DGRAM:
s.send('')
s.recv(512)
success = True
return True
except socket.error, e:
pass
finally:
s.close()
if success:
return True
if s:
s.close()
return False
def bind_port_responder(port, socket_type=socket.SOCK_STREAM, socket_timeout=None, responder_data=None):
families = (socket.AF_INET, socket.AF_INET6)
host = None # all available interfaces
last_socket_error = None
host = '' # all available interfaces
# At first try to create IPv6 socket as it is able to accept both IPv6 and
# IPv4 connections (when not turned off)
families = (socket.AF_INET6, socket.AF_INET)
s = None
for family in families:
try:
s = socket.socket(family, socket_type)
addr_infos = socket.getaddrinfo(host, port, family, socket_type, 0,
socket.AI_PASSIVE)
except socket.error, e:
if family == families[-1]: # last available family
raise e
if socket_timeout is not None:
s.settimeout(socket_timeout)
if socket_type == socket.SOCK_STREAM:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind((host, port))
if socket_type == socket.SOCK_STREAM:
s.listen(1)
connection, client_address = s.accept()
last_socket_error = e
continue
for res in addr_infos:
af, socktype, proto, canonname, sa = res
try:
if responder_data:
connection.sendall(responder_data) #pylint: disable=E1101
finally:
connection.close()
elif socket_type == socket.SOCK_DGRAM:
data, addr = s.recvfrom(1)
s = socket.socket(af, socktype, proto)
except socket.error, e:
last_socket_error = e
s = None
continue
if responder_data:
s.sendto(responder_data, addr)
finally:
s.close()
if socket_timeout is not None:
s.settimeout(1)
if af == socket.AF_INET6:
try:
# Make sure IPv4 clients can connect to IPv6 socket
s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
except socket.error:
pass
if socket_type == socket.SOCK_STREAM:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(sa)
while True:
if socket_type == socket.SOCK_STREAM:
s.listen(1)
connection, client_address = s.accept()
try:
if responder_data:
connection.sendall(responder_data) #pylint: disable=E1101
finally:
connection.close()
elif socket_type == socket.SOCK_DGRAM:
data, addr = s.recvfrom(1)
if responder_data:
s.sendto(responder_data, addr)
except socket.timeout:
# Timeout is expectable as it was requested by caller, raise
# the exception back to him
raise
except socket.error, e:
last_socket_error = e
s.close()
s = None
continue
finally:
if s:
s.close()
if s is None and last_socket_error is not None:
raise last_socket_error # pylint: disable=E0702
def is_host_resolvable(fqdn):
for rdtype in (rdatatype.A, rdatatype.AAAA):
@ -1015,34 +1047,24 @@ def utf8_encode_values(values):
def wait_for_open_ports(host, ports, timeout=0):
"""
Wait until the specified port(s) on the remote host are open. Timeout
in seconds may be specified to limit the wait.
in seconds may be specified to limit the wait. If the timeout is
exceeded, socket.timeout exception is raised.
"""
if not isinstance(ports, (tuple, list)):
ports = [ports]
root_logger.debug('wait_for_open_ports: %s %s timeout %d' % (host, ports, timeout))
root_logger.debug('wait_for_open_ports: %s %s timeout %d', host, ports, timeout)
op_timeout = time.time() + timeout
ipv6_failover = False
for port in ports:
while True:
try:
if ipv6_failover:
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
else:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((host, port))
s.close()
port_open = host_port_open(host, port)
if port_open:
break
except socket.error, e:
if e.errno == 111: # 111: Connection refused
if timeout and time.time() > op_timeout: # timeout exceeded
raise e
time.sleep(1)
elif not ipv6_failover: # fallback to IPv6 connection
ipv6_failover = True
else:
raise e
if timeout and time.time() > op_timeout: # timeout exceeded
raise socket.timeout()
time.sleep(1)
def wait_for_open_socket(socket_name, timeout=0):
"""

View File

@ -115,6 +115,12 @@ def client_auth_data_callback(ca_names, chosen_nickname, password, certdb):
return False
return False
_af_dict = {
socket.AF_INET: io.PR_AF_INET,
socket.AF_INET6: io.PR_AF_INET6,
socket.AF_UNSPEC: io.PR_AF_UNSPEC
}
class NSSAddressFamilyFallback(object):
def __init__(self, family):
self.sock_family = family
@ -124,67 +130,39 @@ class NSSAddressFamilyFallback(object):
"""
Translate a family from python socket module to nss family.
"""
if sock_family in [ socket.AF_INET, socket.AF_UNSPEC ]:
return io.PR_AF_INET
elif sock_family == socket.AF_INET6:
return io.PR_AF_INET6
else:
try:
return _af_dict[sock_family]
except KeyError:
raise ValueError('Uknown socket family %d\n', sock_family)
def _get_next_family(self):
if self.sock_family == socket.AF_UNSPEC and \
self.family == io.PR_AF_INET:
return io.PR_AF_INET6
return None
def _create_socket(self):
self.sock = io.Socket(family=self.family)
def _connect_socket_family(self, host, port, family):
root_logger.debug("connect_socket_family: host=%s port=%s family=%s",
host, port, io.addr_family_name(family))
try:
addr_info = [ ai for ai in io.AddrInfo(host) if ai.family == family ]
# No suitable families
if len(addr_info) == 0:
raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR,
"Cannot resolve %s using family %s" % (host, io.addr_family_name(family)))
# Try connecting to the NetworkAddresses
for net_addr in addr_info:
net_addr.port = port
root_logger.debug("connecting: %s", net_addr)
try:
self.sock.connect(net_addr)
except Exception, e:
root_logger.debug("Could not connect socket to %s, error: %s, retrying..",
net_addr, str(e))
continue
else:
return
# Could not connect with any of NetworkAddresses
raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR,
"Could not connect to %s using any address" % host)
except ValueError, e:
raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR, e.message)
def connect_socket(self, host, port):
try:
self._connect_socket_family(host, port, self.family)
except NSPRError, e:
if e.errno == error.PR_ADDRESS_NOT_SUPPORTED_ERROR:
next_family = self._get_next_family()
if next_family:
self.family = next_family
self._create_socket()
self._connect_socket_family(host, port, self.family)
else:
root_logger.debug('No next family to try..')
raise e
else:
raise e
addr_info = io.AddrInfo(host, family=self.family)
except Exception:
raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR,
"Cannot resolve %s using family %s" % (host,
io.addr_family_name(self.family)))
for net_addr in addr_info:
root_logger.debug("Connecting: %s", net_addr)
net_addr.port = port
self.family = net_addr.family
try:
self._create_socket()
self.sock.connect(net_addr)
return
except Exception, e:
root_logger.debug("Could not connect socket to %s, error: %s",
net_addr, str(e))
root_logger.debug("Try to continue with next family...")
continue
raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR,
"Could not connect to %s using any address" % host)
class NSSConnection(httplib.HTTPConnection, NSSAddressFamilyFallback):
default_port = httplib.HTTPSConnection.default_port
@ -218,12 +196,10 @@ class NSSConnection(httplib.HTTPConnection, NSSAddressFamilyFallback):
nss.nss_init(dbdir)
ssl.set_domestic_policy()
nss.set_password_callback(self.password_callback)
self._create_socket()
def _create_socket(self):
#TODO remove the try block once python-nss is guaranteed to
#contain these values
# TODO: remove the try block once python-nss is guaranteed to contain
# these values
try :
ssl_enable_renegotiation = SSL_ENABLE_RENEGOTIATION #pylint: disable=E0602
ssl_require_safe_negotiation = SSL_REQUIRE_SAFE_NEGOTIATION #pylint: disable=E0602

View File

@ -119,8 +119,15 @@ def get_ds_instances():
return instances
def check_ports():
ds_unsecure = installutils.port_available(389)
ds_secure = installutils.port_available(636)
"""
Check of Directory server ports are open.
Returns a tuple with two booleans, one for unsecure port 389 and one for
secure port 636. True means that the port is free, False means that the
port is taken.
"""
ds_unsecure = not ipautil.host_port_open(None, 389)
ds_secure = not ipautil.host_port_open(None, 636)
return (ds_unsecure, ds_secure)
def is_ds_running(server_id=''):

View File

@ -256,36 +256,6 @@ def read_dns_forwarders():
return addrs
def port_available(port):
"""Try to bind to a port on the wildcard host
Return 1 if the port is available
Return 0 if the port is in use
"""
rv = 1
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
fcntl.fcntl(s, fcntl.F_SETFD, fcntl.FD_CLOEXEC)
s.bind(('', port))
s.close()
except socket.error, e:
if e[0] == errno.EADDRINUSE:
rv = 0
if rv:
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
fcntl.fcntl(s, fcntl.F_SETFD, fcntl.FD_CLOEXEC)
s.bind(('', port))
s.close()
except socket.error, e:
if e[0] == errno.EADDRINUSE:
rv = 0
return rv
def get_password(prompt):
if os.isatty(sys.stdin.fileno()):
return getpass.getpass(prompt)