test_integration.host: Move transport-related functionality to a new module

This will make it possible to use a different mechanism for cases like
- Paramiko is not available
- Hosts without SSH servers (e.g. Windows)

Add BaseHost, Transport & Command base classes that define the interface
and common functionality, and Host, ParamikoTransport & SSHCommand with
specific details.

The {get,put}_file_contents methods are left on Host for convenience;
all other Transport methods must be now accessed through the transport.

Part of the work for https://fedorahosted.org/freeipa/ticket/3890
This commit is contained in:
Petr Viktorin 2013-09-17 12:24:40 +02:00
parent 7d2d1cb59d
commit 758c73e149
4 changed files with 371 additions and 214 deletions

View File

@ -21,121 +21,17 @@
import os
import socket
import threading
import subprocess
from contextlib import contextmanager
import errno
import paramiko
from ipapython.ipaldap import IPAdmin
from ipapython import ipautil
from ipapython.ipa_log_manager import log_mgr
from ipatests.test_integration import transport
class RemoteCommand(object):
"""A Popen-style object representing a remote command
Unlike subprocess.Popen, this does not run the given command; instead
it only starts a shell. The command must be written to stdin manually.
The standard error and output are handled by this class. They're not
available for file-like reading. They are logged by default.
To make sure reading doesn't stall after one buffer fills up, they are read
in parallel using threads.
After calling wait(), stdout_text and stderr_text attributes will be
strings containing the output, and returncode will contain the
exit code.
:param host: The Host on which the command is run
:param argv: The command that will be run (for logging only)
:param index: An identification number added to the logs
:param log_stdout: If false, stdout will not be logged
"""
def __init__(self, host, argv, index, log_stdout=True):
self.returncode = None
self.host = host
self.argv = argv
self._stdout_lines = []
self._stderr_lines = []
self.running_threads = set()
self.logger_name = '%s.cmd%s' % (self.host.logger_name, index)
self.log = log_mgr.get_logger(self.logger_name)
self.log.info('RUN %s', argv)
self._ssh = host.transport.open_channel('session')
self._ssh.invoke_shell()
stdin = self.stdin = self._ssh.makefile('wb')
stdout = self._ssh.makefile('rb')
stderr = self._ssh.makefile_stderr('rb')
self._start_pipe_thread(self._stdout_lines, stdout, 'out', log_stdout)
self._start_pipe_thread(self._stderr_lines, stderr, 'err', True)
self._done = False
def wait(self, raiseonerr=True):
"""Wait for the remote process to exit
Raises an excption if the exit code is not 0.
"""
if self._done:
return self.returncode
self._ssh.shutdown_write()
while self.running_threads:
self.running_threads.pop().join()
self.stdout_text = ''.join(self._stdout_lines)
self.stderr_text = ''.join(self._stderr_lines)
self.returncode = self._ssh.recv_exit_status()
self._ssh.close()
self._done = True
if raiseonerr and self.returncode:
self.log.error('Exit code: %s', self.returncode)
raise subprocess.CalledProcessError(self.returncode, self.argv)
else:
self.log.debug('Exit code: %s', self.returncode)
return self.returncode
def _start_pipe_thread(self, result_list, stream, name, do_log=True):
log = log_mgr.get_logger('%s.%s' % (self.logger_name, name))
def read_stream():
for line in stream:
if do_log:
log.debug(line.rstrip('\n'))
result_list.append(line)
thread = threading.Thread(target=read_stream)
self.running_threads.add(thread)
thread.start()
return thread
@contextmanager
def sftp_open(sftp, filename, mode='r'):
"""Context manager that provides a file-like object over a SFTP channel
This provides compatibility with older Paramiko versions.
(In Paramiko 1.10+, file objects from `sftp.open` are directly usable as
context managers).
"""
file = sftp.open(filename, mode)
try:
yield file
finally:
file.close()
class Host(object):
class BaseHost(object):
"""Representation of a remote IPA host"""
transport_class = None
def __init__(self, domain, hostname, role, index, ip=None):
self.domain = domain
self.role = role
@ -175,8 +71,6 @@ class Host(object):
self.env_sh_path = os.path.join(domain.config.test_dir, 'env.sh')
self._command_index = 0
self.log_collectors = []
def __str__(self):
@ -224,13 +118,48 @@ class Host(object):
return env
@property
def transport(self):
try:
return self._transport
except AttributeError:
cls = self.transport_class
if cls:
# transport_class is None in the base class and must be
# set in subclasses.
# Pylint reports that calling None will fail
self._transport = cls(self) # pylint: disable=E1102
else:
raise NotImplementedError('transport class not available')
return self._transport
def get_file_contents(self, filename):
"""Shortcut for transport.get_file_contents"""
return self.transport.get_file_contents(filename)
def put_file_contents(self, filename, contents):
"""Shortcut for transport.put_file_contents"""
self.transport.put_file_contents(filename, contents)
def ldap_connect(self):
"""Return an LDAPClient authenticated to this host as directory manager
"""
ldap = IPAdmin(self.external_hostname)
ldap.do_simple_bind(self.config.dirman_dn,
self.config.dirman_password)
return ldap
def collect_log(self, filename):
for collector in self.log_collectors:
collector(self, filename)
def run_command(self, argv, set_env=True, stdin_text=None,
log_stdout=True, raiseonerr=True,
cwd=None):
"""Run the given command on this host
Returns a RemoteCommand instance. The command will have already run
when this method returns, so its stdout_text, stderr_text, and
Returns a Shell instance. The command will have already run in the
shell when this method returns, so its stdout_text, stderr_text, and
returncode attributes will be available.
:param argv: Command to run, as either a Popen-style list, or a string
@ -242,30 +171,43 @@ class Host(object):
(but will still be available as cmd.stdout_text)
:param raiseonerr: If true, an exception will be raised if the command
does not exit with return code 0
:param cwd: The working directory for the command
"""
assert self.transport
raise NotImplementedError()
self._command_index += 1
command = RemoteCommand(self, argv, index=self._command_index,
log_stdout=log_stdout)
class Host(BaseHost):
"""A Unix host"""
transport_class = transport.ParamikoTransport
def run_command(self, argv, set_env=True, stdin_text=None,
log_stdout=True, raiseonerr=True,
cwd=None):
# This will give us a Bash shell
command = self.transport.start_shell(argv, log_stdout=log_stdout)
# Set working directory
if cwd is None:
cwd = self.config.test_dir
command.stdin.write('cd %s\n' % ipautil.shell_quote(cwd))
# Set the environment
if set_env:
command.stdin.write('. %s\n' %
ipautil.shell_quote(self.env_sh_path))
command.stdin.write('set -e\n')
if isinstance(argv, basestring):
# Run a shell command given as a string
command.stdin.write('(')
command.stdin.write(argv)
command.stdin.write(')')
else:
# Run a command given as a popen-style list (no shell expansion)
for arg in argv:
command.stdin.write(ipautil.shell_quote(arg))
command.stdin.write(' ')
command.stdin.write(';exit\n')
if stdin_text:
command.stdin.write(stdin_text)
@ -273,92 +215,3 @@ class Host(object):
command.wait(raiseonerr=raiseonerr)
return command
@property
def transport(self):
"""Paramiko Transport connected to this host"""
try:
return self._transport
except AttributeError:
sock = socket.create_connection((self.external_hostname,
self.ssh_port))
self._transport = transport = paramiko.Transport(sock)
transport.connect(hostkey=self.host_key)
if self.root_ssh_key_filename:
self.log.debug('Authenticating with private RSA key')
filename = os.path.expanduser(self.root_ssh_key_filename)
key = paramiko.RSAKey.from_private_key_file(filename)
transport.auth_publickey(username='root', key=key)
elif self.root_password:
self.log.debug('Authenticating with password')
transport.auth_password(username='root',
password=self.root_password)
else:
self.log.critical('No SSH credentials configured')
raise RuntimeError('No SSH credentials configured')
return transport
@property
def sftp(self):
"""Paramiko SFTPClient connected to this host"""
try:
return self._sftp
except AttributeError:
transport = self.transport
self._sftp = paramiko.SFTPClient.from_transport(transport)
return self._sftp
def ldap_connect(self):
"""Return an LDAPClient authenticated to this host as directory manager
"""
ldap = IPAdmin(self.external_hostname)
ldap.do_simple_bind(self.config.dirman_dn,
self.config.dirman_password)
return ldap
def mkdir_recursive(self, path):
"""`mkdir -p` on the remote host"""
try:
self.sftp.chdir(path or '/')
except IOError as e:
if not path or path == '/':
raise
self.mkdir_recursive(os.path.dirname(path))
self.sftp.mkdir(path)
self.sftp.chdir(path)
def get_file_contents(self, filename):
"""Read the named remote file and return the contents as a string"""
self.log.debug('READ %s', filename)
with sftp_open(self.sftp, filename) as f:
return f.read()
def put_file_contents(self, filename, contents):
"""Write the given string to the named remote file"""
self.log.info('WRITE %s', filename)
with sftp_open(self.sftp, filename, 'w') as f:
f.write(contents)
def file_exists(self, filename):
"""Return true if the named remote file exists"""
self.log.debug('STAT %s', filename)
try:
self.sftp.stat(filename)
except IOError, e:
if e.errno == errno.ENOENT:
return False
else:
raise
return True
def get_file(self, remotepath, localpath):
self.log.debug('GET %s', remotepath)
self.sftp.get(remotepath, localpath)
def put_file(self, localpath, remotepath):
self.log.info('PUT %s', remotepath)
self.sftp.put(localpath, remotepath)
def collect_log(self, filename):
for collector in self.log_collectors:
collector(self, filename)

View File

@ -40,7 +40,7 @@ log = log_mgr.get_logger(__name__)
def prepare_host(host):
env_filename = os.path.join(host.config.test_dir, 'env.sh')
host.collect_log(env_filename)
host.mkdir_recursive(host.config.test_dir)
host.transport.mkdir_recursive(host.config.test_dir)
host.put_file_contents(env_filename, env_to_script(host.to_env()))
@ -51,10 +51,10 @@ def apply_common_fixes(host):
def backup_file(host, filename):
if host.file_exists(filename):
if host.transport.file_exists(filename):
backupname = os.path.join(host.config.test_dir, 'file_backup',
filename.lstrip('/'))
host.mkdir_recursive(os.path.dirname(backupname))
host.transport.mkdir_recursive(os.path.dirname(backupname))
host.run_command(['cp', '-af', filename, backupname])
return True
else:
@ -63,7 +63,7 @@ def backup_file(host, filename):
ipautil.shell_quote(filename),
ipautil.shell_quote(rmname)))
contents = host.get_file_contents(rmname)
host.mkdir_recursive(os.path.dirname(rmname))
host.transport.mkdir_recursive(os.path.dirname(rmname))
return False

View File

@ -103,7 +103,7 @@ class CALessBase(IntegrationTest):
host.mkdir_recursive(cls.crl_path)
for source in glob.glob(os.path.join(base, '*.crl')):
dest = os.path.join(cls.crl_path, os.path.basename(source))
host.put_file(source, dest)
host.transport.put_file(source, dest)
@classmethod
def uninstall(cls):
@ -174,8 +174,8 @@ class CALessBase(IntegrationTest):
@classmethod
def copy_cert(cls, host, filename):
host.put_file(os.path.join(cls.cert_dir, filename),
os.path.join(host.config.test_dir, filename))
host.transport.put_file(os.path.join(cls.cert_dir, filename),
os.path.join(host.config.test_dir, filename))
@classmethod
def uninstall_server(self, host=None):
@ -211,8 +211,9 @@ class CALessBase(IntegrationTest):
if dirsrv_pkcs12_exists:
files_to_copy.append(dirsrv_pkcs12)
for filename in set(files_to_copy):
master.put_file(os.path.join(self.cert_dir, filename),
os.path.join(master.config.test_dir, filename))
master.transport.put_file(
os.path.join(self.cert_dir, filename),
os.path.join(master.config.test_dir, filename))
self.collect_log(replica, '/var/log/ipareplica-install.log')
self.collect_log(replica, '/var/log/ipaclient-install.log')

View File

@ -0,0 +1,303 @@
# Authors:
# Petr Viktorin <pviktori@redhat.com>
#
# Copyright (C) 2013 Red Hat
# see file 'COPYING' for use and warranty information
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Objects for communicating with remote hosts"""
import os
import socket
import threading
import subprocess
from contextlib import contextmanager
import errno
from ipapython.ipa_log_manager import log_mgr
try:
import paramiko
have_paramiko = True
except ImportError:
have_paramiko = False
class Transport(object):
"""Mechanism for communicating with remote hosts
The Transport can manipulate files on a remote host, and open a Command.
The base class defines an interface that specific subclasses implement.
"""
def __init__(self, host):
self.host = host
self.logger_name = '%s.%s' % (host.logger_name, type(self).__name__)
self.log = log_mgr.get_logger(self.logger_name)
self._command_index = 0
def get_file_contents(self, filename):
"""Read the named remote file and return the contents as a string"""
raise NotImplementedError('Transport.get_file_contents')
def put_file_contents(self, filename, contents):
"""Write the given string to the named remote file"""
raise NotImplementedError('Transport.put_file_contents')
def file_exists(self, filename):
"""Return true if the named remote file exists"""
raise NotImplementedError('Transport.file_exists')
def mkdir(self, path):
"""Make the named directory"""
raise NotImplementedError('Transport.mkdir')
def start_shell(self, argv, log_stdout=True):
"""Start a Shell
:param argv: The command this shell is intended to run (used for
logging only)
:param log_stdout: If false, the stdout will not be logged (useful when
binary output is expected)
Given a `shell` from this method, the caller can then use
``shell.stdin.write()`` to input any command(s), call ``shell.wait()``
to let the command run, and then inspect ``returncode``,
``stdout_text`` or ``stderr_text``.
"""
raise NotImplementedError('Transport.start_shell')
def mkdir_recursive(self, path):
"""`mkdir -p` on the remote host"""
if not path or path == '/':
raise ValueError('Invalid path')
if not self.file_exists(path or '/'):
self.mkdir_recursive(os.path.dirname(path))
self.mkdir(path)
def get_file(self, remotepath, localpath):
"""Copy a file from the remote host to a local file"""
contents = self.get_file_contents(remotepath)
with open(localpath, 'wb') as local_file:
local_file.write(contents)
def put_file(self, localpath, remotepath):
"""Copy a local file to the remote host"""
with open(localpath, 'rb') as local_file:
contents = local_file.read()
self.put_file_contents(remotepath, contents)
def get_next_command_logger_name(self):
self._command_index += 1
return '%s.cmd%s' % (self.host.logger_name, self._command_index)
class Command(object):
"""A Popen-style object representing a remote command
Instances of this class should only be created via method of a concrete
Transport, such as start_shell.
The standard error and output are handled by this class. They're not
available for file-like reading, and are logged by default.
To make sure reading doesn't stall after one buffer fills up, they are read
in parallel using threads.
After calling wait(), ``stdout_text`` and ``stderr_text`` attributes will
be strings containing the output, and ``returncode`` will contain the
exit code.
"""
def __init__(self, argv, logger_name=None, log_stdout=True):
self.returncode = None
self.argv = argv
self._done = False
if logger_name:
self.logger_name = logger_name
else:
self.logger_name = '%s.%s' % (self.__module__, type(self).__name__)
self.log = log_mgr.get_logger(self.logger_name)
def wait(self, raiseonerr=True):
"""Wait for the remote process to exit
Raises an excption if the exit code is not 0, unless raiseonerr is
true.
"""
if self._done:
return self.returncode
self._end_process()
self._done = True
if raiseonerr and self.returncode:
self.log.error('Exit code: %s', self.returncode)
raise subprocess.CalledProcessError(self.returncode, self.argv)
else:
self.log.debug('Exit code: %s', self.returncode)
return self.returncode
def _end_process(self):
"""Wait until the process exits and output is received, close channel
Called from wait()
"""
raise NotImplementedError()
class ParamikoTransport(Transport):
"""Transport that uses the Paramiko SSH2 library"""
def __init__(self, host):
super(ParamikoTransport, self).__init__(host)
sock = socket.create_connection((host.external_hostname,
host.ssh_port))
self._transport = transport = paramiko.Transport(sock)
transport.connect(hostkey=host.host_key)
if host.root_ssh_key_filename:
self.log.debug('Authenticating with private RSA key')
filename = os.path.expanduser(host.root_ssh_key_filename)
key = paramiko.RSAKey.from_private_key_file(filename)
transport.auth_publickey(username='root', key=key)
elif host.root_password:
self.log.debug('Authenticating with password')
transport.auth_password(username='root',
password=host.root_password)
else:
self.log.critical('No SSH credentials configured')
raise RuntimeError('No SSH credentials configured')
@contextmanager
def sftp_open(self, filename, mode='r'):
"""Context manager that provides a file-like object over a SFTP channel
This provides compatibility with older Paramiko versions.
(In Paramiko 1.10+, file objects from `sftp.open` are directly usable
as context managers).
"""
file = self.sftp.open(filename, mode)
try:
yield file
finally:
file.close()
@property
def sftp(self):
"""Paramiko SFTPClient connected to this host"""
try:
return self._sftp
except AttributeError:
transport = self._transport
self._sftp = paramiko.SFTPClient.from_transport(transport)
return self._sftp
def get_file_contents(self, filename):
"""Read the named remote file and return the contents as a string"""
self.log.debug('READ %s', filename)
with self.sftp_open(filename) as f:
return f.read()
def put_file_contents(self, filename, contents):
"""Write the given string to the named remote file"""
self.log.info('WRITE %s', filename)
with self.sftp_open(filename, 'w') as f:
f.write(contents)
def file_exists(self, filename):
"""Return true if the named remote file exists"""
self.log.debug('STAT %s', filename)
try:
self.sftp.stat(filename)
except IOError, e:
if e.errno == errno.ENOENT:
return False
else:
raise
return True
def mkdir(self, path):
self.log.info('MKDIR %s', path)
self.sftp.mkdir(path)
def start_shell(self, argv, log_stdout=True):
logger_name = self.get_next_command_logger_name()
ssh = self._transport.open_channel('session')
self.log.info('RUN %s', argv)
return SSHCommand(ssh, argv, logger_name=logger_name,
log_stdout=log_stdout)
def get_file(self, remotepath, localpath):
self.log.debug('GET %s', remotepath)
self.sftp.get(remotepath, localpath)
def put_file(self, localpath, remotepath):
self.log.info('PUT %s', remotepath)
self.sftp.put(localpath, remotepath)
class SSHCommand(Command):
"""Command implementation for ParamikoTransport and OpenSSHTranspport"""
def __init__(self, ssh, argv, logger_name, log_stdout=True,
collect_output=True):
super(SSHCommand, self).__init__(argv, logger_name,
log_stdout=log_stdout)
self._stdout_lines = []
self._stderr_lines = []
self.running_threads = set()
self._ssh = ssh
self.log.debug('RUN %s', argv)
self._ssh.invoke_shell()
stdin = self.stdin = self._ssh.makefile('wb')
stdout = self._ssh.makefile('rb')
stderr = self._ssh.makefile_stderr('rb')
if collect_output:
self._start_pipe_thread(self._stdout_lines, stdout, 'out',
log_stdout)
self._start_pipe_thread(self._stderr_lines, stderr, 'err', True)
def _end_process(self, raiseonerr=True):
self._ssh.shutdown_write()
while self.running_threads:
self.running_threads.pop().join()
self.stdout_text = ''.join(self._stdout_lines)
self.stderr_text = ''.join(self._stderr_lines)
self.returncode = self._ssh.recv_exit_status()
self._ssh.close()
def _start_pipe_thread(self, result_list, stream, name, do_log=True):
"""Start a thread that copies lines from ``stream`` to ``result_list``
If do_log is true, also logs the lines under ``name``
The thread is added to ``self.running_threads``.
"""
log = log_mgr.get_logger('%s.%s' % (self.logger_name, name))
def read_stream():
for line in stream:
if do_log:
log.debug(line.rstrip('\n'))
result_list.append(line)
thread = threading.Thread(target=read_stream)
self.running_threads.add(thread)
thread.start()
return thread