diff --git a/tests/virnetsockettest.c b/tests/virnetsockettest.c index e463d432ff..cccb90d0be 100644 --- a/tests/virnetsockettest.c +++ b/tests/virnetsockettest.c @@ -115,6 +115,56 @@ checkProtocols(bool *hasIPv4, bool *hasIPv6, return ret; } +struct testClientData { + const char *path; + const char *cnode; + const char *portstr; +}; + +static void +testSocketClient(void *opaque) +{ + struct testClientData *data = opaque; + char c; + virNetSocketPtr csock = NULL; + + if (data->path) { + if (virNetSocketNewConnectUNIX(data->path, false, + NULL, &csock) < 0) + return; + } else { + if (virNetSocketNewConnectTCP(data->cnode, data->portstr, + AF_UNSPEC, + &csock) < 0) + return; + } + + virNetSocketSetBlocking(csock, true); + + if (virNetSocketRead(csock, &c, 1) != 1) { + VIR_DEBUG("Cannot read from server"); + goto done; + } + if (virNetSocketWrite(csock, &c, 1) != 1) { + VIR_DEBUG("Cannot write to server"); + goto done; + } + + done: + virObjectUnref(csock); +} + + +static void +testSocketIncoming(virNetSocketPtr sock, + int events ATTRIBUTE_UNUSED, + void *opaque) +{ + virNetSocketPtr *retsock = opaque; + VIR_DEBUG("Incoming sock=%p events=%d\n", sock, events); + *retsock = sock; +} + struct testSocketData { const char *lnode; @@ -122,18 +172,25 @@ struct testSocketData { const char *cnode; }; -static int testSocketAccept(const void *opaque) + +static int +testSocketAccept(const void *opaque) { virNetSocketPtr *lsock = NULL; /* Listen socket */ size_t nlsock = 0, i; virNetSocketPtr ssock = NULL; /* Server socket */ - virNetSocketPtr csock = NULL; /* Client socket */ + virNetSocketPtr rsock = NULL; /* Incoming client socket */ const struct testSocketData *data = opaque; int ret = -1; char portstr[100]; char *tmpdir = NULL; char *path = NULL; char template[] = "/tmp/libvirt_XXXXXX"; + virThread th; + struct testClientData cdata = { 0 }; + bool goodsock = false; + char a = 'a'; + char b = '\0'; if (!data) { virNetSocketPtr usock; @@ -155,50 +212,90 @@ static int testSocketAccept(const void *opaque) lsock[0] = usock; nlsock = 1; + + cdata.path = path; } else { snprintf(portstr, sizeof(portstr), "%d", data->port); if (virNetSocketNewListenTCP(data->lnode, portstr, AF_UNSPEC, &lsock, &nlsock) < 0) goto cleanup; + + cdata.cnode = data->cnode; + cdata.portstr = portstr; } for (i = 0; i < nlsock; i++) { if (virNetSocketListen(lsock[i], 0) < 0) goto cleanup; + + if (virNetSocketAddIOCallback(lsock[i], + VIR_EVENT_HANDLE_READABLE, + testSocketIncoming, + &rsock, + NULL) < 0) { + goto cleanup; + } } - if (!data) { - if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0) - goto cleanup; - } else { - if (virNetSocketNewConnectTCP(data->cnode, portstr, - AF_UNSPEC, - &csock) < 0) - goto cleanup; - } + if (virThreadCreate(&th, true, + testSocketClient, + &cdata) < 0) + goto cleanup; - virObjectUnref(csock); + while (rsock == NULL) + virEventRunDefaultImpl(); for (i = 0; i < nlsock; i++) { - if (virNetSocketAccept(lsock[i], &ssock) != -1 && ssock) { - char c = 'a'; - if (virNetSocketWrite(ssock, &c, 1) != -1 && - virNetSocketRead(ssock, &c, 1) != -1) { - VIR_DEBUG("Unexpected client socket present"); - goto cleanup; - } + if (lsock[i] == rsock) { + goodsock = true; + break; } - virObjectUnref(ssock); - ssock = NULL; } + if (!goodsock) { + virReportError(VIR_ERR_INTERNAL_ERROR, "%s", + "Unexpected server socket seen"); + goto join; + } + + if (virNetSocketAccept(rsock, &ssock) < 0) + goto join; + + if (!ssock) { + virReportError(VIR_ERR_INTERNAL_ERROR, "%s", + "Client went away unexpectedly"); + goto join; + } + + virNetSocketSetBlocking(ssock, true); + + if (virNetSocketWrite(ssock, &a, 1) < 0 || + virNetSocketRead(ssock, &b, 1) < 0) { + goto join; + } + + if (a != b) { + virReportError(VIR_ERR_INTERNAL_ERROR, + "Bad data received '%x' != '%x'", a, b); + goto join; + } + + virObjectUnref(ssock); + ssock = NULL; + ret = 0; + join: + virThreadJoin(&th); + cleanup: virObjectUnref(ssock); - for (i = 0; i < nlsock; i++) + for (i = 0; i < nlsock; i++) { + virNetSocketRemoveIOCallback(lsock[i]); + virNetSocketClose(lsock[i]); virObjectUnref(lsock[i]); + } VIR_FREE(lsock); VIR_FREE(path); if (tmpdir) @@ -431,6 +528,8 @@ mymain(void) signal(SIGPIPE, SIG_IGN); + virEventRegisterDefaultImpl(); + #ifdef HAVE_IFADDRS_H if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) { fprintf(stderr, "Cannot identify IPv4/6 availability\n");