diff --git a/ipatests/test_integration/host.py b/ipatests/test_integration/host.py index 38d9ae6f5..7e5fca1ce 100644 --- a/ipatests/test_integration/host.py +++ b/ipatests/test_integration/host.py @@ -178,7 +178,7 @@ class BaseHost(object): class Host(BaseHost): """A Unix host""" - transport_class = transport.ParamikoTransport + transport_class = transport.SSHTransport def run_command(self, argv, set_env=True, stdin_text=None, log_stdout=True, raiseonerr=True, diff --git a/ipatests/test_integration/transport.py b/ipatests/test_integration/transport.py index 52b689a1d..a0bd3700a 100644 --- a/ipatests/test_integration/transport.py +++ b/ipatests/test_integration/transport.py @@ -17,7 +17,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Objects for communicating with remote hosts""" +"""Objects for communicating with remote hosts + +This class defines "SSHTransport" as ParamikoTransport (by default), or as +OpenSSHTransport (if Paramiko is not importable, or the IPA_TEST_SSH_TRANSPORT +environment variable is set to "openssh"). +""" import os import socket @@ -27,6 +32,7 @@ from contextlib import contextmanager import errno from ipapython.ipa_log_manager import log_mgr +from ipatests import util try: import paramiko @@ -247,6 +253,132 @@ class ParamikoTransport(Transport): self.sftp.put(localpath, remotepath) +class OpenSSHTransport(Transport): + """Transport that uses the `ssh` binary""" + def __init__(self, host): + super(OpenSSHTransport, self).__init__(host) + self.control_dir = util.TempDir() + + self.ssh_argv = self._get_ssh_argv() + + # Run a "control master" process. This serves two purposes: + # - Establishes a control socket; other SSHs will connect to it + # and reuse the same connection. This way the slow handshake + # only needs to be done once + # - Writes the host to known_hosts so stderr of "real" connections + # doesn't contain the "unknown host" warning + # Popen closes the stdin pipe when it's garbage-collected, so + # this process will exit when it's no longer needed + command = ['-o', 'ControlMaster=yes', '/usr/bin/cat'] + self.control_master = self._run(command, collect_output=False) + + def _get_ssh_argv(self): + """Return the path to SSH and options needed for every call""" + control_file = os.path.join(self.control_dir.path, 'control') + known_hosts_file = os.path.join(self.control_dir.path, 'known_hosts') + + argv = ['/usr/bin/ssh', + '-l', 'root', + '-o', 'ControlPath=%s' % control_file, + '-o', 'StrictHostKeyChecking=no', + '-o', 'UserKnownHostsFile=%s' % known_hosts_file] + + if self.host.root_ssh_key_filename: + argv.extend(['-i', self.host.root_ssh_key_filename]) + elif self.host.root_password: + self.log.critical('Password authentication not supported') + raise RuntimeError('Password authentication not supported') + else: + self.log.critical('No SSH credentials configured') + raise RuntimeError('No SSH credentials configured') + + argv.append(self.host.external_hostname) + self.log.debug('SSH invocation: %s', argv) + + return argv + + def start_shell(self, argv, log_stdout=True): + self.log.info('RUN %s', argv) + command = self._run(['/bin/bash'], argv=argv, log_stdout=log_stdout) + return command + + def _run(self, command, log_stdout=True, argv=None, collect_output=True): + """Run the given command on the remote host + + :param command: Command to run (appended to the common SSH invocation) + :param log_stdout: If false, stdout will not be logged + :param argv: Command to log (if different from ``command`` + :param collect_output: If false, no output will be collected + """ + if argv is None: + argv = command + logger_name = self.get_next_command_logger_name() + ssh = SSHCallWrapper(self.ssh_argv + list(command)) + return SSHCommand(ssh, argv, logger_name, log_stdout=log_stdout, + collect_output=collect_output) + + def file_exists(self, path): + self.log.info('STAT %s', path) + cmd = self._run(['/usr/bin/ls', path], log_stdout=False) + cmd.wait(raiseonerr=False) + return cmd.returncode == 0 + + def mkdir(self, path): + self.log.info('MKDIR %s', path) + cmd = self._run(['/usr/bin/mkdir', path]) + cmd.wait() + + def put_file_contents(self, filename, contents): + self.log.info('PUT %s', filename) + cmd = self._run(['/usr/bin/tee', filename], log_stdout=False) + cmd.stdin.write(contents) + cmd.wait() + assert cmd.stdout_text == contents + + def get_file_contents(self, filename): + self.log.info('GET %s', filename) + cmd = self._run(['/usr/bin/cat', filename], log_stdout=False) + cmd.wait(raiseonerr=False) + if cmd.returncode == 0: + return cmd.stdout_text + else: + raise IOError('File %r could not be read' % filename) + + +class SSHCallWrapper(object): + """Adapts a /usr/bin/ssh call to the paramiko.Channel interface + + This only wraps what SSHCommand needs. + """ + def __init__(self, command): + self.command = command + + def invoke_shell(self): + self.command = subprocess.Popen( + self.command, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + def makefile(self, mode): + return { + 'wb': self.command.stdin, + 'rb': self.command.stdout, + }[mode] + + def makefile_stderr(self, mode): + assert mode == 'rb' + return self.command.stderr + + def shutdown_write(self): + self.command.stdin.close() + + def recv_exit_status(self): + return self.command.wait() + + def close(self): + return self.command.wait() + + class SSHCommand(Command): """Command implementation for ParamikoTransport and OpenSSHTranspport""" def __init__(self, ssh, argv, logger_name, log_stdout=True, @@ -301,3 +433,9 @@ class SSHCommand(Command): self.running_threads.add(thread) thread.start() return thread + + +if not have_paramiko or os.environ.get('IPA_TEST_SSH_TRANSPORT') == 'openssh': + SSHTransport = OpenSSHTransport +else: + SSHTransport = ParamikoTransport