Fix sending/receiving of FDs when stream returns EAGAIN

The code calling sendfd/recvfd was mistakenly assuming those
calls would never block. They can in fact return EAGAIN and
this is causing us to drop the client connection when blocking
ocurrs while sending/receiving FDs.

Fixing this is a little hairy on the incoming side, since at
the point where we see the EAGAIN, we already thought we had
finished receiving all data for the packet. So we play a little
trick to reset bufferOffset again and go back into polling for
more data.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Update
  virNetSocketSendFD/RecvFD to return 0 on EAGAIN, or 1
  on success
* src/rpc/virnetclient.c: Move decoding of header & fds
  out of virNetClientCallDispatch and into virNetClientIOHandleInput.
  Handling blocking when sending/receiving FDs
* src/rpc/virnetmessage.h: Add a 'donefds' field to track
  how many FDs we've sent / received
* src/rpc/virnetserverclient.c: Handling blocking when
  sending/receiving FDs
This commit is contained in:
Daniel P. Berrange 2011-11-04 16:02:14 +00:00
parent 4d970fd293
commit b2c6231647
5 changed files with 125 additions and 53 deletions

View File

@ -694,10 +694,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
static int static int
virNetClientCallDispatch(virNetClientPtr client) virNetClientCallDispatch(virNetClientPtr client)
{ {
size_t i;
if (virNetMessageDecodeHeader(&client->msg) < 0)
return -1;
PROBE(RPC_CLIENT_MSG_RX, PROBE(RPC_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u", "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
client, client->msg.bufferLength, client, client->msg.bufferLength,
@ -706,15 +702,7 @@ virNetClientCallDispatch(virNetClientPtr client)
switch (client->msg.header.type) { switch (client->msg.header.type) {
case VIR_NET_REPLY: /* Normal RPC replies */ case VIR_NET_REPLY: /* Normal RPC replies */
return virNetClientCallDispatchReply(client);
case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */ case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
return -1;
for (i = 0 ; i < client->msg.nfds ; i++) {
if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
return -1;
}
return virNetClientCallDispatchReply(client); return virNetClientCallDispatchReply(client);
case VIR_NET_MESSAGE: /* Async notifications */ case VIR_NET_MESSAGE: /* Async notifications */
@ -737,22 +725,29 @@ static ssize_t
virNetClientIOWriteMessage(virNetClientPtr client, virNetClientIOWriteMessage(virNetClientPtr client,
virNetClientCallPtr thecall) virNetClientCallPtr thecall)
{ {
ssize_t ret; ssize_t ret = 0;
ret = virNetSocketWrite(client->sock, if (thecall->msg->bufferOffset < thecall->msg->bufferLength) {
thecall->msg->buffer + thecall->msg->bufferOffset, ret = virNetSocketWrite(client->sock,
thecall->msg->bufferLength - thecall->msg->bufferOffset); thecall->msg->buffer + thecall->msg->bufferOffset,
if (ret <= 0) thecall->msg->bufferLength - thecall->msg->bufferOffset);
return ret; if (ret <= 0)
return ret;
thecall->msg->bufferOffset += ret; thecall->msg->bufferOffset += ret;
}
if (thecall->msg->bufferOffset == thecall->msg->bufferLength) { if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
size_t i; size_t i;
for (i = 0 ; i < thecall->msg->nfds ; i++) { for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) {
if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0) int rv;
if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i])) < 0)
return -1; return -1;
if (rv == 0) /* Blocking */
return 0;
thecall->msg->donefds++;
} }
thecall->msg->donefds = 0;
thecall->msg->bufferOffset = thecall->msg->bufferLength = 0; thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
if (thecall->expectReply) if (thecall->expectReply)
thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX; thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
@ -821,12 +816,16 @@ virNetClientIOHandleInput(virNetClientPtr client)
* EAGAIN * EAGAIN
*/ */
for (;;) { for (;;) {
ssize_t ret = virNetClientIOReadMessage(client); ssize_t ret;
if (ret < 0) if (client->msg.nfds == 0) {
return -1; ret = virNetClientIOReadMessage(client);
if (ret == 0)
return 0; /* Blocking on read */ if (ret < 0)
return -1;
if (ret == 0)
return 0; /* Blocking on read */
}
/* Check for completion of our goal */ /* Check for completion of our goal */
if (client->msg.bufferOffset == client->msg.bufferLength) { if (client->msg.bufferOffset == client->msg.bufferLength) {
@ -842,6 +841,33 @@ virNetClientIOHandleInput(virNetClientPtr client)
* next iteration. * next iteration.
*/ */
} else { } else {
if (virNetMessageDecodeHeader(&client->msg) < 0)
return -1;
if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) {
size_t i;
if (virNetMessageDecodeNumFDs(&client->msg) < 0)
return -1;
for (i = client->msg.donefds ; i < client->msg.nfds ; i++) {
int rv;
if ((rv = virNetSocketRecvFD(client->sock, &(client->msg.fds[i]))) < 0)
return -1;
if (rv == 0) /* Blocking */
break;
client->msg.donefds++;
}
if (client->msg.donefds < client->msg.nfds) {
/* Because DecodeHeader/NumFDs reset bufferOffset, we
* put it back to what it was, so everything works
* again next time we run this method
*/
client->msg.bufferOffset = client->msg.bufferLength;
return 0; /* Blocking on more fds */
}
}
ret = virNetClientCallDispatch(client); ret = virNetClientCallDispatch(client);
client->msg.bufferOffset = client->msg.bufferLength = 0; client->msg.bufferOffset = client->msg.bufferLength = 0;
/* /*
@ -1257,6 +1283,7 @@ int virNetClientSend(virNetClientPtr client,
goto cleanup; goto cleanup;
} }
msg->donefds = 0;
if (msg->bufferLength) if (msg->bufferLength)
call->mode = VIR_NET_CLIENT_MODE_WAIT_TX; call->mode = VIR_NET_CLIENT_MODE_WAIT_TX;
else else

View File

@ -48,6 +48,7 @@ struct _virNetMessage {
size_t nfds; size_t nfds;
int *fds; int *fds;
size_t donefds;
virNetMessagePtr next; virNetMessagePtr next;
}; };

View File

@ -771,9 +771,11 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client)
static void virNetServerClientDispatchRead(virNetServerClientPtr client) static void virNetServerClientDispatchRead(virNetServerClientPtr client)
{ {
readmore: readmore:
if (virNetServerClientRead(client) < 0) { if (client->rx->nfds == 0) {
client->wantClose = true; if (virNetServerClientRead(client) < 0) {
return; /* Error */ client->wantClose = true;
return; /* Error */
}
} }
if (client->rx->bufferOffset < client->rx->bufferLength) if (client->rx->bufferOffset < client->rx->bufferLength)
@ -794,7 +796,7 @@ readmore:
goto readmore; goto readmore;
} else { } else {
/* Grab the completed message */ /* Grab the completed message */
virNetMessagePtr msg = virNetMessageQueueServe(&client->rx); virNetMessagePtr msg = client->rx;
virNetServerClientFilterPtr filter; virNetServerClientFilterPtr filter;
size_t i; size_t i;
@ -805,20 +807,40 @@ readmore:
return; return;
} }
/* Now figure out if we need to read more data to get some
* file descriptors */
if (msg->header.type == VIR_NET_CALL_WITH_FDS && if (msg->header.type == VIR_NET_CALL_WITH_FDS &&
virNetMessageDecodeNumFDs(msg) < 0) { virNetMessageDecodeNumFDs(msg) < 0) {
virNetMessageFree(msg); virNetMessageFree(msg);
client->wantClose = true; client->wantClose = true;
return; return; /* Error */
} }
for (i = 0 ; i < msg->nfds ; i++) {
if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) { /* Try getting the file descriptors (may fail if blocking) */
for (i = msg->donefds ; i < msg->nfds ; i++) {
int rv;
if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
virNetMessageFree(msg); virNetMessageFree(msg);
client->wantClose = true; client->wantClose = true;
return; return;
} }
if (rv == 0) /* Blocking */
break;
msg->donefds++;
} }
/* Need to poll() until FDs arrive */
if (msg->donefds < msg->nfds) {
/* Because DecodeHeader/NumFDs reset bufferOffset, we
* put it back to what it was, so everything works
* again next time we run this method
*/
client->rx->bufferOffset = client->rx->bufferLength;
return;
}
/* Definitely finished reading, so remove from queue */
virNetMessageQueueServe(&client->rx);
PROBE(RPC_SERVER_CLIENT_MSG_RX, PROBE(RPC_SERVER_CLIENT_MSG_RX,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u", "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
client, msg->bufferLength, client, msg->bufferLength,
@ -912,25 +934,30 @@ static void
virNetServerClientDispatchWrite(virNetServerClientPtr client) virNetServerClientDispatchWrite(virNetServerClientPtr client)
{ {
while (client->tx) { while (client->tx) {
ssize_t ret; if (client->tx->bufferOffset < client->tx->bufferLength) {
ssize_t ret;
ret = virNetServerClientWrite(client); ret = virNetServerClientWrite(client);
if (ret < 0) { if (ret < 0) {
client->wantClose = true; client->wantClose = true;
return; return;
}
if (ret == 0)
return; /* Would block on write EAGAIN */
} }
if (ret == 0)
return; /* Would block on write EAGAIN */
if (client->tx->bufferOffset == client->tx->bufferLength) { if (client->tx->bufferOffset == client->tx->bufferLength) {
virNetMessagePtr msg; virNetMessagePtr msg;
size_t i; size_t i;
for (i = 0 ; i < client->tx->nfds ; i++) { for (i = client->tx->donefds ; i < client->tx->nfds ; i++) {
if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0) { int rv;
if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
client->wantClose = true; client->wantClose = true;
return; return;
} }
if (rv == 0) /* Blocking */
return;
client->tx->donefds++;
} }
#if HAVE_SASL #if HAVE_SASL
@ -1041,6 +1068,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client,
msg->bufferLength, msg->bufferOffset); msg->bufferLength, msg->bufferOffset);
virNetServerClientLock(client); virNetServerClientLock(client);
msg->donefds = 0;
if (client->sock && !client->wantClose) { if (client->sock && !client->wantClose) {
PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE, PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
"client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u", "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",

View File

@ -1142,6 +1142,9 @@ ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
} }
/*
* Returns 1 if an FD was sent, 0 if it would block, -1 on error
*/
int virNetSocketSendFD(virNetSocketPtr sock, int fd) int virNetSocketSendFD(virNetSocketPtr sock, int fd)
{ {
int ret = -1; int ret = -1;
@ -1154,12 +1157,15 @@ int virNetSocketSendFD(virNetSocketPtr sock, int fd)
PROBE(RPC_SOCKET_SEND_FD, PROBE(RPC_SOCKET_SEND_FD,
"sock=%p fd=%d", sock, fd); "sock=%p fd=%d", sock, fd);
if (sendfd(sock->fd, fd) < 0) { if (sendfd(sock->fd, fd) < 0) {
virReportSystemError(errno, if (errno == EAGAIN)
_("Failed to send file descriptor %d"), ret = 0;
fd); else
virReportSystemError(errno,
_("Failed to send file descriptor %d"),
fd);
goto cleanup; goto cleanup;
} }
ret = 0; ret = 1;
cleanup: cleanup:
virMutexUnlock(&sock->lock); virMutexUnlock(&sock->lock);
@ -1167,9 +1173,15 @@ cleanup:
} }
int virNetSocketRecvFD(virNetSocketPtr sock) /*
* Returns 1 if an FD was read, 0 if it would block, -1 on error
*/
int virNetSocketRecvFD(virNetSocketPtr sock, int *fd)
{ {
int ret = -1; int ret = -1;
*fd = -1;
if (!virNetSocketHasPassFD(sock)) { if (!virNetSocketHasPassFD(sock)) {
virNetError(VIR_ERR_INTERNAL_ERROR, virNetError(VIR_ERR_INTERNAL_ERROR,
_("Receiving file descriptors is not supported on this socket")); _("Receiving file descriptors is not supported on this socket"));
@ -1177,13 +1189,17 @@ int virNetSocketRecvFD(virNetSocketPtr sock)
} }
virMutexLock(&sock->lock); virMutexLock(&sock->lock);
if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) { if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) {
virReportSystemError(errno, "%s", if (errno == EAGAIN)
_("Failed to recv file descriptor")); ret = 0;
else
virReportSystemError(errno, "%s",
_("Failed to recv file descriptor"));
goto cleanup; goto cleanup;
} }
PROBE(RPC_SOCKET_RECV_FD, PROBE(RPC_SOCKET_RECV_FD,
"sock=%p fd=%d", sock, ret); "sock=%p fd=%d", sock, *fd);
ret = 1;
cleanup: cleanup:
virMutexUnlock(&sock->lock); virMutexUnlock(&sock->lock);

View File

@ -97,7 +97,7 @@ ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
int virNetSocketSendFD(virNetSocketPtr sock, int fd); int virNetSocketSendFD(virNetSocketPtr sock, int fd);
int virNetSocketRecvFD(virNetSocketPtr sock); int virNetSocketRecvFD(virNetSocketPtr sock, int *fd);
void virNetSocketSetTLSSession(virNetSocketPtr sock, void virNetSocketSetTLSSession(virNetSocketPtr sock,
virNetTLSSessionPtr sess); virNetTLSSessionPtr sess);