diff --git a/api.go b/api.go index d733fc0..23c5347 100644 --- a/api.go +++ b/api.go @@ -274,9 +274,14 @@ func (a *Api) GetTunnel(tokenData TokenData, params url.Values) (Tunnel, error) return Tunnel{}, errors.New("Invalid domain parameter") } - tun, exists := a.db.GetTunnel(domain) + client := params.Get("client") + if client == "" { + return Tunnel{}, errors.New("Invalid client parameter") + } + + tun, exists := a.db.GetTunnel(domain + "|" + client) if !exists { - return Tunnel{}, errors.New("Tunnel doesn't exist for domain") + return Tunnel{}, errors.New("Tunnel doesn't exist for this domain|client combination") } user, _ := a.db.GetUser(tokenData.Owner) @@ -389,6 +394,7 @@ func (a *Api) CreateTunnel(tokenData TokenData, params url.Values) (*Tunnel, err AuthUsername: username, AuthPassword: password, TlsTermination: tlsTerm, + Used: false, } tunnel, err := a.tunMan.RequestCreateTunnel(request) @@ -406,7 +412,12 @@ func (a *Api) DeleteTunnel(tokenData TokenData, params url.Values) error { return errors.New("Invalid domain parameter") } - tun, exists := a.db.GetTunnel(domain) + client := params.Get("client") + if client == "" { + return errors.New("Invalid client parameter") + } + + tun, exists := a.db.GetTunnel(domain + "|" + client) if !exists { return errors.New("Tunnel doesn't exist") } @@ -418,7 +429,7 @@ func (a *Api) DeleteTunnel(tokenData TokenData, params url.Values) error { } } - a.tunMan.DeleteTunnel(domain) + a.tunMan.DeleteTunnel(domain + "|" + client) return nil } diff --git a/boringproxy.go b/boringproxy.go index 410f696..92d6c5f 100644 --- a/boringproxy.go +++ b/boringproxy.go @@ -281,7 +281,7 @@ func Listen() { } } else { - tunnel, exists := db.GetTunnel(hostDomain) + tunnel, exists := db.SelectLoadBalancedTunnel(hostDomain) if !exists { errMessage := fmt.Sprintf("No tunnel attached to %s", hostDomain) w.WriteHeader(500) @@ -342,7 +342,7 @@ func (p *Server) handleConnection(clientConn net.Conn) { passConn := NewProxyConn(clientConn, clientReader) - tunnel, exists := p.db.GetTunnel(clientHello.ServerName) + tunnel, exists := p.db.SelectLoadBalancedTunnel(clientHello.ServerName) if exists && (tunnel.TlsTermination == "client" || tunnel.TlsTermination == "passthrough") || tunnel.TlsTermination == "client-tls" { p.passthroughRequest(passConn, tunnel) diff --git a/database.go b/database.go index 761dbc5..edc7b2e 100644 --- a/database.go +++ b/database.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "log" + "strings" "sync" "github.com/takingnames/namedrop-go" @@ -58,6 +59,7 @@ type Tunnel struct { ClientName string `json:"client_name"` AuthUsername string `json:"auth_username"` AuthPassword string `json:"auth_password"` + Used bool } func NewDatabase() (*Database, error) { @@ -223,6 +225,44 @@ func (d *Database) GetTunnels() map[string]Tunnel { return tunnels } +func (d *Database) SelectLoadBalancedTunnel(domain string) (Tunnel, bool) { + d.mutex.Lock() + defer d.mutex.Unlock() + + tunnels := make(map[string]Tunnel) + for dom, tun := range d.Tunnels { + keys := strings.Split(dom, "|") + if domain == keys[0] { + tunnels[dom] = tun + } + } + + // Load balance + for _, t := range tunnels { + if !t.Used { + t.Used = true + log.Printf("Routing to : %s|%s", t.Domain, t.ClientName) + return t, true + } + } + + // Reset all used flags + count := 1 + for _, t := range tunnels { + t.Used = false + + if count == len(tunnels) { + t.Used = true + log.Printf("Routing to : %s|%s", t.Domain, t.ClientName) + return t, true + } + count += 1 + } + + // len(tunnels) is 0 + return Tunnel{}, false +} + func (d *Database) GetTunnel(domain string) (Tunnel, bool) { d.mutex.Lock() defer d.mutex.Unlock() diff --git a/templates/tunnel.tmpl b/templates/tunnel.tmpl index 8753e78..67ccbad 100644 --- a/templates/tunnel.tmpl +++ b/templates/tunnel.tmpl @@ -30,8 +30,8 @@
{{ template "footer.tmpl" . }} diff --git a/templates/tunnels.tmpl b/templates/tunnels.tmpl index 4284314..be40ecf 100644 --- a/templates/tunnels.tmpl +++ b/templates/tunnels.tmpl @@ -4,7 +4,7 @@