diff --git a/virtManager/baseclass.py b/virtManager/baseclass.py index f769650cf..84c9c79dc 100644 --- a/virtManager/baseclass.py +++ b/virtManager/baseclass.py @@ -36,6 +36,19 @@ from gi.repository import Gtk class vmmGObject(GObject.GObject): _leak_check = True + @staticmethod + def idle_add(func, *args, **kwargs): + """ + Make sure idle functions are run thread safe + """ + def cb(): + try: + return func(*args, **kwargs) + except: + print traceback.format_exc() + return False + return GLib.idle_add(cb) + def __init__(self): GObject.GObject.__init__(self) self.config = config.running_config @@ -141,18 +154,6 @@ class vmmGObject(GObject.GObject): self.idle_add(emitwrap, signal, *args) - def idle_add(self, func, *args, **kwargs): - """ - Make sure idle functions are run thread safe - """ - def cb(): - try: - return func(*args, **kwargs) - except: - print traceback.format_exc() - return False - return GLib.idle_add(cb) - def timeout_add(self, timeout, func, *args): """ Make sure timeout functions are run thread safe diff --git a/virtManager/console.py b/virtManager/console.py index 5a043b5fb..c31d35be4 100644 --- a/virtManager/console.py +++ b/virtManager/console.py @@ -31,10 +31,12 @@ from gi.repository import SpiceClientGLib import libvirt +import logging import os +import Queue import signal import socket -import logging +import threading import virtManager.uihelpers as uihelpers from virtManager.autodrawer import AutoDrawer @@ -114,14 +116,105 @@ class ConnectionInfo(object): return int(self.gport) == -1 -class Tunnel(object): +class _TunnelScheduler(object): + """ + If the user is using Spice + SSH URI + no SSH keys, we need to + serialize connection opening otherwise ssh-askpass gets all angry. + This handles the locking and scheduling. + + It's only instantiated once for the whole app, because we serialize + independent of connection, vm, etc. + """ + def __init__(self): + self._thread = threading.Thread(name="Tunnel thread", + target=self._handle_queue, + args=()) + self._thread.daemon = True + self._queue = Queue.Queue() + self._lock = threading.Lock() + + def _handle_queue(self): + while True: + cb, args, = self._queue.get() + self.lock() + vmmGObject.idle_add(cb, *args) + + def schedule(self, cb, *args): + if not self._thread.is_alive(): + self._thread.start() + self._queue.put((cb, args)) + + def lock(self): + self._lock.acquire() + def unlock(self): + self._lock.release() + +_tunnel_sched = _TunnelScheduler() + + +class _Tunnel(object): def __init__(self): self.outfd = None self.errfd = None self.pid = None + self._outfds = None + self._errfds = None + self.closed = False def open(self, ginfo): - if self.outfd is not None: + self._outfds = socket.socketpair() + self._errfds = socket.socketpair() + + return self._outfds[0].fileno(), self._launch_tunnel, ginfo + + def close(self): + if self.closed: + return + self.closed = True + + logging.debug("Close tunnel PID=%s OUTFD=%s ERRFD=%s", + self.pid, + self.outfd and self.outfd.fileno() or self._outfds, + self.errfd and self.errfd.fileno() or self._errfds) + + if self.outfd: + self.outfd.close() + elif self._outfds: + self._outfds[0].close() + self._outfds[1].close() + self.outfd = None + self._outfds = None + + if self.errfd: + self.errfd.close() + elif self._errfds: + self._errfds[0].close() + self._errfds[1].close() + self.errfd = None + self._errfds = None + + if self.pid: + os.kill(self.pid, signal.SIGKILL) + os.waitpid(self.pid, 0) + self.pid = None + + def get_err_output(self): + errout = "" + while True: + try: + new = self.errfd.recv(1024) + except: + break + + if not new: + break + + errout += new + + return errout + + def _launch_tunnel(self, ginfo): + if self.closed: return -1 host, port, ignore = ginfo.get_conn_host() @@ -168,70 +261,33 @@ class Tunnel(object): argv_str = reduce(lambda x, y: x + " " + y, argv[1:]) logging.debug("Creating SSH tunnel: %s", argv_str) - fds = socket.socketpair() - errorfds = socket.socketpair() - pid = os.fork() if pid == 0: - fds[0].close() - errorfds[0].close() + self._outfds[0].close() + self._errfds[0].close() os.close(0) os.close(1) os.close(2) - os.dup(fds[1].fileno()) - os.dup(fds[1].fileno()) - os.dup(errorfds[1].fileno()) + os.dup(self._outfds[1].fileno()) + os.dup(self._outfds[1].fileno()) + os.dup(self._errfds[1].fileno()) os.execlp(*argv) os._exit(1) # pylint: disable=W0212 else: - fds[1].close() - errorfds[1].close() + self._outfds[1].close() + self._errfds[1].close() - logging.debug("Tunnel PID=%d OUTFD=%d ERRFD=%d", - pid, fds[0].fileno(), errorfds[0].fileno()) - errorfds[0].setblocking(0) + logging.debug("Open tunnel PID=%d OUTFD=%d ERRFD=%d", + pid, self._outfds[0].fileno(), self._errfds[0].fileno()) + self._errfds[0].setblocking(0) - self.outfd = fds[0] - self.errfd = errorfds[0] + self.outfd = self._outfds[0] + self.errfd = self._errfds[0] + self._outfds = None + self._errfds = None self.pid = pid - fd = fds[0].fileno() - if fd < 0: - raise SystemError("can't open a new tunnel: fd=%d" % fd) - return fd - - def close(self): - if self.outfd is None: - return - - logging.debug("Shutting down tunnel PID=%d OUTFD=%d ERRFD=%d", - self.pid, self.outfd.fileno(), - self.errfd.fileno()) - self.outfd.close() - self.outfd = None - self.errfd.close() - self.errfd = None - - os.kill(self.pid, signal.SIGKILL) - os.waitpid(self.pid, 0) - self.pid = None - - def get_err_output(self): - errout = "" - while True: - try: - new = self.errfd.recv(1024) - except: - break - - if not new: - break - - errout += new - - return errout - class Tunnels(object): def __init__(self, ginfo): @@ -239,9 +295,11 @@ class Tunnels(object): self._tunnels = [] def open_new(self): - t = Tunnel() - fd = t.open(self.ginfo) + t = _Tunnel() + fd, cb, args = t.open(self.ginfo) self._tunnels.append(t) + _tunnel_sched.schedule(cb, args) + return fd def close_all(self): @@ -254,6 +312,9 @@ class Tunnels(object): errout += l.get_err_output() return errout + lock = _tunnel_sched.lock + unlock = _tunnel_sched.unlock + class Viewer(vmmGObject): def __init__(self, console): @@ -275,6 +336,12 @@ class Viewer(vmmGObject): def get_pixbuf(self): return self.display.get_pixbuf() + def open_ginfo(self, ginfo): + if ginfo.need_tunnel(): + self.open_fd(self.console.tunnels.open_new()) + else: + self.open_host(ginfo) + def get_grab_keys(self): raise NotImplementedError() @@ -284,10 +351,10 @@ class Viewer(vmmGObject): def send_keys(self, keys): raise NotImplementedError() - def open_host(self, ginfo, password=None): + def open_host(self, ginfo): raise NotImplementedError() - def open_fd(self, fd, password=None): + def open_fd(self, fd): raise NotImplementedError() def get_desktop_resolution(self): @@ -306,6 +373,8 @@ class VNCViewer(Viewer): # Last noticed desktop resolution self.desktop_resolution = None + self._tunnel_unlocked = False + def init_widget(self): self.set_grab_keys() @@ -320,18 +389,32 @@ class VNCViewer(Viewer): self.display.set_pointer_grab(True) self.display.connect("vnc-pointer-grab", self.console.pointer_grabbed) - self.display.connect("vnc-pointer-ungrab", self.console.pointer_ungrabbed) + self.display.connect("vnc-pointer-ungrab", + self.console.pointer_ungrabbed) self.display.connect("vnc-auth-credential", self._auth_credential) - self.display.connect("vnc-initialized", - lambda src: self.console.connected()) - self.display.connect("vnc-disconnected", - lambda src: self.console.disconnected()) + self.display.connect("vnc-initialized", self._connected_cb) + self.display.connect("vnc-disconnected", self._disconnected_cb) self.display.connect("vnc-desktop-resize", self._desktop_resize) - self.display.connect("focus-in-event", self.console.viewer_focus_changed) - self.display.connect("focus-out-event", self.console.viewer_focus_changed) + self.display.connect("focus-in-event", + self.console.viewer_focus_changed) + self.display.connect("focus-out-event", + self.console.viewer_focus_changed) self.display.show() + def _unlock_tunnel(self): + if self.console.tunnels and not self._tunnel_unlocked: + self.console.tunnels.unlock() + self._tunnel_unlocked = True + + def _connected_cb(self, ignore): + self._unlock_tunnel() + self.console.connected() + + def _disconnected_cb(self, ignore): + self._unlock_tunnel() + self.console.disconnected() + def get_grab_keys(self): return self.display.get_grab_keys().as_string() @@ -421,7 +504,7 @@ class VNCViewer(Viewer): def is_open(self): return self.display.is_open() - def open_host(self, ginfo, password=None): + def open_host(self, ginfo): host, port, ignore = ginfo.get_conn_host() if not ginfo.gsocket: @@ -444,8 +527,7 @@ class VNCViewer(Viewer): ginfo.gsocket) + " fd=%s" % fd) self.open_fd(fd) - def open_fd(self, fd, password=None): - ignore = password + def open_fd(self, fd): self.display.open_fd(fd) def set_credential_username(self, cred): @@ -469,8 +551,10 @@ class SpiceViewer(Viewer): self.console.refresh_scaling() self.display.realize() - self.display.connect("mouse-grab", lambda src, g: g and self.console.pointer_grabbed(src)) - self.display.connect("mouse-grab", lambda src, g: g or self.console.pointer_ungrabbed(src)) + self.display.connect("mouse-grab", + lambda src, g: g and self.console.pointer_grabbed(src)) + self.display.connect("mouse-grab", + lambda src, g: g or self.console.pointer_ungrabbed(src)) self.display.connect("focus-in-event", self.console.viewer_focus_changed) @@ -534,11 +618,19 @@ class SpiceViewer(Viewer): logging.debug("Spice channel event error: %s", event) self.console.disconnected() + def _fd_channel_event_cb(self, channel, event): + # When we see any event from the channel, release the + # associated tunnel lock + channel.disconnect_by_func(self._fd_channel_event_cb) + self.console.tunnels.unlock() + def _channel_open_fd_request(self, channel, tls_ignore): if not self.console.tunnels: raise SystemError("Got fd request with no configured tunnel!") logging.debug("Opening tunnel for channel: %s", channel) + channel.connect_after("channel-event", self._fd_channel_event_cb) + fd = self.console.tunnels.open_new() channel.open_fd(fd) @@ -547,6 +639,8 @@ class SpiceViewer(Viewer): self._channel_open_fd_request) if type(channel) == SpiceClientGLib.MainChannel: + if self.console.tunnels: + self.console.tunnels.unlock() channel.connect_after("channel-event", self._main_channel_event_cb) return @@ -584,6 +678,9 @@ class SpiceViewer(Viewer): gtk_session = SpiceClientGtk.GtkSession.get(self.spice_session) gtk_session.set_property("auto-clipboard", True) + GObject.GObject.connect(self.spice_session, "channel-new", + self._channel_new_cb) + self.usbdev_manager = SpiceClientGLib.UsbDeviceManager.get( self.spice_session) self.usbdev_manager.connect("auto-connect-failed", @@ -595,26 +692,19 @@ class SpiceViewer(Viewer): if autoredir: gtk_session.set_property("auto-usbredir", True) - def open_host(self, ginfo, password=None): + def open_host(self, ginfo): host, port, tlsport = ginfo.get_conn_host() - self._create_spice_session() + self.spice_session.set_property("host", str(host)) self.spice_session.set_property("port", str(port)) if tlsport: self.spice_session.set_property("tls-port", str(tlsport)) - if password: - self.spice_session.set_property("password", password) - GObject.GObject.connect(self.spice_session, "channel-new", - self._channel_new_cb) + self.spice_session.connect() - def open_fd(self, fd, password=None): + def open_fd(self, fd): self._create_spice_session() - if password: - self.spice_session.set_property("password", password) - GObject.GObject.connect(self.spice_session, "channel-new", - self._channel_new_cb) self.spice_session.open_fd(fd) def set_credential_password(self, cred): @@ -1254,15 +1344,8 @@ class vmmConsolePages(vmmGObjectUI): self.set_enable_accel() if ginfo.need_tunnel(): - if self.tunnels: - # Tunnel already open, no need to continue - return - self.tunnels = Tunnels(ginfo) - self.viewer.open_fd(self.tunnels.open_new()) - else: - self.viewer.open_host(ginfo) - + self.viewer.open_ginfo(ginfo) except Exception, e: logging.exception("Error connection to graphical console") self.activate_unavailable_page(