diff --git a/api.go b/api.go index d733fc0..23c5347 100644 --- a/api.go +++ b/api.go @@ -274,9 +274,14 @@ func (a *Api) GetTunnel(tokenData TokenData, params url.Values) (Tunnel, error) return Tunnel{}, errors.New("Invalid domain parameter") } - tun, exists := a.db.GetTunnel(domain) + client := params.Get("client") + if client == "" { + return Tunnel{}, errors.New("Invalid client parameter") + } + + tun, exists := a.db.GetTunnel(domain + "|" + client) if !exists { - return Tunnel{}, errors.New("Tunnel doesn't exist for domain") + return Tunnel{}, errors.New("Tunnel doesn't exist for this domain|client combination") } user, _ := a.db.GetUser(tokenData.Owner) @@ -389,6 +394,7 @@ func (a *Api) CreateTunnel(tokenData TokenData, params url.Values) (*Tunnel, err AuthUsername: username, AuthPassword: password, TlsTermination: tlsTerm, + Used: false, } tunnel, err := a.tunMan.RequestCreateTunnel(request) @@ -406,7 +412,12 @@ func (a *Api) DeleteTunnel(tokenData TokenData, params url.Values) error { return errors.New("Invalid domain parameter") } - tun, exists := a.db.GetTunnel(domain) + client := params.Get("client") + if client == "" { + return errors.New("Invalid client parameter") + } + + tun, exists := a.db.GetTunnel(domain + "|" + client) if !exists { return errors.New("Tunnel doesn't exist") } @@ -418,7 +429,7 @@ func (a *Api) DeleteTunnel(tokenData TokenData, params url.Values) error { } } - a.tunMan.DeleteTunnel(domain) + a.tunMan.DeleteTunnel(domain + "|" + client) return nil } diff --git a/boringproxy.go b/boringproxy.go index 410f696..92d6c5f 100644 --- a/boringproxy.go +++ b/boringproxy.go @@ -281,7 +281,7 @@ func Listen() { } } else { - tunnel, exists := db.GetTunnel(hostDomain) + tunnel, exists := db.SelectLoadBalancedTunnel(hostDomain) if !exists { errMessage := fmt.Sprintf("No tunnel attached to %s", hostDomain) w.WriteHeader(500) @@ -342,7 +342,7 @@ func (p *Server) handleConnection(clientConn net.Conn) { passConn := NewProxyConn(clientConn, clientReader) - tunnel, exists := p.db.GetTunnel(clientHello.ServerName) + tunnel, exists := p.db.SelectLoadBalancedTunnel(clientHello.ServerName) if exists && (tunnel.TlsTermination == "client" || tunnel.TlsTermination == "passthrough") || tunnel.TlsTermination == "client-tls" { p.passthroughRequest(passConn, tunnel) diff --git a/database.go b/database.go index 761dbc5..edc7b2e 100644 --- a/database.go +++ b/database.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "log" + "strings" "sync" "github.com/takingnames/namedrop-go" @@ -58,6 +59,7 @@ type Tunnel struct { ClientName string `json:"client_name"` AuthUsername string `json:"auth_username"` AuthPassword string `json:"auth_password"` + Used bool } func NewDatabase() (*Database, error) { @@ -223,6 +225,44 @@ func (d *Database) GetTunnels() map[string]Tunnel { return tunnels } +func (d *Database) SelectLoadBalancedTunnel(domain string) (Tunnel, bool) { + d.mutex.Lock() + defer d.mutex.Unlock() + + tunnels := make(map[string]Tunnel) + for dom, tun := range d.Tunnels { + keys := strings.Split(dom, "|") + if domain == keys[0] { + tunnels[dom] = tun + } + } + + // Load balance + for _, t := range tunnels { + if !t.Used { + t.Used = true + log.Printf("Routing to : %s|%s", t.Domain, t.ClientName) + return t, true + } + } + + // Reset all used flags + count := 1 + for _, t := range tunnels { + t.Used = false + + if count == len(tunnels) { + t.Used = true + log.Printf("Routing to : %s|%s", t.Domain, t.ClientName) + return t, true + } + count += 1 + } + + // len(tunnels) is 0 + return Tunnel{}, false +} + func (d *Database) GetTunnel(domain string) (Tunnel, bool) { d.mutex.Lock() defer d.mutex.Unlock() diff --git a/templates/tunnel.tmpl b/templates/tunnel.tmpl index 8753e78..67ccbad 100644 --- a/templates/tunnel.tmpl +++ b/templates/tunnel.tmpl @@ -30,8 +30,8 @@
- Download Private Key - Delete + Download Private Key + Delete
{{ template "footer.tmpl" . }} diff --git a/templates/tunnels.tmpl b/templates/tunnels.tmpl index 4284314..be40ecf 100644 --- a/templates/tunnels.tmpl +++ b/templates/tunnels.tmpl @@ -4,7 +4,7 @@
Domain:
- +
Client:
@@ -15,8 +15,8 @@
{{$tunnel.ClientAddress}}:{{$tunnel.ClientPort}}
- View - Delete + View + Delete
{{ end }} @@ -36,14 +36,14 @@ {{range $domain, $tunnel:= .Tunnels}} - {{$domain}} + {{$tunnel.Domain}} {{$tunnel.ClientName}} {{$tunnel.ClientAddress}}:{{$tunnel.ClientPort}}
- View - Delete + View + Delete
diff --git a/tunnel_manager.go b/tunnel_manager.go index 11efec7..05f5bdb 100644 --- a/tunnel_manager.go +++ b/tunnel_manager.go @@ -8,14 +8,15 @@ import ( "encoding/pem" "errors" "fmt" - "github.com/caddyserver/certmagic" - "golang.org/x/crypto/ssh" "io/ioutil" "log" "os" "os/user" "strings" "sync" + + "github.com/caddyserver/certmagic" + "golang.org/x/crypto/ssh" ) type TunnelManager struct { @@ -34,9 +35,9 @@ func NewTunnelManager(config *Config, db *Database, certConfig *certmagic.Config } if config.autoCerts { - for domainName, tun := range db.GetTunnels() { + for _, tun := range db.GetTunnels() { if tun.TlsTermination == "server" { - err = certConfig.ManageSync(context.Background(), []string{domainName}) + err = certConfig.ManageSync(context.Background(), []string{tun.Domain}) if err != nil { log.Println("CertMagic error at startup") log.Println(err) @@ -84,8 +85,8 @@ func (m *TunnelManager) RequestCreateTunnel(tunReq Tunnel) (Tunnel, error) { } for _, tun := range m.db.GetTunnels() { - if tunReq.Domain == tun.Domain { - return Tunnel{}, errors.New("Tunnel domain already in use") + if tunReq.Domain == tun.Domain && tunReq.ClientName == tun.ClientName { + return Tunnel{}, errors.New("Tunnel domain and client name combination already in use") } if tunReq.TunnelPort == tun.TunnelPort { @@ -104,7 +105,7 @@ func (m *TunnelManager) RequestCreateTunnel(tunReq Tunnel) (Tunnel, error) { tunReq.Username = m.user.Username tunReq.TunnelPrivateKey = privKey - m.db.SetTunnel(tunReq.Domain, tunReq) + m.db.SetTunnel(tunReq.Domain+"|"+tunReq.ClientName, tunReq) return tunReq, nil } diff --git a/ui_handler.go b/ui_handler.go index 5230c9c..c779909 100644 --- a/ui_handler.go +++ b/ui_handler.go @@ -3,12 +3,15 @@ package boringproxy import ( "embed" "encoding/base64" + //"encoding/json" "fmt" - qrcode "github.com/skip2/go-qrcode" "html/template" "io" "net/http" + + qrcode "github.com/skip2/go-qrcode" + //"net/url" //"os" "strings" @@ -142,10 +145,17 @@ func (h *WebUiHandler) handleWebUiRequest(w http.ResponseWriter, r *http.Request } domain := r.Form["domain"][0] + if len(r.Form["client"]) != 1 { + w.WriteHeader(400) + w.Write([]byte("Invalid client parameter")) + return + } + client := r.Form["client"][0] + data := &ConfirmData{ Head: h.headHtml, - Message: fmt.Sprintf("Are you sure you want to delete %s?", domain), - ConfirmUrl: fmt.Sprintf("/delete-tunnel?domain=%s", domain), + Message: fmt.Sprintf("Are you sure you want to delete %s|%s?", domain, client), + ConfirmUrl: fmt.Sprintf("/delete-tunnel?domain=%s&client=%s", domain, client), CancelUrl: "/tunnels", } @@ -291,15 +301,17 @@ func (h *WebUiHandler) handleWebUiRequest(w http.ResponseWriter, r *http.Request parts := strings.Split(r.URL.Path, "/") - if len(parts) != 3 { + if len(parts) != 4 { w.WriteHeader(400) h.alertDialog(w, r, "Invalid path", "/tunnels") return } domain := parts[2] + client := parts[3] r.Form.Set("domain", domain) + r.Form.Set("client", client) tunnel, err := h.api.GetTunnel(tokenData, r.Form) if err != nil {