Merge pull request #6680 from mhinz/listen/localhost

Use uv_getaddrinfo() for servers
This commit is contained in:
James McCoy 2017-05-28 13:26:06 +00:00 committed by GitHub
commit 9cc185dc6d
6 changed files with 169 additions and 78 deletions

View File

@ -6460,11 +6460,20 @@ serverlist() *serverlist()*
nvim --cmd "echo serverlist()" --cmd "q" nvim --cmd "echo serverlist()" --cmd "q"
< <
serverstart([{address}]) *serverstart()* serverstart([{address}]) *serverstart()*
Opens a named pipe or TCP socket at {address} for clients to Opens a TCP socket (IPv4/IPv6), Unix domain socket (Unix),
connect to and returns {address}. If no address is given, it or named pipe (Windows) at {address} for clients to connect
is equivalent to: > to and returns {address}.
If {address} contains `:`, a TCP socket is used. Everything in
front of the last occurrence of `:` is the IP or hostname,
everything after it the port. If the port is empty or `0`,
a random port will be assigned.
If no address is given, it is equivalent to: >
:call serverstart(tempname()) :call serverstart(tempname())
< |$NVIM_LISTEN_ADDRESS| is set to {address} if not already set. < |$NVIM_LISTEN_ADDRESS| is set to {address} if not already set.
*--servername* *--servername*
The Vim command-line option `--servername` can be imitated: > The Vim command-line option `--servername` can be imitated: >
nvim --cmd "let g:server_addr = serverstart('foo')" nvim --cmd "let g:server_addr = serverstart('foo')"

View File

@ -14321,22 +14321,39 @@ static void f_serverstart(typval_T *argvars, typval_T *rettv, FunPtr fptr)
return; return;
} }
char *address;
// If the user supplied an address, use it, otherwise use a temp. // If the user supplied an address, use it, otherwise use a temp.
if (argvars[0].v_type != VAR_UNKNOWN) { if (argvars[0].v_type != VAR_UNKNOWN) {
if (argvars[0].v_type != VAR_STRING) { if (argvars[0].v_type != VAR_STRING) {
EMSG(_(e_invarg)); EMSG(_(e_invarg));
return; return;
} else { } else {
rettv->vval.v_string = (char_u *)xstrdup(tv_get_string(argvars)); address = xstrdup(tv_get_string(argvars));
} }
} else { } else {
rettv->vval.v_string = (char_u *)server_address_new(); address = server_address_new();
} }
int result = server_start((char *) rettv->vval.v_string); int result = server_start(address);
xfree(address);
if (result != 0) { if (result != 0) {
EMSG2("Failed to start server: %s", uv_strerror(result)); EMSG2("Failed to start server: %s",
result > 0 ? "Unknonwn system error" : uv_strerror(result));
return;
} }
// Since it's possible server_start adjusted the given {address} (e.g.,
// "localhost:" will now have a port), return the final value to the user.
size_t n;
char **addrs = server_address_list(&n);
rettv->vval.v_string = (char_u *)addrs[n - 1];
n--;
for (size_t i = 0; i < n; i++) {
xfree(addrs[i]);
}
xfree(addrs);
} }
/// "serverstop()" function /// "serverstop()" function

View File

@ -17,60 +17,53 @@
#include "nvim/path.h" #include "nvim/path.h"
#include "nvim/memory.h" #include "nvim/memory.h"
#include "nvim/macros.h" #include "nvim/macros.h"
#include "nvim/charset.h"
#include "nvim/log.h"
#ifdef INCLUDE_GENERATED_DECLARATIONS #ifdef INCLUDE_GENERATED_DECLARATIONS
# include "event/socket.c.generated.h" # include "event/socket.c.generated.h"
#endif #endif
#define NVIM_DEFAULT_TCP_PORT 7450 int socket_watcher_init(Loop *loop, SocketWatcher *watcher,
const char *endpoint)
void socket_watcher_init(Loop *loop, SocketWatcher *watcher, FUNC_ATTR_NONNULL_ALL
const char *endpoint, void *data)
FUNC_ATTR_NONNULL_ARG(1) FUNC_ATTR_NONNULL_ARG(2) FUNC_ATTR_NONNULL_ARG(3)
{ {
// Trim to `ADDRESS_MAX_SIZE` xstrlcpy(watcher->addr, endpoint, sizeof(watcher->addr));
if (xstrlcpy(watcher->addr, endpoint, sizeof(watcher->addr)) char *addr = watcher->addr;
>= sizeof(watcher->addr)) { char *host_end = strrchr(addr, ':');
// TODO(aktau): since this is not what the user wanted, perhaps we
// should return an error here
WLOG("Address was too long, truncated to %s", watcher->addr);
}
bool tcp = true; if (host_end && addr != host_end) {
char ip[16], *ip_end = xstrchrnul(watcher->addr, ':'); // Split user specified address into two strings, addr(hostname) and port.
// The port part in watcher->addr will be updated later.
*host_end = '\0';
char *port = host_end + 1;
intmax_t iport;
// (ip_end - addr) is always > 0, so convert to size_t int ret = getdigits_safe(&(char_u *){ (char_u *)port }, &iport);
size_t addr_len = (size_t)(ip_end - watcher->addr); if (ret == FAIL || iport < 0 || iport > UINT16_MAX) {
ELOG("Invalid port: %s", port);
if (addr_len > sizeof(ip) - 1) { return UV_EINVAL;
// Maximum length of an IPv4 address buffer is 15 (eg: 255.255.255.255)
addr_len = sizeof(ip) - 1;
}
// Extract the address part
xstrlcpy(ip, watcher->addr, addr_len + 1);
int port = NVIM_DEFAULT_TCP_PORT;
if (*ip_end == ':') {
// Extract the port
long lport = strtol(ip_end + 1, NULL, 10); // NOLINT
if (lport <= 0 || lport > 0xffff) {
// Invalid port, treat as named pipe or unix socket
tcp = false;
} else {
port = (int) lport;
} }
}
if (tcp) { if (*port == NUL) {
// Try to parse ip address // When no port is given, (uv_)getaddrinfo expects NULL otherwise the
if (uv_ip4_addr(ip, port, &watcher->uv.tcp.addr)) { // implementation may attempt to lookup the service by name (and fail)
// Invalid address, treat as named pipe or unix socket port = NULL;
tcp = false;
} }
}
if (tcp) { uv_getaddrinfo_t request;
int retval = uv_getaddrinfo(&loop->uv, &request, NULL, addr, port,
&(struct addrinfo){
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
});
if (retval != 0) {
ELOG("Host lookup failed: %s", endpoint);
return retval;
}
watcher->uv.tcp.addrinfo = request.addrinfo;
uv_tcp_init(&loop->uv, &watcher->uv.tcp.handle); uv_tcp_init(&loop->uv, &watcher->uv.tcp.handle);
watcher->stream = STRUCT_CAST(uv_stream_t, &watcher->uv.tcp.handle); watcher->stream = STRUCT_CAST(uv_stream_t, &watcher->uv.tcp.handle);
} else { } else {
@ -82,33 +75,60 @@ void socket_watcher_init(Loop *loop, SocketWatcher *watcher,
watcher->cb = NULL; watcher->cb = NULL;
watcher->close_cb = NULL; watcher->close_cb = NULL;
watcher->events = NULL; watcher->events = NULL;
watcher->data = NULL;
return 0;
} }
int socket_watcher_start(SocketWatcher *watcher, int backlog, socket_cb cb) int socket_watcher_start(SocketWatcher *watcher, int backlog, socket_cb cb)
FUNC_ATTR_NONNULL_ALL FUNC_ATTR_NONNULL_ALL
{ {
watcher->cb = cb; watcher->cb = cb;
int result; int result = UV_EINVAL;
if (watcher->stream->type == UV_TCP) { if (watcher->stream->type == UV_TCP) {
result = uv_tcp_bind(&watcher->uv.tcp.handle, struct addrinfo *ai = watcher->uv.tcp.addrinfo;
(const struct sockaddr *)&watcher->uv.tcp.addr, 0);
for (; ai; ai = ai->ai_next) {
result = uv_tcp_bind(&watcher->uv.tcp.handle, ai->ai_addr, 0);
if (result != 0) {
continue;
}
result = uv_listen(watcher->stream, backlog, connection_cb);
if (result == 0) {
struct sockaddr_storage sas;
// When the endpoint in socket_watcher_init() didn't specify a port
// number, a free random port number will be assigned. sin_port will
// contain 0 in this case, unless uv_tcp_getsockname() is used first.
uv_tcp_getsockname(&watcher->uv.tcp.handle, (struct sockaddr *)&sas,
&(int){ sizeof(sas) });
uint16_t port = (uint16_t)((sas.ss_family == AF_INET)
? ((struct sockaddr_in *)&sas)->sin_port
: ((struct sockaddr_in6 *)&sas)->sin6_port);
// v:servername uses the string from watcher->addr
size_t len = strlen(watcher->addr);
snprintf(watcher->addr+len, sizeof(watcher->addr)-len, ":%" PRIu16,
ntohs(port));
break;
}
}
uv_freeaddrinfo(watcher->uv.tcp.addrinfo);
} else { } else {
result = uv_pipe_bind(&watcher->uv.pipe.handle, watcher->addr); result = uv_pipe_bind(&watcher->uv.pipe.handle, watcher->addr);
} if (result == 0) {
result = uv_listen(watcher->stream, backlog, connection_cb);
if (result == 0) { }
result = uv_listen(watcher->stream, backlog, connection_cb);
} }
assert(result <= 0); // libuv should return negative error code or zero. assert(result <= 0); // libuv should return negative error code or zero.
if (result < 0) { if (result < 0) {
if (result == -EACCES) { if (result == UV_EACCES) {
// Libuv converts ENOENT to EACCES for Windows compatibility, but if // Libuv converts ENOENT to EACCES for Windows compatibility, but if
// the parent directory does not exist, ENOENT would be more accurate. // the parent directory does not exist, ENOENT would be more accurate.
*path_tail((char_u *)watcher->addr) = NUL; *path_tail((char_u *)watcher->addr) = NUL;
if (!os_path_exists((char_u *)watcher->addr)) { if (!os_path_exists((char_u *)watcher->addr)) {
result = -ENOENT; result = UV_ENOENT;
} }
} }
return result; return result;

View File

@ -20,7 +20,7 @@ struct socket_watcher {
union { union {
struct { struct {
uv_tcp_t handle; uv_tcp_t handle;
struct sockaddr_in addr; struct addrinfo *addrinfo;
} tcp; } tcp;
struct { struct {
uv_pipe_t handle; uv_pipe_t handle;

View File

@ -97,37 +97,47 @@ char *server_address_new(void)
#endif #endif
} }
/// Starts listening for API calls on the TCP address or pipe path `endpoint`. /// Starts listening for API calls.
/// The socket type is determined by parsing `endpoint`: If it's a valid IPv4
/// address in 'ip[:port]' format, then it will be TCP socket. The port is
/// optional and if omitted defaults to NVIM_DEFAULT_TCP_PORT. Otherwise it
/// will be a unix socket or named pipe.
/// ///
/// @param endpoint Address of the server. Either a 'ip[:port]' string or an /// The socket type is determined by parsing `endpoint`: If it's a valid IPv4
/// arbitrary identifier (trimmed to 256 bytes) for the unix socket or /// or IPv6 address in 'ip:[port]' format, then it will be a TCP socket.
/// named pipe. /// Otherwise it will be a Unix socket or named pipe (Windows).
///
/// If no port is given, a random one will be assigned.
///
/// @param endpoint Address of the server. Either a 'ip:[port]' string or an
/// arbitrary identifier (trimmed to 256 bytes) for the Unix
/// socket or named pipe.
/// @returns 0 on success, 1 on a regular error, and negative errno /// @returns 0 on success, 1 on a regular error, and negative errno
/// on failure to bind or connect. /// on failure to bind or listen.
int server_start(const char *endpoint) int server_start(const char *endpoint)
{ {
if (endpoint == NULL) { if (endpoint == NULL || endpoint[0] == '\0') {
ELOG("Attempting to start server on NULL endpoint"); ELOG("Empty or NULL endpoint");
return 1; return 1;
} }
SocketWatcher *watcher = xmalloc(sizeof(SocketWatcher)); SocketWatcher *watcher = xmalloc(sizeof(SocketWatcher));
socket_watcher_init(&main_loop, watcher, endpoint, NULL);
int result = socket_watcher_init(&main_loop, watcher, endpoint);
if (result < 0) {
xfree(watcher);
return result;
}
// Check if a watcher for the endpoint already exists // Check if a watcher for the endpoint already exists
for (int i = 0; i < watchers.ga_len; i++) { for (int i = 0; i < watchers.ga_len; i++) {
if (!strcmp(watcher->addr, ((SocketWatcher **)watchers.ga_data)[i]->addr)) { if (!strcmp(watcher->addr, ((SocketWatcher **)watchers.ga_data)[i]->addr)) {
ELOG("Already listening on %s", watcher->addr); ELOG("Already listening on %s", watcher->addr);
if (watcher->stream->type == UV_TCP) {
uv_freeaddrinfo(watcher->uv.tcp.addrinfo);
}
socket_watcher_close(watcher, free_server); socket_watcher_close(watcher, free_server);
return 1; return 1;
} }
} }
int result = socket_watcher_start(watcher, MAX_CONNECTIONS, connection_cb); result = socket_watcher_start(watcher, MAX_CONNECTIONS, connection_cb);
if (result < 0) { if (result < 0) {
ELOG("Failed to start server: %s", uv_strerror(result)); ELOG("Failed to start server: %s", uv_strerror(result));
socket_watcher_close(watcher, free_server); socket_watcher_close(watcher, free_server);

View File

@ -1,20 +1,27 @@
local helpers = require('test.functional.helpers')(after_each) local helpers = require('test.functional.helpers')(after_each)
local nvim, eq, neq, eval = helpers.nvim, helpers.eq, helpers.neq, helpers.eval local eq, neq, eval = helpers.eq, helpers.neq, helpers.eval
local command = helpers.command
local clear, funcs, meths = helpers.clear, helpers.funcs, helpers.meths local clear, funcs, meths = helpers.clear, helpers.funcs, helpers.meths
local os_name = helpers.os_name local os_name = helpers.os_name
local function clear_serverlist()
for _, server in pairs(funcs.serverlist()) do
funcs.serverstop(server)
end
end
describe('serverstart(), serverstop()', function() describe('serverstart(), serverstop()', function()
before_each(clear) before_each(clear)
it('sets $NVIM_LISTEN_ADDRESS on first invocation', function() it('sets $NVIM_LISTEN_ADDRESS on first invocation', function()
-- Unset $NVIM_LISTEN_ADDRESS -- Unset $NVIM_LISTEN_ADDRESS
nvim('command', 'let $NVIM_LISTEN_ADDRESS = ""') command('let $NVIM_LISTEN_ADDRESS = ""')
local s = eval('serverstart()') local s = eval('serverstart()')
assert(s ~= nil and s:len() > 0, "serverstart() returned empty") assert(s ~= nil and s:len() > 0, "serverstart() returned empty")
eq(s, eval('$NVIM_LISTEN_ADDRESS')) eq(s, eval('$NVIM_LISTEN_ADDRESS'))
nvim('command', "call serverstop('"..s.."')") command("call serverstop('"..s.."')")
eq('', eval('$NVIM_LISTEN_ADDRESS')) eq('', eval('$NVIM_LISTEN_ADDRESS'))
end) end)
@ -47,10 +54,38 @@ describe('serverstart(), serverstop()', function()
end) end)
it('serverstop() ignores invalid input', function() it('serverstop() ignores invalid input', function()
nvim('command', "call serverstop('')") command("call serverstop('')")
nvim('command', "call serverstop('bogus-socket-name')") command("call serverstop('bogus-socket-name')")
end) end)
it('parses endpoints correctly', function()
clear_serverlist()
eq({}, funcs.serverlist())
local s = funcs.serverstart('127.0.0.1:0') -- assign random port
assert(string.match(s, '127.0.0.1:%d+'))
eq(s, funcs.serverlist()[1])
clear_serverlist()
s = funcs.serverstart('127.0.0.1:') -- assign random port
assert(string.match(s, '127.0.0.1:%d+'))
eq(s, funcs.serverlist()[1])
clear_serverlist()
funcs.serverstart('127.0.0.1:12345')
funcs.serverstart('127.0.0.1:12345') -- exists already; ignore
funcs.serverstart('::1:12345')
funcs.serverstart('::1:12345') -- exists already; ignore
local expected = {
'127.0.0.1:12345',
'::1:12345',
}
eq(expected, funcs.serverlist())
clear_serverlist()
funcs.serverstart('127.0.0.1:65536') -- invalid port
eq({}, funcs.serverlist())
end)
end) end)
describe('serverlist()', function() describe('serverlist()', function()
@ -75,7 +110,7 @@ describe('serverlist()', function()
-- The new servers should be at the end of the list. -- The new servers should be at the end of the list.
for i = 1, #servs do for i = 1, #servs do
eq(servs[i], new_servs[i + n]) eq(servs[i], new_servs[i + n])
nvim('command', "call serverstop('"..servs[i].."')") command("call serverstop('"..servs[i].."')")
end end
-- After serverstop() the servers should NOT be in the list. -- After serverstop() the servers should NOT be in the list.
eq(n, eval('len(serverlist())')) eq(n, eval('len(serverlist())'))