Implement client tunnel synchronization

Client now polls server state and updates tunnels to match.
This commit is contained in:
Anders Pitman 2020-10-10 09:55:07 -06:00
parent 41bd4759eb
commit eb4d6903c7
2 changed files with 106 additions and 35 deletions

View File

@ -221,7 +221,7 @@ func (p *BoringProxy) handleCreateTunnel(w http.ResponseWriter, r *http.Request)
_, err = p.tunMan.CreateTunnelForClient(domain, clientName, clientPort) _, err = p.tunMan.CreateTunnelForClient(domain, clientName, clientPort)
if err != nil { if err != nil {
w.WriteHeader(400) w.WriteHeader(400)
io.WriteString(w, "Failed to get cert. Ensure your domain is valid") io.WriteString(w, err.Error())
return return
} }

139
client.go
View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
@ -12,18 +13,22 @@ import (
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"sync" "sync"
"time"
) )
type BoringProxyClient struct { type BoringProxyClient struct {
httpClient *http.Client
tunnels map[string]Tunnel
previousEtag string
server string
token string
clientName string
cancelFuncs map[string]context.CancelFunc
cancelFuncsMutex *sync.Mutex
} }
func NewBoringProxyClient() *BoringProxyClient { func NewBoringProxyClient() *BoringProxyClient {
return &BoringProxyClient{}
}
func (c *BoringProxyClient) RunPuppetClient() {
flagSet := flag.NewFlagSet(os.Args[0], flag.ExitOnError) flagSet := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
server := flagSet.String("server", "", "boringproxy server") server := flagSet.String("server", "", "boringproxy server")
token := flagSet.String("token", "", "Access token") token := flagSet.String("token", "", "Access token")
@ -31,58 +36,113 @@ func (c *BoringProxyClient) RunPuppetClient() {
flagSet.Parse(os.Args[2:]) flagSet.Parse(os.Args[2:])
httpClient := &http.Client{} httpClient := &http.Client{}
tunnels := make(map[string]Tunnel)
cancelFuncs := make(map[string]context.CancelFunc)
cancelFuncsMutex := &sync.Mutex{}
url := fmt.Sprintf("https://%s/api/tunnels?client-name=%s", *server, *name) return &BoringProxyClient{
httpClient: httpClient,
tunnels: tunnels,
previousEtag: "",
server: *server,
token: *token,
clientName: *name,
cancelFuncs: cancelFuncs,
cancelFuncsMutex: cancelFuncsMutex,
}
}
func (c *BoringProxyClient) RunPuppetClient() {
for {
c.PollTunnels()
time.Sleep(2 * time.Second)
}
}
func (c *BoringProxyClient) PollTunnels() {
url := fmt.Sprintf("https://%s/api/tunnels?client-name=%s", c.server, c.clientName)
listenReq, err := http.NewRequest("GET", url, nil) listenReq, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
log.Fatal("Failed making request", err) log.Fatal("Failed making request", err)
} }
if len(*token) > 0 { if len(c.token) > 0 {
listenReq.Header.Add("Authorization", "bearer "+*token) listenReq.Header.Add("Authorization", "bearer "+c.token)
} }
resp, err := httpClient.Do(listenReq) resp, err := c.httpClient.Do(listenReq)
if err != nil { if err != nil {
log.Fatal("Failed make tunnel request", err) log.Fatal("Failed listen request", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
log.Fatal("Failed to create tunnel: " + string(body)) log.Fatal("Failed to listen (not 200 status)")
} }
tunnels := make(map[string]Tunnel) etag := resp.Header["Etag"][0]
err = json.Unmarshal(body, &tunnels) if etag != c.previousEtag {
if err != nil {
log.Fatal("Failed to parse response", err) body, err := ioutil.ReadAll(resp.Body)
tunnels := make(map[string]Tunnel)
err = json.Unmarshal(body, &tunnels)
if err != nil {
log.Fatal("Failed to parse response", err)
}
c.SyncTunnels(tunnels)
c.previousEtag = etag
} }
for _, tun := range tunnels { }
go c.BoreTunnel(tun)
func (c *BoringProxyClient) SyncTunnels(serverTunnels map[string]Tunnel) {
fmt.Println("SyncTunnels")
// update tunnels to match server
for k, newTun := range serverTunnels {
tun, exists := c.tunnels[k]
if !exists {
log.Println("New tunnel", k)
c.tunnels[k] = newTun
cancel := c.BoreTunnel(newTun)
c.cancelFuncs[k] = cancel
} else if newTun != tun {
log.Println("Restart tunnel", k)
c.cancelFuncs[k]()
cancel := c.BoreTunnel(newTun)
c.cancelFuncs[k] = cancel
}
} }
//go c.BoreTunnel(tunnels["apitman.com"]) // delete any tunnels that no longer exist on server
for k, _ := range c.tunnels {
sigChan := make(chan os.Signal, 1) _, exists := serverTunnels[k]
signal.Notify(sigChan, os.Interrupt) if !exists {
for range sigChan { log.Println("Kill tunnel", k)
break c.cancelFuncs[k]()
delete(c.tunnels, k)
delete(c.cancelFuncs, k)
}
} }
} }
func (c *BoringProxyClient) BoreTunnel(tun Tunnel) { func (c *BoringProxyClient) BoreTunnel(tun Tunnel) context.CancelFunc {
//log.Println("BoreTunnel", tun)
privKeyFile, err := ioutil.TempFile("", "") privKeyFile, err := ioutil.TempFile("", "")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer os.Remove(privKeyFile.Name())
if _, err := privKeyFile.Write([]byte(tun.TunnelPrivateKey)); err != nil { if _, err := privKeyFile.Write([]byte(tun.TunnelPrivateKey)); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -93,12 +153,23 @@ func (c *BoringProxyClient) BoreTunnel(tun Tunnel) {
tunnelSpec := fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", tun.TunnelPort, tun.ClientPort) tunnelSpec := fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", tun.TunnelPort, tun.ClientPort)
sshLogin := fmt.Sprintf("%s@%s", tun.Username, tun.ServerAddress) sshLogin := fmt.Sprintf("%s@%s", tun.Username, tun.ServerAddress)
serverPortStr := fmt.Sprintf("%d", tun.ServerPort) serverPortStr := fmt.Sprintf("%d", tun.ServerPort)
fmt.Println(tunnelSpec, sshLogin, serverPortStr)
cmd := exec.Command("ssh", "-i", privKeyFile.Name(), "-NR", tunnelSpec, sshLogin, "-p", serverPortStr) ctx, cancelFunc := context.WithCancel(context.Background())
err = cmd.Run()
if err != nil { privKeyPath := privKeyFile.Name()
log.Fatal(err)
} go func() {
// TODO: Clean up private key files on exit
defer os.Remove(privKeyPath)
fmt.Println(privKeyPath, tunnelSpec, sshLogin, serverPortStr)
cmd := exec.CommandContext(ctx, "ssh", "-i", privKeyPath, "-NR", tunnelSpec, sshLogin, "-p", serverPortStr)
err = cmd.Run()
if err != nil {
log.Print(err)
}
}()
return cancelFunc
} }
func (c *BoringProxyClient) Run() { func (c *BoringProxyClient) Run() {